fennol.md.dynamic

  1import sys, os, io
  2import argparse
  3import time
  4from pathlib import Path
  5
  6import numpy as np
  7from collections import defaultdict
  8import jax
  9import jax.numpy as jnp
 10import pickle
 11import yaml
 12
 13from ..utils.io import (
 14    write_arc_frame,
 15    write_xyz_frame,
 16    write_extxyz_frame,
 17    human_time_duration,
 18)
 19from .utils import wrapbox, save_dynamics_restart,us
 20from ..utils import minmaxone,read_tinker_interval
 21from ..utils.input_parser import parse_input,convert_dict_units, InputFile
 22from .integrate import initialize_dynamics
 23
 24
 25def main():
 26    """
 27    Main entry point for the fennol_md command-line interface.
 28    
 29    Parses command-line arguments and runs a molecular dynamics simulation
 30    based on the provided parameter file.
 31    
 32    Command-line Usage:
 33        fennol_md input.fnl
 34        fennol_md config.yaml
 35    
 36    Returns:
 37        int: Exit code (0 for success)
 38    """
 39    # os.environ["OMP_NUM_THREADS"] = "1"
 40    sys.stdout = io.TextIOWrapper(
 41        open(sys.stdout.fileno(), "wb", 0), write_through=True
 42    )
 43    
 44    ### Read the parameter file
 45    parser = argparse.ArgumentParser(prog="fennol_md")
 46    parser.add_argument("param_file", type=Path, help="Parameter file")
 47    args = parser.parse_args()
 48    param_file = args.param_file
 49
 50    return config_and_run_dynamic(param_file)
 51
 52def config_and_run_dynamic(param_file: Path):
 53    """
 54    Configure and run a molecular dynamics simulation.
 55    
 56    This function loads simulation parameters from a configuration file,
 57    sets up the computation device and precision, and runs the MD simulation.
 58    
 59    Parameters:
 60        param_file (Path): Path to the parameter file (.fnl, .yaml, or .yml)
 61    
 62    Returns:
 63        int: Exit code (0 for success)
 64        
 65    Raises:
 66        FileNotFoundError: If the parameter file doesn't exist
 67        ValueError: If the parameter file format is unsupported
 68        
 69    Supported file formats:
 70        - .fnl: FeNNol native format
 71        - .yaml/.yml: YAML format
 72    
 73    Internal units are specified by UnitSystem(L="ANGSTROM", T="PS", E="KCALPERMOL")
 74
 75    Unit conversion in the parameter file:
 76        - Units specified in brackets: dt[fs] = 0.5
 77        - All specified units converted to internal units
 78        - non-specified units are assumed to be in internal units
 79    """
 80
 81    if not param_file.exists() and not param_file.is_file():
 82        raise FileNotFoundError(f"Parameter file {param_file} not found")
 83
 84    if param_file.suffix in [".yaml", ".yml"]:
 85        with open(param_file, "r") as f:
 86            simulation_parameters = convert_dict_units(yaml.safe_load(f),us=us)
 87            simulation_parameters = InputFile(**simulation_parameters)
 88    elif param_file.suffix == ".fnl":
 89        simulation_parameters = parse_input(param_file,us=us)
 90    else:
 91        raise ValueError(
 92            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
 93        )
 94
 95    ### Set the device
 96    if "FENNOL_DEVICE" in os.environ:
 97        device = os.environ["FENNOL_DEVICE"].lower()
 98        print(f"# Setting device from env FENNOL_DEVICE={device}")
 99    else:
100        device = simulation_parameters.get("device", "cpu").lower()
101        """@keyword[fennol_md] device
102        Computation device. Options: 'cpu', 'cuda:N', 'gpu:N' where N is device number.
103        Default: 'cpu'
104        """
105    if device == "cpu":
106        jax.config.update('jax_platforms', 'cpu')
107        os.environ["CUDA_VISIBLE_DEVICES"] = ""
108    elif device.startswith("cuda") or device.startswith("gpu"):
109        if ":" in device:
110            num = device.split(":")[-1]
111            os.environ["CUDA_VISIBLE_DEVICES"] = num
112        else:
113            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
114        device = "gpu"
115
116    _device = jax.devices(device)[0]
117    jax.config.update("jax_default_device", _device)
118
119    ### Set the precision
120    enable_x64 = simulation_parameters.get("double_precision", False)
121    """@keyword[fennol_md] double_precision
122    Enable double precision (64-bit) calculations. Default is single precision (32-bit).
123    Default: False
124    """
125    jax.config.update("jax_enable_x64", enable_x64)
126    fprec = "float64" if enable_x64 else "float32"
127
128    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
129    """@keyword[fennol_md] matmul_prec
130    Matrix multiplication precision. Options: 'default', 'high', 'highest'.
131    Default: "highest"
132    """
133    assert matmul_precision in [
134        "default",
135        "high",
136        "highest",
137    ], "matmul_prec must be one of 'default','high','highest'"
138    if matmul_precision != "highest":
139        print(f"# Setting matmul precision to '{matmul_precision}'")
140    if matmul_precision == "default" and fprec == "float32":
141        print(
142            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
143        )
144    jax.config.update("jax_default_matmul_precision", matmul_precision)
145
146    # with jax.default_device(_device):
147    return dynamic(simulation_parameters, device, fprec)
148
149
150def dynamic(simulation_parameters, device, fprec):
151    """
152    Execute the molecular dynamics simulation loop.
153    
154    This function performs the main MD simulation using the initialized
155    system, integrator, and thermostats/barostats. It handles trajectory
156    output, property monitoring, and restart file generation.
157    
158    Parameters:
159        simulation_parameters: Parsed simulation parameters
160        device (str): Computation device ('cpu' or 'gpu')
161        fprec (str): Floating point precision ('float32' or 'float64')
162        
163    Returns:
164        int: Exit code (0 for success)
165    """
166    tstart_dyn = time.time()
167
168    random_seed = simulation_parameters.get(
169        "random_seed", np.random.randint(0, 2**32 - 1)
170    )
171    """@keyword[fennol_md] random_seed
172    Random seed for reproducible simulations. If not specified, a random seed is generated.
173    Default: Random integer between 0 and 2^32-1
174    """
175    print(f"# random_seed: {random_seed}")
176    rng_key = jax.random.PRNGKey(random_seed)
177
178    ## INITIALIZE INTEGRATOR AND SYSTEM
179    rng_key, subkey = jax.random.split(rng_key)
180    step, update_conformation, system_data, dyn_state, conformation, system = (
181        initialize_dynamics(simulation_parameters, fprec, subkey)
182    )
183
184    nat = system_data["nat"]
185    dt = dyn_state["dt"]
186    ## get number of steps
187    nsteps = int(simulation_parameters.get("nsteps"))
188    """@keyword[fennol_md] nsteps
189    Total number of MD steps to perform. Required parameter.
190    Type: int, Required
191    """
192    start_time_ps = dyn_state.get("start_time_ps", 0.0)
193
194    ### Set I/O parameters
195    Tdump = simulation_parameters.get("tdump", 1.0 / us.PS)
196    """@keyword[fennol_md] tdump
197    Time interval between trajectory frames.
198    Default: 1.0 ps
199    """
200    ndump = int(Tdump / dt)
201    system_name = system_data["name"]
202    estimate_pressure = dyn_state["estimate_pressure"]
203
204    @jax.jit
205    def check_nan(system):
206        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
207            jnp.isnan(system["coordinates"])
208        )
209
210    if system_data["pbc"] is not None:
211        cell = system["cell"]
212        reciprocal_cell = np.linalg.inv(cell)
213        do_wrap_box = simulation_parameters.get("wrap_box", False)
214        """@keyword[fennol_md] wrap_box
215        Wrap coordinates into primary unit cell.
216        Default: False
217        """
218        if do_wrap_box:
219            wrap_groups_def = simulation_parameters.get("wrap_groups",None)
220            """@keyword[fennol_md] wrap_groups
221            Specific atom groups to wrap independently. Dictionary mapping group names to atom indices.
222            Default: None
223            """
224            if wrap_groups_def is None:
225                wrap_groups = None
226            else:
227                wrap_groups = {}
228                assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary"
229                for k, v in wrap_groups_def.items():
230                    wrap_groups[k]=read_tinker_interval(v)
231                # check that pairwise intersection of wrap groups is empty
232                wrap_groups_keys = list(wrap_groups.keys())
233                for i in range(len(wrap_groups_keys)):
234                    i_key = wrap_groups_keys[i]
235                    w1 = set(wrap_groups[i_key])
236                    for j in range(i + 1, len(wrap_groups_keys)):
237                        j_key = wrap_groups_keys[j]
238                        w2 = set(wrap_groups[j_key])
239                        if  w1.intersection(w2):
240                            raise ValueError(
241                                f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}"
242                            )
243                group_all = np.concatenate(list(wrap_groups.values()))
244                # get all atoms that are not in any wrap group
245                group_none = np.setdiff1d(np.arange(nat), group_all)
246                print(f"# Wrap groups: {wrap_groups}")
247                wrap_groups["__other"] = group_none
248                wrap_groups = ((k, v) for k, v in wrap_groups.items())
249
250    else:
251        cell = None
252        reciprocal_cell = None
253        do_wrap_box = False
254        wrap_groups = None
255
256    ### Energy units and print initial energy
257    model_energy_unit = system_data["model_energy_unit"]
258    model_energy_unit_str = system_data["model_energy_unit_str"]
259    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
260    """@keyword[fennol_md] per_atom_energy
261    Print energies per atom instead of total energies.
262    Default: True
263    """
264    energy_unit_str = system_data["energy_unit_str"]
265    energy_unit = system_data["energy_unit"]
266    print("# Energy unit: ", energy_unit_str)
267    atom_energy_unit = energy_unit
268    atom_energy_unit_str = energy_unit_str
269    if per_atom_energy:
270        atom_energy_unit /= nat
271        atom_energy_unit_str = f"{energy_unit_str}/atom"
272        print("# Printing Energy per atom")
273    print(
274        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
275    )
276    f = system["forces"]
277    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
278
279    ## printing options
280    nprint = int(simulation_parameters.get("nprint", 10))
281    """@keyword[fennol_md] nprint
282    Number of steps between energy/property printing.
283    Default: 10
284    """
285    assert nprint > 0, "nprint must be > 0"
286    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
287    """@keyword[fennol_md] nsummary
288    Number of steps between summary statistics.
289    Default: 100 * nprint
290    """
291    assert nsummary > nprint, "nsummary must be > nprint"
292
293    save_keys = simulation_parameters.get("save_keys", [])
294    """@keyword[fennol_md] save_keys
295    Additional model output keys to save to trajectory.
296    Default: []
297    """
298    if save_keys:
299        print(f"# Saving keys: {save_keys}")
300        fkeys = open(f"{system_name}.traj.pkl", "wb+")
301    else:
302        fkeys = None
303    
304    ### initialize colvars
305    use_colvars = "colvars" in dyn_state
306    if use_colvars:
307        print(f"# Colvars: {dyn_state['colvars']}")
308        colvars_names = dyn_state["colvars"]
309        # open colvars file and print header
310        fcolvars = open(f"{system_name}.colvars.traj", "a")
311        fcolvars.write("#time[ps] ")
312        fcolvars.write(" ".join(colvars_names))
313        fcolvars.write("\n")
314        fcolvars.flush()
315
316    ### Print header
317    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
318    thermostat_name = dyn_state["thermostat_name"]
319    pimd = dyn_state["pimd"]
320    variable_cell = dyn_state["variable_cell"]
321    nbeads = system_data.get("nbeads", 1)
322    dyn_name = "PIMD" if pimd else "MD"
323    print("#" * 84)
324    print(
325        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
326    )
327    header = f"#{'Step':>9} {'Time[ps]':>10} {'Etot':>10} {'Epot':>10} {'Ekin':>10} {'Temp[K]':>10}"
328    if pimd:
329        header += f" {'Temp_c[K]':>10}"
330    if include_thermostat_energy:
331        header += f" {'Etherm':>10}"
332    if estimate_pressure:
333        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
334        """@keyword[fennol_md] print_aniso_pressure
335        Print anisotropic pressure tensor components.
336        Default: False
337        """
338        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
339        """@keyword[fennol_md] pressure_unit
340        Pressure unit for output. Options: 'atm', 'bar', 'Pa', 'GPa'.
341        Default: "atm"
342        """
343        pressure_unit = us.get_multiplier(pressure_unit_str)
344        p_str = f"  Press[{pressure_unit_str}]"
345        header += f" {p_str:>10}"
346    if variable_cell:
347        header += f" {'Density':>10}"
348    print(header)
349
350    ### Open trajectory file
351    traj_format = simulation_parameters.get("traj_format", "arc").lower()
352    """@keyword[fennol_md] traj_format
353    Trajectory file format. Options: 'arc' (Tinker), 'xyz' (standard), 'extxyz' (extended).
354    Default: "arc"
355    """
356    if traj_format == "xyz":
357        traj_ext = ".traj.xyz"
358        write_frame = write_xyz_frame
359    elif traj_format == "extxyz":
360        traj_ext = ".traj.extxyz"
361        write_frame = write_extxyz_frame
362    elif traj_format == "arc":
363        traj_ext = ".arc"
364        write_frame = write_arc_frame
365    else:
366        raise ValueError(
367            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
368        )
369
370    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
371    """@keyword[fennol_md] write_all_beads
372    Write all PIMD beads to separate trajectory files.
373    Default: False
374    """
375
376    if write_all_beads:
377        fout = [
378            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
379        ]
380    else:
381        fout = open(system_name + traj_ext, "a")
382
383    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
384    """@keyword[fennol_md] etot_ensemble_key
385    Key for ensemble weighting in enhanced sampling.
386    Default: None
387    """
388    if ensemble_key is not None:
389        fens = open(f"{system_name}.ensemble_weights.traj", "a")
390
391    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
392    """@keyword[fennol_md] write_centroid
393    Write PIMD centroid coordinates to separate file.
394    Default: False
395    """
396    if write_centroid:
397        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
398
399    ### initialize proprerty trajectories
400    properties_traj = defaultdict(list)
401
402    ### initialize counters and timers
403    t0 = time.time()
404    t0dump = t0
405    istep = 0
406    t0full = time.time()
407    force_preprocess = False
408
409    for istep in range(1, nsteps + 1):
410
411        ### update the system
412        dyn_state, system, conformation, model_output = step(
413            istep, dyn_state, system, conformation, force_preprocess
414        )
415
416        ### print properties
417        if istep % nprint == 0:
418            t1 = time.time()
419            tperstep = (t1 - t0) / nprint
420            t0 = t1
421            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
422
423            ek = system["ek"]
424            epot = system["epot"]
425            etot = ek + epot
426            temper = 2 * ek / (3.0 * nat) * us.KELVIN
427
428            th_state = system["thermostat"]
429            if include_thermostat_energy:
430                etherm = th_state["thermostat_energy"]
431                etot = etot + etherm
432
433            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
434                etot * atom_energy_unit
435            )
436            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
437                epot * atom_energy_unit
438            )
439            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
440                ek * atom_energy_unit
441            )
442            properties_traj["Temper[Kelvin]"].append(temper)
443            if pimd:
444                ek_c = system["ek_c"]
445                temper_c = 2 * ek_c / (3.0 * nat) * us.KELVIN
446                properties_traj["Temper_c[Kelvin]"].append(temper_c)
447
448            ### construct line of properties
449            simulated_time_ps = start_time_ps + istep * dt * us.PS
450            line = f"{istep:10.6g} {simulated_time_ps:10.3f} {etot*atom_energy_unit:#10.4f}  {epot*atom_energy_unit:#10.4f} {ek*atom_energy_unit:#10.4f} {temper:10.2f}"
451            if pimd:
452                line += f" {temper_c:10.2f}"
453            if include_thermostat_energy:
454                line += f" {etherm*atom_energy_unit:#10.4f}"
455                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
456                    etherm * atom_energy_unit
457                )
458            if estimate_pressure:
459                pres = system["pressure"] * pressure_unit
460                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
461                if print_aniso_pressure:
462                    pres_tensor = system["pressure_tensor"] * pressure_unit
463                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
464                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
465                        pres_tensor[0, 0]
466                    )
467                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
468                        pres_tensor[1, 1]
469                    )
470                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
471                        pres_tensor[2, 2]
472                    )
473                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
474                        pres_tensor[0, 1]
475                    )
476                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
477                        pres_tensor[0, 2]
478                    )
479                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
480                        pres_tensor[1, 2]
481                    )
482                line += f" {pres:10.3f}"
483            if variable_cell:
484                density = system["density"]
485                properties_traj["Density[g/cm^3]"].append(density)
486                if print_aniso_pressure:
487                    cell = system["cell"]
488                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
489                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
490                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
491                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
492                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
493                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
494                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
495                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
496                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
497                line += f" {density:10.4f}"
498                if "piston_temperature" in system["barostat"]:
499                    piston_temperature = system["barostat"]["piston_temperature"]
500                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
501
502            print(line)
503
504            ### save colvars
505            if use_colvars:
506                colvars = system["colvars"]
507                fcolvars.write(f"{simulated_time_ps:.3f} ")
508                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
509                fcolvars.write("\n")
510                fcolvars.flush()
511
512            if save_keys:
513                data = {
514                    k: (
515                        np.asarray(model_output[k])
516                        if isinstance(model_output[k], jnp.ndarray)
517                        else model_output[k]
518                    )
519                    for k in save_keys
520                }
521                data["properties"] = {
522                    k: float(v) for k, v in zip(header.split()[1:], line.split())
523                }
524                data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str)
525                data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str)
526
527                pickle.dump(data, fkeys)
528
529        ### save frame
530        if istep % ndump == 0:
531            line = "# Write XYZ frame"
532            if variable_cell:
533                cell = np.array(system["cell"])
534                reciprocal_cell = np.linalg.inv(cell)
535            if do_wrap_box:
536                if pimd:
537                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups)
538                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
539                else:
540                    system["coordinates"] = wrapbox(
541                        system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups
542                    )
543                conformation = update_conformation(conformation, system)
544                line += " (atoms have been wrapped into the box)"
545                force_preprocess = True
546            print(line)
547
548            save_dynamics_restart(system_data, conformation, dyn_state, system)
549
550            properties = {
551                "energy": float(system["epot"]) * energy_unit,
552                "Time_ps": start_time_ps + istep * dt * us.PS,
553                "energy_unit": energy_unit_str,
554            }
555
556            if write_all_beads:
557                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
558                for i, fb in enumerate(fout):
559                    write_frame(
560                        fb,
561                        system_data["symbols"],
562                        coords[i],
563                        cell=cell,
564                        properties=properties,
565                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
566                    )
567            else:
568                write_frame(
569                    fout,
570                    system_data["symbols"],
571                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
572                    cell=cell,
573                    properties=properties,
574                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
575                )
576            if write_centroid:
577                centroid = np.asarray(system["coordinates"][0])
578                write_frame(
579                    fcentroid,
580                    system_data["symbols"],
581                    centroid,
582                    cell=cell,
583                    properties=properties,
584                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
585                    * energy_unit,
586                )
587            if ensemble_key is not None:
588                weights = " ".join(
589                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
590                )
591                fens.write(f"{weights}\n")
592                fens.flush()
593
594        ### summary over last nsummary steps
595        if istep % (nsummary) == 0:
596            if check_nan(system):
597                raise ValueError(f"dynamics crashed at step {istep}.")
598            tfull = time.time() - t0full
599            t0full = time.time()
600            tperstep = tfull / (nsummary)
601            nsperday = (24 * 60 * 60 / tperstep) * dt * us.NS
602            elapsed_time = time.time() - tstart_dyn
603            estimated_remaining_time = tperstep * (nsteps - istep)
604            estimated_total_time = elapsed_time + estimated_remaining_time
605
606            print("#" * 50)
607            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
608            print(f"# Simulated time      : {istep * dt*us.PS:.3f} ps")
609            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*us.PS:.3f} ps")
610            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
611            print(
612                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
613            )
614            print(
615                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
616            )
617            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
618
619            corr_kin = system["thermostat"].get("corr_kin", None)
620            if corr_kin is not None:
621                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
622            print(f"# Averages over last {nsummary:_} steps :")
623            for k, v in properties_traj.items():
624                if len(properties_traj[k]) == 0:
625                    continue
626                mu = np.mean(properties_traj[k])
627                sig = np.std(properties_traj[k])
628                ksplit = k.split("[")
629                name = ksplit[0].strip()
630                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
631                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
632
633            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
634            print("#" * 50)
635            if istep < nsteps:
636                print(header)
637            ## reset property trajectories
638            properties_traj = defaultdict(list)
639
640    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
641    ### close trajectory file
642    fout.close()
643    if ensemble_key is not None:
644        fens.close()
645    if use_colvars:
646        fcolvars.close()
647    if fkeys is not None:
648        fkeys.close()
649    if write_centroid:
650        fcentroid.close()
651
652    return 0
653
654
655if __name__ == "__main__":
656    main()
def main():
26def main():
27    """
28    Main entry point for the fennol_md command-line interface.
29    
30    Parses command-line arguments and runs a molecular dynamics simulation
31    based on the provided parameter file.
32    
33    Command-line Usage:
34        fennol_md input.fnl
35        fennol_md config.yaml
36    
37    Returns:
38        int: Exit code (0 for success)
39    """
40    # os.environ["OMP_NUM_THREADS"] = "1"
41    sys.stdout = io.TextIOWrapper(
42        open(sys.stdout.fileno(), "wb", 0), write_through=True
43    )
44    
45    ### Read the parameter file
46    parser = argparse.ArgumentParser(prog="fennol_md")
47    parser.add_argument("param_file", type=Path, help="Parameter file")
48    args = parser.parse_args()
49    param_file = args.param_file
50
51    return config_and_run_dynamic(param_file)

Main entry point for the fennol_md command-line interface.

Parses command-line arguments and runs a molecular dynamics simulation based on the provided parameter file.

Command-line Usage: fennol_md input.fnl fennol_md config.yaml

Returns: int: Exit code (0 for success)

def config_and_run_dynamic(param_file: pathlib.Path):
 53def config_and_run_dynamic(param_file: Path):
 54    """
 55    Configure and run a molecular dynamics simulation.
 56    
 57    This function loads simulation parameters from a configuration file,
 58    sets up the computation device and precision, and runs the MD simulation.
 59    
 60    Parameters:
 61        param_file (Path): Path to the parameter file (.fnl, .yaml, or .yml)
 62    
 63    Returns:
 64        int: Exit code (0 for success)
 65        
 66    Raises:
 67        FileNotFoundError: If the parameter file doesn't exist
 68        ValueError: If the parameter file format is unsupported
 69        
 70    Supported file formats:
 71        - .fnl: FeNNol native format
 72        - .yaml/.yml: YAML format
 73    
 74    Internal units are specified by UnitSystem(L="ANGSTROM", T="PS", E="KCALPERMOL")
 75
 76    Unit conversion in the parameter file:
 77        - Units specified in brackets: dt[fs] = 0.5
 78        - All specified units converted to internal units
 79        - non-specified units are assumed to be in internal units
 80    """
 81
 82    if not param_file.exists() and not param_file.is_file():
 83        raise FileNotFoundError(f"Parameter file {param_file} not found")
 84
 85    if param_file.suffix in [".yaml", ".yml"]:
 86        with open(param_file, "r") as f:
 87            simulation_parameters = convert_dict_units(yaml.safe_load(f),us=us)
 88            simulation_parameters = InputFile(**simulation_parameters)
 89    elif param_file.suffix == ".fnl":
 90        simulation_parameters = parse_input(param_file,us=us)
 91    else:
 92        raise ValueError(
 93            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
 94        )
 95
 96    ### Set the device
 97    if "FENNOL_DEVICE" in os.environ:
 98        device = os.environ["FENNOL_DEVICE"].lower()
 99        print(f"# Setting device from env FENNOL_DEVICE={device}")
100    else:
101        device = simulation_parameters.get("device", "cpu").lower()
102        """@keyword[fennol_md] device
103        Computation device. Options: 'cpu', 'cuda:N', 'gpu:N' where N is device number.
104        Default: 'cpu'
105        """
106    if device == "cpu":
107        jax.config.update('jax_platforms', 'cpu')
108        os.environ["CUDA_VISIBLE_DEVICES"] = ""
109    elif device.startswith("cuda") or device.startswith("gpu"):
110        if ":" in device:
111            num = device.split(":")[-1]
112            os.environ["CUDA_VISIBLE_DEVICES"] = num
113        else:
114            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
115        device = "gpu"
116
117    _device = jax.devices(device)[0]
118    jax.config.update("jax_default_device", _device)
119
120    ### Set the precision
121    enable_x64 = simulation_parameters.get("double_precision", False)
122    """@keyword[fennol_md] double_precision
123    Enable double precision (64-bit) calculations. Default is single precision (32-bit).
124    Default: False
125    """
126    jax.config.update("jax_enable_x64", enable_x64)
127    fprec = "float64" if enable_x64 else "float32"
128
129    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
130    """@keyword[fennol_md] matmul_prec
131    Matrix multiplication precision. Options: 'default', 'high', 'highest'.
132    Default: "highest"
133    """
134    assert matmul_precision in [
135        "default",
136        "high",
137        "highest",
138    ], "matmul_prec must be one of 'default','high','highest'"
139    if matmul_precision != "highest":
140        print(f"# Setting matmul precision to '{matmul_precision}'")
141    if matmul_precision == "default" and fprec == "float32":
142        print(
143            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
144        )
145    jax.config.update("jax_default_matmul_precision", matmul_precision)
146
147    # with jax.default_device(_device):
148    return dynamic(simulation_parameters, device, fprec)

Configure and run a molecular dynamics simulation.

This function loads simulation parameters from a configuration file, sets up the computation device and precision, and runs the MD simulation.

Parameters: param_file (Path): Path to the parameter file (.fnl, .yaml, or .yml)

Returns: int: Exit code (0 for success)

Raises: FileNotFoundError: If the parameter file doesn't exist ValueError: If the parameter file format is unsupported

Supported file formats: - .fnl: FeNNol native format - .yaml/.yml: YAML format

Internal units are specified by UnitSystem(L="ANGSTROM", T="PS", E="KCALPERMOL")

Unit conversion in the parameter file: - Units specified in brackets: dt[fs] = 0.5 - All specified units converted to internal units - non-specified units are assumed to be in internal units

def dynamic(simulation_parameters, device, fprec):
151def dynamic(simulation_parameters, device, fprec):
152    """
153    Execute the molecular dynamics simulation loop.
154    
155    This function performs the main MD simulation using the initialized
156    system, integrator, and thermostats/barostats. It handles trajectory
157    output, property monitoring, and restart file generation.
158    
159    Parameters:
160        simulation_parameters: Parsed simulation parameters
161        device (str): Computation device ('cpu' or 'gpu')
162        fprec (str): Floating point precision ('float32' or 'float64')
163        
164    Returns:
165        int: Exit code (0 for success)
166    """
167    tstart_dyn = time.time()
168
169    random_seed = simulation_parameters.get(
170        "random_seed", np.random.randint(0, 2**32 - 1)
171    )
172    """@keyword[fennol_md] random_seed
173    Random seed for reproducible simulations. If not specified, a random seed is generated.
174    Default: Random integer between 0 and 2^32-1
175    """
176    print(f"# random_seed: {random_seed}")
177    rng_key = jax.random.PRNGKey(random_seed)
178
179    ## INITIALIZE INTEGRATOR AND SYSTEM
180    rng_key, subkey = jax.random.split(rng_key)
181    step, update_conformation, system_data, dyn_state, conformation, system = (
182        initialize_dynamics(simulation_parameters, fprec, subkey)
183    )
184
185    nat = system_data["nat"]
186    dt = dyn_state["dt"]
187    ## get number of steps
188    nsteps = int(simulation_parameters.get("nsteps"))
189    """@keyword[fennol_md] nsteps
190    Total number of MD steps to perform. Required parameter.
191    Type: int, Required
192    """
193    start_time_ps = dyn_state.get("start_time_ps", 0.0)
194
195    ### Set I/O parameters
196    Tdump = simulation_parameters.get("tdump", 1.0 / us.PS)
197    """@keyword[fennol_md] tdump
198    Time interval between trajectory frames.
199    Default: 1.0 ps
200    """
201    ndump = int(Tdump / dt)
202    system_name = system_data["name"]
203    estimate_pressure = dyn_state["estimate_pressure"]
204
205    @jax.jit
206    def check_nan(system):
207        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
208            jnp.isnan(system["coordinates"])
209        )
210
211    if system_data["pbc"] is not None:
212        cell = system["cell"]
213        reciprocal_cell = np.linalg.inv(cell)
214        do_wrap_box = simulation_parameters.get("wrap_box", False)
215        """@keyword[fennol_md] wrap_box
216        Wrap coordinates into primary unit cell.
217        Default: False
218        """
219        if do_wrap_box:
220            wrap_groups_def = simulation_parameters.get("wrap_groups",None)
221            """@keyword[fennol_md] wrap_groups
222            Specific atom groups to wrap independently. Dictionary mapping group names to atom indices.
223            Default: None
224            """
225            if wrap_groups_def is None:
226                wrap_groups = None
227            else:
228                wrap_groups = {}
229                assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary"
230                for k, v in wrap_groups_def.items():
231                    wrap_groups[k]=read_tinker_interval(v)
232                # check that pairwise intersection of wrap groups is empty
233                wrap_groups_keys = list(wrap_groups.keys())
234                for i in range(len(wrap_groups_keys)):
235                    i_key = wrap_groups_keys[i]
236                    w1 = set(wrap_groups[i_key])
237                    for j in range(i + 1, len(wrap_groups_keys)):
238                        j_key = wrap_groups_keys[j]
239                        w2 = set(wrap_groups[j_key])
240                        if  w1.intersection(w2):
241                            raise ValueError(
242                                f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}"
243                            )
244                group_all = np.concatenate(list(wrap_groups.values()))
245                # get all atoms that are not in any wrap group
246                group_none = np.setdiff1d(np.arange(nat), group_all)
247                print(f"# Wrap groups: {wrap_groups}")
248                wrap_groups["__other"] = group_none
249                wrap_groups = ((k, v) for k, v in wrap_groups.items())
250
251    else:
252        cell = None
253        reciprocal_cell = None
254        do_wrap_box = False
255        wrap_groups = None
256
257    ### Energy units and print initial energy
258    model_energy_unit = system_data["model_energy_unit"]
259    model_energy_unit_str = system_data["model_energy_unit_str"]
260    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
261    """@keyword[fennol_md] per_atom_energy
262    Print energies per atom instead of total energies.
263    Default: True
264    """
265    energy_unit_str = system_data["energy_unit_str"]
266    energy_unit = system_data["energy_unit"]
267    print("# Energy unit: ", energy_unit_str)
268    atom_energy_unit = energy_unit
269    atom_energy_unit_str = energy_unit_str
270    if per_atom_energy:
271        atom_energy_unit /= nat
272        atom_energy_unit_str = f"{energy_unit_str}/atom"
273        print("# Printing Energy per atom")
274    print(
275        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
276    )
277    f = system["forces"]
278    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
279
280    ## printing options
281    nprint = int(simulation_parameters.get("nprint", 10))
282    """@keyword[fennol_md] nprint
283    Number of steps between energy/property printing.
284    Default: 10
285    """
286    assert nprint > 0, "nprint must be > 0"
287    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
288    """@keyword[fennol_md] nsummary
289    Number of steps between summary statistics.
290    Default: 100 * nprint
291    """
292    assert nsummary > nprint, "nsummary must be > nprint"
293
294    save_keys = simulation_parameters.get("save_keys", [])
295    """@keyword[fennol_md] save_keys
296    Additional model output keys to save to trajectory.
297    Default: []
298    """
299    if save_keys:
300        print(f"# Saving keys: {save_keys}")
301        fkeys = open(f"{system_name}.traj.pkl", "wb+")
302    else:
303        fkeys = None
304    
305    ### initialize colvars
306    use_colvars = "colvars" in dyn_state
307    if use_colvars:
308        print(f"# Colvars: {dyn_state['colvars']}")
309        colvars_names = dyn_state["colvars"]
310        # open colvars file and print header
311        fcolvars = open(f"{system_name}.colvars.traj", "a")
312        fcolvars.write("#time[ps] ")
313        fcolvars.write(" ".join(colvars_names))
314        fcolvars.write("\n")
315        fcolvars.flush()
316
317    ### Print header
318    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
319    thermostat_name = dyn_state["thermostat_name"]
320    pimd = dyn_state["pimd"]
321    variable_cell = dyn_state["variable_cell"]
322    nbeads = system_data.get("nbeads", 1)
323    dyn_name = "PIMD" if pimd else "MD"
324    print("#" * 84)
325    print(
326        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
327    )
328    header = f"#{'Step':>9} {'Time[ps]':>10} {'Etot':>10} {'Epot':>10} {'Ekin':>10} {'Temp[K]':>10}"
329    if pimd:
330        header += f" {'Temp_c[K]':>10}"
331    if include_thermostat_energy:
332        header += f" {'Etherm':>10}"
333    if estimate_pressure:
334        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
335        """@keyword[fennol_md] print_aniso_pressure
336        Print anisotropic pressure tensor components.
337        Default: False
338        """
339        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
340        """@keyword[fennol_md] pressure_unit
341        Pressure unit for output. Options: 'atm', 'bar', 'Pa', 'GPa'.
342        Default: "atm"
343        """
344        pressure_unit = us.get_multiplier(pressure_unit_str)
345        p_str = f"  Press[{pressure_unit_str}]"
346        header += f" {p_str:>10}"
347    if variable_cell:
348        header += f" {'Density':>10}"
349    print(header)
350
351    ### Open trajectory file
352    traj_format = simulation_parameters.get("traj_format", "arc").lower()
353    """@keyword[fennol_md] traj_format
354    Trajectory file format. Options: 'arc' (Tinker), 'xyz' (standard), 'extxyz' (extended).
355    Default: "arc"
356    """
357    if traj_format == "xyz":
358        traj_ext = ".traj.xyz"
359        write_frame = write_xyz_frame
360    elif traj_format == "extxyz":
361        traj_ext = ".traj.extxyz"
362        write_frame = write_extxyz_frame
363    elif traj_format == "arc":
364        traj_ext = ".arc"
365        write_frame = write_arc_frame
366    else:
367        raise ValueError(
368            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
369        )
370
371    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
372    """@keyword[fennol_md] write_all_beads
373    Write all PIMD beads to separate trajectory files.
374    Default: False
375    """
376
377    if write_all_beads:
378        fout = [
379            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
380        ]
381    else:
382        fout = open(system_name + traj_ext, "a")
383
384    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
385    """@keyword[fennol_md] etot_ensemble_key
386    Key for ensemble weighting in enhanced sampling.
387    Default: None
388    """
389    if ensemble_key is not None:
390        fens = open(f"{system_name}.ensemble_weights.traj", "a")
391
392    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
393    """@keyword[fennol_md] write_centroid
394    Write PIMD centroid coordinates to separate file.
395    Default: False
396    """
397    if write_centroid:
398        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
399
400    ### initialize proprerty trajectories
401    properties_traj = defaultdict(list)
402
403    ### initialize counters and timers
404    t0 = time.time()
405    t0dump = t0
406    istep = 0
407    t0full = time.time()
408    force_preprocess = False
409
410    for istep in range(1, nsteps + 1):
411
412        ### update the system
413        dyn_state, system, conformation, model_output = step(
414            istep, dyn_state, system, conformation, force_preprocess
415        )
416
417        ### print properties
418        if istep % nprint == 0:
419            t1 = time.time()
420            tperstep = (t1 - t0) / nprint
421            t0 = t1
422            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
423
424            ek = system["ek"]
425            epot = system["epot"]
426            etot = ek + epot
427            temper = 2 * ek / (3.0 * nat) * us.KELVIN
428
429            th_state = system["thermostat"]
430            if include_thermostat_energy:
431                etherm = th_state["thermostat_energy"]
432                etot = etot + etherm
433
434            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
435                etot * atom_energy_unit
436            )
437            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
438                epot * atom_energy_unit
439            )
440            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
441                ek * atom_energy_unit
442            )
443            properties_traj["Temper[Kelvin]"].append(temper)
444            if pimd:
445                ek_c = system["ek_c"]
446                temper_c = 2 * ek_c / (3.0 * nat) * us.KELVIN
447                properties_traj["Temper_c[Kelvin]"].append(temper_c)
448
449            ### construct line of properties
450            simulated_time_ps = start_time_ps + istep * dt * us.PS
451            line = f"{istep:10.6g} {simulated_time_ps:10.3f} {etot*atom_energy_unit:#10.4f}  {epot*atom_energy_unit:#10.4f} {ek*atom_energy_unit:#10.4f} {temper:10.2f}"
452            if pimd:
453                line += f" {temper_c:10.2f}"
454            if include_thermostat_energy:
455                line += f" {etherm*atom_energy_unit:#10.4f}"
456                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
457                    etherm * atom_energy_unit
458                )
459            if estimate_pressure:
460                pres = system["pressure"] * pressure_unit
461                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
462                if print_aniso_pressure:
463                    pres_tensor = system["pressure_tensor"] * pressure_unit
464                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
465                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
466                        pres_tensor[0, 0]
467                    )
468                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
469                        pres_tensor[1, 1]
470                    )
471                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
472                        pres_tensor[2, 2]
473                    )
474                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
475                        pres_tensor[0, 1]
476                    )
477                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
478                        pres_tensor[0, 2]
479                    )
480                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
481                        pres_tensor[1, 2]
482                    )
483                line += f" {pres:10.3f}"
484            if variable_cell:
485                density = system["density"]
486                properties_traj["Density[g/cm^3]"].append(density)
487                if print_aniso_pressure:
488                    cell = system["cell"]
489                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
490                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
491                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
492                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
493                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
494                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
495                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
496                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
497                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
498                line += f" {density:10.4f}"
499                if "piston_temperature" in system["barostat"]:
500                    piston_temperature = system["barostat"]["piston_temperature"]
501                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
502
503            print(line)
504
505            ### save colvars
506            if use_colvars:
507                colvars = system["colvars"]
508                fcolvars.write(f"{simulated_time_ps:.3f} ")
509                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
510                fcolvars.write("\n")
511                fcolvars.flush()
512
513            if save_keys:
514                data = {
515                    k: (
516                        np.asarray(model_output[k])
517                        if isinstance(model_output[k], jnp.ndarray)
518                        else model_output[k]
519                    )
520                    for k in save_keys
521                }
522                data["properties"] = {
523                    k: float(v) for k, v in zip(header.split()[1:], line.split())
524                }
525                data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str)
526                data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str)
527
528                pickle.dump(data, fkeys)
529
530        ### save frame
531        if istep % ndump == 0:
532            line = "# Write XYZ frame"
533            if variable_cell:
534                cell = np.array(system["cell"])
535                reciprocal_cell = np.linalg.inv(cell)
536            if do_wrap_box:
537                if pimd:
538                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups)
539                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
540                else:
541                    system["coordinates"] = wrapbox(
542                        system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups
543                    )
544                conformation = update_conformation(conformation, system)
545                line += " (atoms have been wrapped into the box)"
546                force_preprocess = True
547            print(line)
548
549            save_dynamics_restart(system_data, conformation, dyn_state, system)
550
551            properties = {
552                "energy": float(system["epot"]) * energy_unit,
553                "Time_ps": start_time_ps + istep * dt * us.PS,
554                "energy_unit": energy_unit_str,
555            }
556
557            if write_all_beads:
558                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
559                for i, fb in enumerate(fout):
560                    write_frame(
561                        fb,
562                        system_data["symbols"],
563                        coords[i],
564                        cell=cell,
565                        properties=properties,
566                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
567                    )
568            else:
569                write_frame(
570                    fout,
571                    system_data["symbols"],
572                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
573                    cell=cell,
574                    properties=properties,
575                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
576                )
577            if write_centroid:
578                centroid = np.asarray(system["coordinates"][0])
579                write_frame(
580                    fcentroid,
581                    system_data["symbols"],
582                    centroid,
583                    cell=cell,
584                    properties=properties,
585                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
586                    * energy_unit,
587                )
588            if ensemble_key is not None:
589                weights = " ".join(
590                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
591                )
592                fens.write(f"{weights}\n")
593                fens.flush()
594
595        ### summary over last nsummary steps
596        if istep % (nsummary) == 0:
597            if check_nan(system):
598                raise ValueError(f"dynamics crashed at step {istep}.")
599            tfull = time.time() - t0full
600            t0full = time.time()
601            tperstep = tfull / (nsummary)
602            nsperday = (24 * 60 * 60 / tperstep) * dt * us.NS
603            elapsed_time = time.time() - tstart_dyn
604            estimated_remaining_time = tperstep * (nsteps - istep)
605            estimated_total_time = elapsed_time + estimated_remaining_time
606
607            print("#" * 50)
608            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
609            print(f"# Simulated time      : {istep * dt*us.PS:.3f} ps")
610            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*us.PS:.3f} ps")
611            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
612            print(
613                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
614            )
615            print(
616                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
617            )
618            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
619
620            corr_kin = system["thermostat"].get("corr_kin", None)
621            if corr_kin is not None:
622                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
623            print(f"# Averages over last {nsummary:_} steps :")
624            for k, v in properties_traj.items():
625                if len(properties_traj[k]) == 0:
626                    continue
627                mu = np.mean(properties_traj[k])
628                sig = np.std(properties_traj[k])
629                ksplit = k.split("[")
630                name = ksplit[0].strip()
631                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
632                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
633
634            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
635            print("#" * 50)
636            if istep < nsteps:
637                print(header)
638            ## reset property trajectories
639            properties_traj = defaultdict(list)
640
641    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
642    ### close trajectory file
643    fout.close()
644    if ensemble_key is not None:
645        fens.close()
646    if use_colvars:
647        fcolvars.close()
648    if fkeys is not None:
649        fkeys.close()
650    if write_centroid:
651        fcentroid.close()
652
653    return 0

Execute the molecular dynamics simulation loop.

This function performs the main MD simulation using the initialized system, integrator, and thermostats/barostats. It handles trajectory output, property monitoring, and restart file generation.

Parameters: simulation_parameters: Parsed simulation parameters device (str): Computation device ('cpu' or 'gpu') fprec (str): Floating point precision ('float32' or 'float64')

Returns: int: Exit code (0 for success)