fennol.md.integrate

  1import time
  2import math
  3import os
  4
  5import numpy as np
  6import jax
  7import jax.numpy as jnp
  8
  9from .thermostats import get_thermostat
 10from .barostats import get_barostat
 11from .colvars import setup_colvars
 12from .spectra import initialize_ir_spectrum
 13
 14from .utils import load_dynamics_restart, get_restart_file,optimize_fire2, us
 15from .initial import load_model, load_system_data, initialize_preprocessing
 16
 17
 18def initialize_dynamics(simulation_parameters, fprec, rng_key):
 19    ### LOAD MODEL
 20    model = load_model(simulation_parameters)
 21    model_energy_unit = us.get_multiplier(model.energy_unit)
 22
 23    ### Get the coordinates and species from the xyz file
 24    system_data, conformation = load_system_data(simulation_parameters, fprec)
 25    system_data["model_energy_unit"] = model_energy_unit
 26    system_data["model_energy_unit_str"] = model.energy_unit
 27
 28    ### FINISH BUILDING conformation
 29    do_restart = os.path.exists(get_restart_file(system_data))
 30    if do_restart:
 31        ### RESTART FROM PREVIOUS DYNAMICS
 32        restart_data = load_dynamics_restart(system_data)
 33        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 34        model.preproc_state = restart_data["preproc_state"]
 35        conformation["coordinates"] = restart_data["coordinates"]
 36    else:
 37        restart_data = {}
 38
 39    ### INITIALIZE PREPROCESSING
 40    preproc_state, conformation = initialize_preprocessing(
 41        simulation_parameters, model, conformation, system_data
 42    )
 43
 44    minimize = simulation_parameters.get("xyz_input/minimize", False)
 45    """@keyword[fennol_md] xyz_input/minimize
 46    Perform energy minimization before dynamics.
 47    Default: False
 48    """
 49    if minimize and not do_restart:
 50        assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems"
 51        model.preproc_state = preproc_state
 52        convert = us.KCALPERMOL / model_energy_unit
 53        nat = system_data["nat"]
 54        def energy_force_fn(coordinates):
 55            inputs = {**conformation, "coordinates": coordinates}
 56            e, f, _ = model.energy_and_forces(
 57                **inputs, gpu_preprocessing=True
 58            )
 59            e = float(e[0]) * convert / nat
 60            f = np.array(f) * convert
 61            return e, f
 62        tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL
 63        """@keyword[fennol_md] xyz_input/minimize_ftol
 64        Force tolerance for minimization.
 65        Default: 0.1 kcal/mol/Å
 66        """
 67        print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A")
 68        conformation["coordinates"], success = optimize_fire2(
 69            conformation["coordinates"],
 70            energy_force_fn,
 71            atol=tol,
 72            max_disp=0.02,
 73        )
 74        if success:
 75            print("# Minimization successful")
 76        else:
 77            print("# Warning: Minimization failed, continuing with last configuration")
 78        # write the minimized coordinates as an xyz file
 79        from ..utils.io import write_xyz_frame
 80        with open(system_data["name"]+".opt.xyz", "w") as f:
 81            write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None))
 82        print("# Minimized configuration written to", system_data["name"]+".opt.xyz")
 83        preproc_state = model.preproc_state
 84        conformation = model.preprocessing.process(preproc_state, conformation)
 85        system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy()
 86
 87    ### get dynamics parameters
 88    dt = simulation_parameters.get("dt")
 89    """@keyword[fennol_md] dt
 90    Integration time step. Required parameter.
 91    Type: float, Required
 92    """
 93    dt2 = 0.5 * dt
 94    mass = system_data["mass"]
 95    densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3)
 96    nat = system_data["nat"]
 97    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 98    ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3)
 99
100    nreplicas = system_data.get("nreplicas", 1)
101    nbeads = system_data.get("nbeads", None)
102    if nbeads is not None:
103        nreplicas = nbeads
104        dtm = dtm[None, :, :]
105    
106    fix_com = simulation_parameters.get("fix_com", False)
107    if fix_com:
108        massprop = system_data["mass_Da"]/system_data["totmass_Da"]
109        print("# Reference frame fixed to center of mass")
110        x = conformation["coordinates"].reshape(-1, nat, 3)
111        if nbeads is None:
112            xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True)
113        else:
114            xcom = jnp.sum(x*massprop[None,:,None], axis=(0,1), keepdims=True)/nbeads
115        conformation["coordinates"] = (x - xcom).reshape(-1, 3)
116
117    ### INITIALIZE DYNAMICS STATE
118    system = {"coordinates": conformation["coordinates"]}
119    dyn_state = {
120        "istep": 0,
121        "dt": dt,
122        "pimd": nbeads is not None,
123        "preproc_state": preproc_state,
124        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
125    }
126    gradient_keys = ["coordinates"]
127    thermo_updates = []
128
129    ### INITIALIZE THERMOSTAT
130    thermostat_rng, rng_key = jax.random.split(rng_key)
131    (
132        thermostat,
133        thermostat_post,
134        thermostat_state,
135        initial_vel,
136        dyn_state["thermostat_name"],
137        ekin_instant,
138    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
139    do_thermostat_post = thermostat_post is not None
140    if do_thermostat_post:
141        thermostat_post, post_state = thermostat_post
142        dyn_state["thermostat_post_state"] = post_state
143
144    assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'"
145
146    system["thermostat"] = thermostat_state
147
148    if "initial_velocities" in simulation_parameters:
149        initial_vel_file = simulation_parameters["initial_velocities"]
150        print(f"# Using initial velocities from {initial_vel_file}")
151        v0 = jnp.array(np.loadtxt(initial_vel_file).reshape(-1,nat, 3).astype(fprec))
152        if v0.shape[0] == 1:
153            if nbeads is None:
154                v0 = v0.repeat(nreplicas, axis=0)
155            else:
156                v0 = initial_vel.at[0].set(v0[0])
157        assert v0.shape == (nreplicas, nat, 3), f"initial_velocities file must have shape (nat, 3) or (nreplicas, nat, 3), but got {v0.shape}"
158        initial_vel = v0
159        if nbeads is None:
160            initial_vel = initial_vel.reshape(-1, 3)
161
162    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
163
164    ### PBC
165    pbc_data = system_data.get("pbc", None)
166    if pbc_data is not None:
167        ### INITIALIZE BAROSTAT
168        barostat_key, rng_key = jax.random.split(rng_key)
169        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
170            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
171        )
172        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
173        system["barostat"] = barostat_state
174        system["cell"] = conformation["cells"][0]
175        if estimate_pressure:
176            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0)
177            """@keyword[fennol_md] pressure_o_weight
178            Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator.
179            Default: 1.0
180            """
181            assert (
182                0.0 <= pressure_o_weight <= 1.0
183            ), "pressure_o_weight must be between 0 and 1"
184            gradient_keys.append("strain")
185        print("# Estimate pressure: ", estimate_pressure)
186    else:
187        estimate_pressure = False
188        variable_cell = False
189
190        def thermo_update_ensemble(x, v, system):
191            v, thermostat_state = thermostat(v, system["thermostat"])
192            return x, v, {**system, "thermostat": thermostat_state}
193
194    dyn_state["estimate_pressure"] = estimate_pressure
195    dyn_state["variable_cell"] = variable_cell
196    thermo_updates.append(thermo_update_ensemble)
197
198    if estimate_pressure:
199        use_average_Pkin = simulation_parameters.get("use_average_Pkin", False)
200        """@keyword[fennol_md] use_average_Pkin
201        Use time-averaged kinetic energy for pressure estimation instead of instantaneous values.
202        Default: False
203        """
204        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
205        if is_qtb and use_average_Pkin:
206            raise ValueError(
207                "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False"
208            )
209
210
211    ### ENERGY ENSEMBLE
212    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
213    """@keyword[fennol_md] etot_ensemble_key
214    Key for energy ensemble calculation. Enables computation of ensemble weights.
215    Default: None
216    """
217
218    ### COLVARS
219    colvars_definitions = simulation_parameters.get("colvars", None)
220    """@keyword[fennol_md] colvars
221    Collective variables definitions for enhanced sampling or monitoring.
222    Default: None
223    """
224    use_colvars = colvars_definitions is not None
225    if use_colvars:
226        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
227        dyn_state["colvars"] = colvars_names
228
229    ### IR SPECTRUM
230    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
231    """@keyword[fennol_md] ir_spectrum
232    Calculate infrared spectrum from molecular dipole moment time series.
233    Default: False
234    """
235    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
236    if do_ir_spectrum:
237        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
238        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
239            simulation_parameters, system_data, fprec, dt, is_qtb
240        )
241        dyn_state["ir_spectrum"] = ir_state
242
243    ### BUILD GRADIENT FUNCTION
244    energy_and_gradient = model.get_gradient_function(
245        *gradient_keys, jit=True, variables_as_input=True
246    )
247
248    ### COLLECT THERMO UPDATES
249    if len(thermo_updates) == 1:
250        thermo_update = thermo_updates[0]
251    else:
252
253        def thermo_update(x, v, system):
254            for update in thermo_updates:
255                x, v, system = update(x, v, system)
256            return x, v, system
257
258    ### RING POLYMER INITIALIZATION
259    if nbeads is not None:
260        cay_correction = simulation_parameters.get("cay_correction", True)
261        """@keyword[fennol_md] cay_correction
262        Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation.
263        Default: True
264        """
265        omk = system_data["omk"]
266        eigmat = system_data["eigmat"]
267        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
268        if cay_correction:
269            axx = jnp.asarray(2 * cayfact)
270            axv = jnp.asarray(dt * cayfact)
271            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
272        else:
273            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
274            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
275            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
276
277        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
278        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
279        system["coordinates"] = eigx
280
281    ###############################################
282    ### DEFINE UPDATE FUNCTION
283    @jax.jit
284    def update_conformation(conformation, system):
285        x = system["coordinates"]
286        if nbeads is not None:
287            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
288                nbeads**0.5
289            )
290        conformation = {**conformation, "coordinates": x}
291        if variable_cell:
292            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
293
294        
295
296        return conformation
297
298    ###############################################
299    ### DEFINE INTEGRATION FUNCTIONS
300    def integrate_A_half(x0, v0):
301        if nbeads is None:
302            return x0 + dt2 * v0, v0
303
304        # update coordinates and velocities of a free ring polymer for a half time step
305        eigx_c = x0[0] + dt2 * v0[0]
306        eigv_c = v0[0]
307        eigx = x0[1:] * axx + v0[1:] * axv
308        eigv = x0[1:] * avx + v0[1:] * axx
309
310        return (
311            jnp.concatenate((eigx_c[None], eigx), axis=0),
312            jnp.concatenate((eigv_c[None], eigv), axis=0),
313        )
314
315    @jax.jit
316    def integrate(system):
317        x = system["coordinates"]
318        v = system["vel"] + dtm * system["forces"]
319        x, v = integrate_A_half(x, v)
320        x, v, system = thermo_update(x, v, system)
321        x, v = integrate_A_half(x, v)
322
323        if fix_com:
324            if nbeads is None:
325                x = x.reshape(-1, nat, 3)
326                xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True)
327                x = (x - xcom).reshape(-1, 3)
328            else:
329                xcom = jnp.sum(x[0]*massprop[:,None], axis=0, keepdims=True)
330                x = x.at[0].set(x[0] - xcom)
331
332        return {**system, "coordinates": x, "vel": v}
333
334    ###############################################
335    ### DEFINE OBSERVABLE FUNCTION
336    @jax.jit
337    def update_observables(system, conformation):
338        ### POTENTIAL ENERGY AND FORCES
339        epot, de, out = energy_and_gradient(model.variables, conformation)
340        out["forces"] = -de["coordinates"]
341        epot = epot / model_energy_unit
342        de = {k: v / model_energy_unit for k, v in de.items()}
343        forces = -de["coordinates"]
344
345        if nbeads is not None:
346            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
347            forces = jnp.einsum(
348                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
349            ) * (1.0 / nbeads**0.5)
350
351        
352        system = {
353            **system,
354            "epot": jnp.mean(epot),
355            "forces": forces,
356            "energy_gradients": de,
357        }
358
359        ### KINETIC ENERGY
360        v = system["vel"]
361        if ekin_instant == "B":
362            v = v + 0.5 * dtm * forces
363                    
364        if nbeads is None:
365            corr_kin = system["thermostat"].get("corr_kin", 1.0)
366            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
367            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
368                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
369            )
370        else:
371            ek_c = 0.5 * jnp.sum(
372                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
373            )
374            ek = ek_c - 0.5 * jnp.sum(
375                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
376                axis=(0, 1),
377            )
378            system["ek_c"] = jnp.trace(ek_c)
379
380        system["ek"] = jnp.trace(ek)
381        system["ek_tensor"] = ek
382
383        if estimate_pressure:
384            if use_average_Pkin:
385                ek = ek_avg
386            elif pressure_o_weight != 1.0:
387                v = system["vel"] + 0.5 * dtm * system["forces"]
388                if nbeads is None:
389                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
390                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
391                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
392                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
393                    )
394                else:
395                    ek_c = 0.5 * jnp.sum(
396                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
397                        axis=0,
398                    )
399                    ek = ek_c - 0.5 * jnp.sum(
400                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
401                        axis=(0, 1),
402                    )
403                b = pressure_o_weight
404                ek = (1.0 - b) * ek + b * system["ek_tensor"]
405
406            vir = jnp.mean(de["strain"], axis=0)
407            system["virial"] = vir
408            out["virial_tensor"] = vir * model_energy_unit
409            
410            volume = jnp.abs(jnp.linalg.det(system["cell"]))
411            Pres =  ek*(2./volume)  - vir/volume
412            system["pressure_tensor"] = Pres
413            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
414            if variable_cell:
415                density = densmass / volume
416                system["density"] = density
417                system["volume"] = volume
418
419        if ensemble_key is not None:
420            kT = system_data["kT"]
421            dE = (
422                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
423            )
424            system["ensemble_weights"] = -dE / kT
425
426        if "total_dipole" in out:
427            if nbeads is None:
428                system["total_dipole"] = out["total_dipole"][0]
429            else:
430                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
431
432        if use_colvars:
433            coords = system["coordinates"].reshape(-1, nat, 3)[0]
434            colvars = {}
435            for colvar_name, colvar_calc in colvars_calculators.items():
436                colvars[colvar_name] = colvar_calc(coords)
437            system["colvars"] = colvars
438
439        return system, out
440
441    ###############################################
442    ### IR SPECTRUM
443    if do_ir_spectrum:
444        # @jax.jit
445        # def update_dipole(ir_state,system,conformation):
446        #     def mumodel(coords):
447        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
448        #         if nbeads is None:
449        #             return out["total_dipole"][0]
450        #         return out["total_dipole"].sum(axis=0)
451        #     dmudqmodel = jax.jacobian(mumodel)
452
453        #     dmudq = dmudqmodel(conformation["coordinates"])
454        #     # print(dmudq.shape)
455        #     if nbeads is None:
456        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
457        #         mudot = (vel*dmudq).sum(axis=(1,2))
458        #     else:
459        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
460        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
461        #         )
462        #         # vel = system["vel"][0].reshape(1,nat,3)
463        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
464
465        #     ir_state = save_dipole(mudot,ir_state)
466        #     return ir_state
467        @jax.jit
468        def update_conformation_ir(conformation, system):
469            conformation = {
470                **conformation,
471                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
472                "natoms": jnp.asarray([nat]),
473                "batch_index": jnp.asarray([0] * nat),
474                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
475            }
476            if variable_cell:
477                conformation["cells"] = system["cell"][None, :, :]
478                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
479                    None, :, :
480                ]
481            return conformation
482
483        @jax.jit
484        def update_dipole(ir_state, system, conformation):
485            if model_ir is not None:
486                out = model_ir._apply(model_ir.variables, conformation)
487                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
488                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
489            else:
490                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
491                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
492            if nbeads is not None:
493                q = jnp.mean(q, axis=0)
494                dip = jnp.mean(dip, axis=0)
495                vel = system["vel"][0]
496                pos = system["coordinates"][0]
497            else:
498                q = q[0]
499                dip = dip[0]
500                vel = system["vel"].reshape(-1, nat, 3)[0]
501                pos = system["coordinates"].reshape(-1, nat, 3)[0]
502
503            if pbc_data is not None:
504                cell_reciprocal = (
505                    conformation["cells"][0],
506                    conformation["reciprocal_cells"][0],
507                )
508            else:
509                cell_reciprocal = None
510
511            ir_state = save_dipole(
512                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
513            )
514            return ir_state
515
516    ###############################################
517    ### GRAPH UPDATES
518
519    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
520    """@keyword[fennol_md] nblist_verbose
521    Print verbose information about neighbor list updates and reallocations.
522    Default: False
523    """
524    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
525    """@keyword[fennol_md] nblist_stride
526    Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0.
527    Default: -1
528    """
529    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0)
530    """@keyword[fennol_md] nblist_warmup_time
531    Time period for neighbor list warmup before using skin updates.
532    Default: -1.0
533    """
534    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
535    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
536    """@keyword[fennol_md] nblist_skin
537    Neighbor list skin distance for efficient updates (in Angstroms).
538    Default: -1.0
539    """
540    if nblist_skin > 0:
541        if nblist_stride <= 0:
542            ## reference skin parameters at 300K (from Tinker-HP)
543            ##   => skin of 2 A gives you 40 fs without complete rebuild
544            t_ref = 40.0 /us.FS # FS
545            nblist_skin_ref = 2.0  # A
546            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
547        print(
548            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
549        )
550
551    if nblist_skin <= 0:
552        nblist_stride = 1
553
554    dyn_state["nblist_countdown"] = 0
555    dyn_state["print_skin_activation"] = nblist_warmup > 0
556
557    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
558        nblist_countdown = dyn_state["nblist_countdown"]
559        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
560            ### FULL NBLIST REBUILD
561            dyn_state["nblist_countdown"] = nblist_stride - 1
562            preproc_state = dyn_state["preproc_state"]
563            conformation = model.preprocessing.process(
564                preproc_state, update_conformation(conformation, system)
565            )
566            preproc_state, state_up, conformation, overflow = (
567                model.preprocessing.check_reallocate(preproc_state, conformation)
568            )
569            dyn_state["preproc_state"] = preproc_state
570            if nblist_verbose and overflow:
571                print("step", istep, ", nblist overflow => reallocating nblist")
572                print("size updates:", state_up)
573
574            if do_ir_spectrum and model_ir is not None:
575                conformation_ir = model_ir.preprocessing.process(
576                    dyn_state["preproc_state_ir"],
577                    update_conformation_ir(dyn_state["conformation_ir"], system),
578                )
579                (
580                    dyn_state["preproc_state_ir"],
581                    _,
582                    dyn_state["conformation_ir"],
583                    overflow,
584                ) = model_ir.preprocessing.check_reallocate(
585                    dyn_state["preproc_state_ir"], conformation_ir
586                )
587
588        else:
589            ### SKIN UPDATE
590            if dyn_state["print_skin_activation"]:
591                if nblist_verbose:
592                    print(
593                        "step",
594                        istep,
595                        ", end of nblist warmup phase => activating skin updates",
596                    )
597                dyn_state["print_skin_activation"] = False
598
599            dyn_state["nblist_countdown"] = nblist_countdown - 1
600            conformation = model.preprocessing.update_skin(
601                update_conformation(conformation, system)
602            )
603            if do_ir_spectrum and model_ir is not None:
604                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
605                    update_conformation_ir(dyn_state["conformation_ir"], system)
606                )
607
608        return conformation, dyn_state
609
610    ################################################
611    ### DEFINE STEP FUNCTION
612    def step(istep, dyn_state, system, conformation, force_preprocess=False):
613
614        dyn_state = {
615            **dyn_state,
616            "istep": dyn_state["istep"] + 1,
617        }
618
619        ### INTEGRATE EQUATIONS OF MOTION
620        system = integrate(system)
621
622        ### UPDATE CONFORMATION AND GRAPHS
623        conformation, dyn_state = update_graphs(
624            istep, dyn_state, system, conformation, force_preprocess
625        )
626
627        ## COMPUTE FORCES AND OBSERVABLES
628        system, out = update_observables(system, conformation)
629
630        ## END OF STEP UPDATES
631        if do_thermostat_post:
632            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
633                system["thermostat"], dyn_state["thermostat_post_state"]
634            )
635        
636        if do_ir_spectrum:
637            ir_state = update_dipole(
638                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
639            )
640            dyn_state["ir_spectrum"] = ir_post(ir_state)
641
642        return dyn_state, system, conformation, out
643
644    ###########################################################
645
646    print("# Computing initial energy and forces")
647
648    conformation = update_conformation(conformation, system)
649    # initialize IR conformation
650    if do_ir_spectrum and model_ir is not None:
651        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
652            model_ir.preprocessing(
653                model_ir.preproc_state,
654                update_conformation_ir(conformation, system),
655            )
656        )
657
658    system, _ = update_observables(system, conformation)
659
660    return step, update_conformation, system_data, dyn_state, conformation, system
def initialize_dynamics(simulation_parameters, fprec, rng_key):
 19def initialize_dynamics(simulation_parameters, fprec, rng_key):
 20    ### LOAD MODEL
 21    model = load_model(simulation_parameters)
 22    model_energy_unit = us.get_multiplier(model.energy_unit)
 23
 24    ### Get the coordinates and species from the xyz file
 25    system_data, conformation = load_system_data(simulation_parameters, fprec)
 26    system_data["model_energy_unit"] = model_energy_unit
 27    system_data["model_energy_unit_str"] = model.energy_unit
 28
 29    ### FINISH BUILDING conformation
 30    do_restart = os.path.exists(get_restart_file(system_data))
 31    if do_restart:
 32        ### RESTART FROM PREVIOUS DYNAMICS
 33        restart_data = load_dynamics_restart(system_data)
 34        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 35        model.preproc_state = restart_data["preproc_state"]
 36        conformation["coordinates"] = restart_data["coordinates"]
 37    else:
 38        restart_data = {}
 39
 40    ### INITIALIZE PREPROCESSING
 41    preproc_state, conformation = initialize_preprocessing(
 42        simulation_parameters, model, conformation, system_data
 43    )
 44
 45    minimize = simulation_parameters.get("xyz_input/minimize", False)
 46    """@keyword[fennol_md] xyz_input/minimize
 47    Perform energy minimization before dynamics.
 48    Default: False
 49    """
 50    if minimize and not do_restart:
 51        assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems"
 52        model.preproc_state = preproc_state
 53        convert = us.KCALPERMOL / model_energy_unit
 54        nat = system_data["nat"]
 55        def energy_force_fn(coordinates):
 56            inputs = {**conformation, "coordinates": coordinates}
 57            e, f, _ = model.energy_and_forces(
 58                **inputs, gpu_preprocessing=True
 59            )
 60            e = float(e[0]) * convert / nat
 61            f = np.array(f) * convert
 62            return e, f
 63        tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL
 64        """@keyword[fennol_md] xyz_input/minimize_ftol
 65        Force tolerance for minimization.
 66        Default: 0.1 kcal/mol/Å
 67        """
 68        print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A")
 69        conformation["coordinates"], success = optimize_fire2(
 70            conformation["coordinates"],
 71            energy_force_fn,
 72            atol=tol,
 73            max_disp=0.02,
 74        )
 75        if success:
 76            print("# Minimization successful")
 77        else:
 78            print("# Warning: Minimization failed, continuing with last configuration")
 79        # write the minimized coordinates as an xyz file
 80        from ..utils.io import write_xyz_frame
 81        with open(system_data["name"]+".opt.xyz", "w") as f:
 82            write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None))
 83        print("# Minimized configuration written to", system_data["name"]+".opt.xyz")
 84        preproc_state = model.preproc_state
 85        conformation = model.preprocessing.process(preproc_state, conformation)
 86        system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy()
 87
 88    ### get dynamics parameters
 89    dt = simulation_parameters.get("dt")
 90    """@keyword[fennol_md] dt
 91    Integration time step. Required parameter.
 92    Type: float, Required
 93    """
 94    dt2 = 0.5 * dt
 95    mass = system_data["mass"]
 96    densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3)
 97    nat = system_data["nat"]
 98    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 99    ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3)
100
101    nreplicas = system_data.get("nreplicas", 1)
102    nbeads = system_data.get("nbeads", None)
103    if nbeads is not None:
104        nreplicas = nbeads
105        dtm = dtm[None, :, :]
106    
107    fix_com = simulation_parameters.get("fix_com", False)
108    if fix_com:
109        massprop = system_data["mass_Da"]/system_data["totmass_Da"]
110        print("# Reference frame fixed to center of mass")
111        x = conformation["coordinates"].reshape(-1, nat, 3)
112        if nbeads is None:
113            xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True)
114        else:
115            xcom = jnp.sum(x*massprop[None,:,None], axis=(0,1), keepdims=True)/nbeads
116        conformation["coordinates"] = (x - xcom).reshape(-1, 3)
117
118    ### INITIALIZE DYNAMICS STATE
119    system = {"coordinates": conformation["coordinates"]}
120    dyn_state = {
121        "istep": 0,
122        "dt": dt,
123        "pimd": nbeads is not None,
124        "preproc_state": preproc_state,
125        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
126    }
127    gradient_keys = ["coordinates"]
128    thermo_updates = []
129
130    ### INITIALIZE THERMOSTAT
131    thermostat_rng, rng_key = jax.random.split(rng_key)
132    (
133        thermostat,
134        thermostat_post,
135        thermostat_state,
136        initial_vel,
137        dyn_state["thermostat_name"],
138        ekin_instant,
139    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
140    do_thermostat_post = thermostat_post is not None
141    if do_thermostat_post:
142        thermostat_post, post_state = thermostat_post
143        dyn_state["thermostat_post_state"] = post_state
144
145    assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'"
146
147    system["thermostat"] = thermostat_state
148
149    if "initial_velocities" in simulation_parameters:
150        initial_vel_file = simulation_parameters["initial_velocities"]
151        print(f"# Using initial velocities from {initial_vel_file}")
152        v0 = jnp.array(np.loadtxt(initial_vel_file).reshape(-1,nat, 3).astype(fprec))
153        if v0.shape[0] == 1:
154            if nbeads is None:
155                v0 = v0.repeat(nreplicas, axis=0)
156            else:
157                v0 = initial_vel.at[0].set(v0[0])
158        assert v0.shape == (nreplicas, nat, 3), f"initial_velocities file must have shape (nat, 3) or (nreplicas, nat, 3), but got {v0.shape}"
159        initial_vel = v0
160        if nbeads is None:
161            initial_vel = initial_vel.reshape(-1, 3)
162
163    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
164
165    ### PBC
166    pbc_data = system_data.get("pbc", None)
167    if pbc_data is not None:
168        ### INITIALIZE BAROSTAT
169        barostat_key, rng_key = jax.random.split(rng_key)
170        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
171            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
172        )
173        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
174        system["barostat"] = barostat_state
175        system["cell"] = conformation["cells"][0]
176        if estimate_pressure:
177            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0)
178            """@keyword[fennol_md] pressure_o_weight
179            Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator.
180            Default: 1.0
181            """
182            assert (
183                0.0 <= pressure_o_weight <= 1.0
184            ), "pressure_o_weight must be between 0 and 1"
185            gradient_keys.append("strain")
186        print("# Estimate pressure: ", estimate_pressure)
187    else:
188        estimate_pressure = False
189        variable_cell = False
190
191        def thermo_update_ensemble(x, v, system):
192            v, thermostat_state = thermostat(v, system["thermostat"])
193            return x, v, {**system, "thermostat": thermostat_state}
194
195    dyn_state["estimate_pressure"] = estimate_pressure
196    dyn_state["variable_cell"] = variable_cell
197    thermo_updates.append(thermo_update_ensemble)
198
199    if estimate_pressure:
200        use_average_Pkin = simulation_parameters.get("use_average_Pkin", False)
201        """@keyword[fennol_md] use_average_Pkin
202        Use time-averaged kinetic energy for pressure estimation instead of instantaneous values.
203        Default: False
204        """
205        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
206        if is_qtb and use_average_Pkin:
207            raise ValueError(
208                "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False"
209            )
210
211
212    ### ENERGY ENSEMBLE
213    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
214    """@keyword[fennol_md] etot_ensemble_key
215    Key for energy ensemble calculation. Enables computation of ensemble weights.
216    Default: None
217    """
218
219    ### COLVARS
220    colvars_definitions = simulation_parameters.get("colvars", None)
221    """@keyword[fennol_md] colvars
222    Collective variables definitions for enhanced sampling or monitoring.
223    Default: None
224    """
225    use_colvars = colvars_definitions is not None
226    if use_colvars:
227        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
228        dyn_state["colvars"] = colvars_names
229
230    ### IR SPECTRUM
231    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
232    """@keyword[fennol_md] ir_spectrum
233    Calculate infrared spectrum from molecular dipole moment time series.
234    Default: False
235    """
236    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
237    if do_ir_spectrum:
238        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
239        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
240            simulation_parameters, system_data, fprec, dt, is_qtb
241        )
242        dyn_state["ir_spectrum"] = ir_state
243
244    ### BUILD GRADIENT FUNCTION
245    energy_and_gradient = model.get_gradient_function(
246        *gradient_keys, jit=True, variables_as_input=True
247    )
248
249    ### COLLECT THERMO UPDATES
250    if len(thermo_updates) == 1:
251        thermo_update = thermo_updates[0]
252    else:
253
254        def thermo_update(x, v, system):
255            for update in thermo_updates:
256                x, v, system = update(x, v, system)
257            return x, v, system
258
259    ### RING POLYMER INITIALIZATION
260    if nbeads is not None:
261        cay_correction = simulation_parameters.get("cay_correction", True)
262        """@keyword[fennol_md] cay_correction
263        Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation.
264        Default: True
265        """
266        omk = system_data["omk"]
267        eigmat = system_data["eigmat"]
268        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
269        if cay_correction:
270            axx = jnp.asarray(2 * cayfact)
271            axv = jnp.asarray(dt * cayfact)
272            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
273        else:
274            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
275            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
276            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
277
278        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
279        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
280        system["coordinates"] = eigx
281
282    ###############################################
283    ### DEFINE UPDATE FUNCTION
284    @jax.jit
285    def update_conformation(conformation, system):
286        x = system["coordinates"]
287        if nbeads is not None:
288            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
289                nbeads**0.5
290            )
291        conformation = {**conformation, "coordinates": x}
292        if variable_cell:
293            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
294
295        
296
297        return conformation
298
299    ###############################################
300    ### DEFINE INTEGRATION FUNCTIONS
301    def integrate_A_half(x0, v0):
302        if nbeads is None:
303            return x0 + dt2 * v0, v0
304
305        # update coordinates and velocities of a free ring polymer for a half time step
306        eigx_c = x0[0] + dt2 * v0[0]
307        eigv_c = v0[0]
308        eigx = x0[1:] * axx + v0[1:] * axv
309        eigv = x0[1:] * avx + v0[1:] * axx
310
311        return (
312            jnp.concatenate((eigx_c[None], eigx), axis=0),
313            jnp.concatenate((eigv_c[None], eigv), axis=0),
314        )
315
316    @jax.jit
317    def integrate(system):
318        x = system["coordinates"]
319        v = system["vel"] + dtm * system["forces"]
320        x, v = integrate_A_half(x, v)
321        x, v, system = thermo_update(x, v, system)
322        x, v = integrate_A_half(x, v)
323
324        if fix_com:
325            if nbeads is None:
326                x = x.reshape(-1, nat, 3)
327                xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True)
328                x = (x - xcom).reshape(-1, 3)
329            else:
330                xcom = jnp.sum(x[0]*massprop[:,None], axis=0, keepdims=True)
331                x = x.at[0].set(x[0] - xcom)
332
333        return {**system, "coordinates": x, "vel": v}
334
335    ###############################################
336    ### DEFINE OBSERVABLE FUNCTION
337    @jax.jit
338    def update_observables(system, conformation):
339        ### POTENTIAL ENERGY AND FORCES
340        epot, de, out = energy_and_gradient(model.variables, conformation)
341        out["forces"] = -de["coordinates"]
342        epot = epot / model_energy_unit
343        de = {k: v / model_energy_unit for k, v in de.items()}
344        forces = -de["coordinates"]
345
346        if nbeads is not None:
347            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
348            forces = jnp.einsum(
349                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
350            ) * (1.0 / nbeads**0.5)
351
352        
353        system = {
354            **system,
355            "epot": jnp.mean(epot),
356            "forces": forces,
357            "energy_gradients": de,
358        }
359
360        ### KINETIC ENERGY
361        v = system["vel"]
362        if ekin_instant == "B":
363            v = v + 0.5 * dtm * forces
364                    
365        if nbeads is None:
366            corr_kin = system["thermostat"].get("corr_kin", 1.0)
367            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
368            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
369                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
370            )
371        else:
372            ek_c = 0.5 * jnp.sum(
373                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
374            )
375            ek = ek_c - 0.5 * jnp.sum(
376                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
377                axis=(0, 1),
378            )
379            system["ek_c"] = jnp.trace(ek_c)
380
381        system["ek"] = jnp.trace(ek)
382        system["ek_tensor"] = ek
383
384        if estimate_pressure:
385            if use_average_Pkin:
386                ek = ek_avg
387            elif pressure_o_weight != 1.0:
388                v = system["vel"] + 0.5 * dtm * system["forces"]
389                if nbeads is None:
390                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
391                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
392                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
393                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
394                    )
395                else:
396                    ek_c = 0.5 * jnp.sum(
397                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
398                        axis=0,
399                    )
400                    ek = ek_c - 0.5 * jnp.sum(
401                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
402                        axis=(0, 1),
403                    )
404                b = pressure_o_weight
405                ek = (1.0 - b) * ek + b * system["ek_tensor"]
406
407            vir = jnp.mean(de["strain"], axis=0)
408            system["virial"] = vir
409            out["virial_tensor"] = vir * model_energy_unit
410            
411            volume = jnp.abs(jnp.linalg.det(system["cell"]))
412            Pres =  ek*(2./volume)  - vir/volume
413            system["pressure_tensor"] = Pres
414            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
415            if variable_cell:
416                density = densmass / volume
417                system["density"] = density
418                system["volume"] = volume
419
420        if ensemble_key is not None:
421            kT = system_data["kT"]
422            dE = (
423                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
424            )
425            system["ensemble_weights"] = -dE / kT
426
427        if "total_dipole" in out:
428            if nbeads is None:
429                system["total_dipole"] = out["total_dipole"][0]
430            else:
431                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
432
433        if use_colvars:
434            coords = system["coordinates"].reshape(-1, nat, 3)[0]
435            colvars = {}
436            for colvar_name, colvar_calc in colvars_calculators.items():
437                colvars[colvar_name] = colvar_calc(coords)
438            system["colvars"] = colvars
439
440        return system, out
441
442    ###############################################
443    ### IR SPECTRUM
444    if do_ir_spectrum:
445        # @jax.jit
446        # def update_dipole(ir_state,system,conformation):
447        #     def mumodel(coords):
448        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
449        #         if nbeads is None:
450        #             return out["total_dipole"][0]
451        #         return out["total_dipole"].sum(axis=0)
452        #     dmudqmodel = jax.jacobian(mumodel)
453
454        #     dmudq = dmudqmodel(conformation["coordinates"])
455        #     # print(dmudq.shape)
456        #     if nbeads is None:
457        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
458        #         mudot = (vel*dmudq).sum(axis=(1,2))
459        #     else:
460        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
461        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
462        #         )
463        #         # vel = system["vel"][0].reshape(1,nat,3)
464        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
465
466        #     ir_state = save_dipole(mudot,ir_state)
467        #     return ir_state
468        @jax.jit
469        def update_conformation_ir(conformation, system):
470            conformation = {
471                **conformation,
472                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
473                "natoms": jnp.asarray([nat]),
474                "batch_index": jnp.asarray([0] * nat),
475                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
476            }
477            if variable_cell:
478                conformation["cells"] = system["cell"][None, :, :]
479                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
480                    None, :, :
481                ]
482            return conformation
483
484        @jax.jit
485        def update_dipole(ir_state, system, conformation):
486            if model_ir is not None:
487                out = model_ir._apply(model_ir.variables, conformation)
488                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
489                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
490            else:
491                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
492                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
493            if nbeads is not None:
494                q = jnp.mean(q, axis=0)
495                dip = jnp.mean(dip, axis=0)
496                vel = system["vel"][0]
497                pos = system["coordinates"][0]
498            else:
499                q = q[0]
500                dip = dip[0]
501                vel = system["vel"].reshape(-1, nat, 3)[0]
502                pos = system["coordinates"].reshape(-1, nat, 3)[0]
503
504            if pbc_data is not None:
505                cell_reciprocal = (
506                    conformation["cells"][0],
507                    conformation["reciprocal_cells"][0],
508                )
509            else:
510                cell_reciprocal = None
511
512            ir_state = save_dipole(
513                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
514            )
515            return ir_state
516
517    ###############################################
518    ### GRAPH UPDATES
519
520    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
521    """@keyword[fennol_md] nblist_verbose
522    Print verbose information about neighbor list updates and reallocations.
523    Default: False
524    """
525    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
526    """@keyword[fennol_md] nblist_stride
527    Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0.
528    Default: -1
529    """
530    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0)
531    """@keyword[fennol_md] nblist_warmup_time
532    Time period for neighbor list warmup before using skin updates.
533    Default: -1.0
534    """
535    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
536    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
537    """@keyword[fennol_md] nblist_skin
538    Neighbor list skin distance for efficient updates (in Angstroms).
539    Default: -1.0
540    """
541    if nblist_skin > 0:
542        if nblist_stride <= 0:
543            ## reference skin parameters at 300K (from Tinker-HP)
544            ##   => skin of 2 A gives you 40 fs without complete rebuild
545            t_ref = 40.0 /us.FS # FS
546            nblist_skin_ref = 2.0  # A
547            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
548        print(
549            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
550        )
551
552    if nblist_skin <= 0:
553        nblist_stride = 1
554
555    dyn_state["nblist_countdown"] = 0
556    dyn_state["print_skin_activation"] = nblist_warmup > 0
557
558    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
559        nblist_countdown = dyn_state["nblist_countdown"]
560        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
561            ### FULL NBLIST REBUILD
562            dyn_state["nblist_countdown"] = nblist_stride - 1
563            preproc_state = dyn_state["preproc_state"]
564            conformation = model.preprocessing.process(
565                preproc_state, update_conformation(conformation, system)
566            )
567            preproc_state, state_up, conformation, overflow = (
568                model.preprocessing.check_reallocate(preproc_state, conformation)
569            )
570            dyn_state["preproc_state"] = preproc_state
571            if nblist_verbose and overflow:
572                print("step", istep, ", nblist overflow => reallocating nblist")
573                print("size updates:", state_up)
574
575            if do_ir_spectrum and model_ir is not None:
576                conformation_ir = model_ir.preprocessing.process(
577                    dyn_state["preproc_state_ir"],
578                    update_conformation_ir(dyn_state["conformation_ir"], system),
579                )
580                (
581                    dyn_state["preproc_state_ir"],
582                    _,
583                    dyn_state["conformation_ir"],
584                    overflow,
585                ) = model_ir.preprocessing.check_reallocate(
586                    dyn_state["preproc_state_ir"], conformation_ir
587                )
588
589        else:
590            ### SKIN UPDATE
591            if dyn_state["print_skin_activation"]:
592                if nblist_verbose:
593                    print(
594                        "step",
595                        istep,
596                        ", end of nblist warmup phase => activating skin updates",
597                    )
598                dyn_state["print_skin_activation"] = False
599
600            dyn_state["nblist_countdown"] = nblist_countdown - 1
601            conformation = model.preprocessing.update_skin(
602                update_conformation(conformation, system)
603            )
604            if do_ir_spectrum and model_ir is not None:
605                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
606                    update_conformation_ir(dyn_state["conformation_ir"], system)
607                )
608
609        return conformation, dyn_state
610
611    ################################################
612    ### DEFINE STEP FUNCTION
613    def step(istep, dyn_state, system, conformation, force_preprocess=False):
614
615        dyn_state = {
616            **dyn_state,
617            "istep": dyn_state["istep"] + 1,
618        }
619
620        ### INTEGRATE EQUATIONS OF MOTION
621        system = integrate(system)
622
623        ### UPDATE CONFORMATION AND GRAPHS
624        conformation, dyn_state = update_graphs(
625            istep, dyn_state, system, conformation, force_preprocess
626        )
627
628        ## COMPUTE FORCES AND OBSERVABLES
629        system, out = update_observables(system, conformation)
630
631        ## END OF STEP UPDATES
632        if do_thermostat_post:
633            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
634                system["thermostat"], dyn_state["thermostat_post_state"]
635            )
636        
637        if do_ir_spectrum:
638            ir_state = update_dipole(
639                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
640            )
641            dyn_state["ir_spectrum"] = ir_post(ir_state)
642
643        return dyn_state, system, conformation, out
644
645    ###########################################################
646
647    print("# Computing initial energy and forces")
648
649    conformation = update_conformation(conformation, system)
650    # initialize IR conformation
651    if do_ir_spectrum and model_ir is not None:
652        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
653            model_ir.preprocessing(
654                model_ir.preproc_state,
655                update_conformation_ir(conformation, system),
656            )
657        )
658
659    system, _ = update_observables(system, conformation)
660
661    return step, update_conformation, system_data, dyn_state, conformation, system