fennol.md.thermostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7import os
  8import pickle
  9
 10from .utils import us 
 11from ..utils import Counter,read_tinker_interval
 12from ..utils.deconvolution import (
 13    deconvolute_spectrum,
 14    kernel_lorentz_pot,
 15    kernel_lorentz,
 16)
 17from ..utils.periodic_table import PERIODIC_TABLE
 18
 19
 20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 21    state = {}
 22    postprocess = None
 23    
 24    default_ekin_instant = "B"
 25
 26    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 27    """@keyword[fennol_md] thermostat
 28    Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'.
 29    Default: "LGV"
 30    """
 31    compute_thermostat_energy = simulation_parameters.get(
 32        "include_thermostat_energy", False
 33    )
 34    """@keyword[fennol_md] include_thermostat_energy
 35    Include thermostat energy in total energy calculations.
 36    Default: False
 37    """
 38
 39    kT = system_data.get("kT", None)
 40    nbeads = system_data.get("nbeads", None)
 41    mass = system_data["mass"]
 42    gamma0 = simulation_parameters.get("gamma", 1.0 / us.THZ)
 43    """@keyword[fennol_md] gamma
 44    Friction coefficient for Langevin thermostat.
 45    Default: 1.0 ps^-1
 46    """
 47    gamma = gamma0
 48    if gamma <= 0.0:
 49        gamma = 0.0
 50    species = system_data["species"]
 51
 52    if nbeads is not None:
 53        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 54        """@keyword[fennol_md] trpmd_lambda
 55        Lambda parameter for TRPMD (Thermostatted Ring Polymer MD).
 56        Default: 1.0
 57        """
 58        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 59
 60    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 61        default_ekin_instant = "O"
 62        if gamma0 <= 1.e-5:
 63            default_ekin_instant = "B"
 64        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 65        assert kT is not None, "kT must be specified for QTB thermostat"
 66        assert gamma is not None, "gamma must be specified for QTB thermostat"
 67        rng_key, v_key = jax.random.split(rng_key)
 68        if nbeads is None:
 69            a1 = math.exp(-gamma * dt)
 70            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 71            vel = (
 72                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 73                * (kT / mass[:, None]) ** 0.5
 74            )
 75        else:
 76            if isinstance(gamma, float):
 77                gamma = np.array([gamma] * nbeads)
 78            assert isinstance(
 79                gamma, np.ndarray
 80            ), "gamma must be a float or a numpy array"
 81            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 82            a1 = np.exp(-gamma * dt)[:, None, None]
 83            a2 = jnp.asarray(
 84                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 85            )
 86            vel = (
 87                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 88                * (kT / mass[:, None]) ** 0.5
 89            )
 90
 91        state["rng_key"] = rng_key
 92        if compute_thermostat_energy:
 93            state["thermostat_energy"] = 0.0
 94        if thermostat_name == "FFLGV":
 95            def thermostat(vel, state):
 96                rng_key, noise_key = jax.random.split(state["rng_key"])
 97                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 98                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 99                dirvel = vel / norm_vel
100                if compute_thermostat_energy:
101                    v2 = (vel**2).sum(axis=-1)
102                vel = a1 * vel + a2 * noise
103                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
104                vel = dirvel * new_norm_vel
105                new_state = {**state, "rng_key": rng_key}
106                if compute_thermostat_energy:
107                    v2new = (vel**2).sum(axis=-1)
108                    new_state["thermostat_energy"] = (
109                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
110                    )
111
112                return vel, new_state
113
114        else:
115            def thermostat(vel, state):
116                rng_key, noise_key = jax.random.split(state["rng_key"])
117                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
118                if compute_thermostat_energy:
119                    v2 = (vel**2).sum(axis=-1)
120                vel = a1 * vel + a2 * noise
121                new_state = {**state, "rng_key": rng_key}
122                if compute_thermostat_energy:
123                    v2new = (vel**2).sum(axis=-1)
124                    new_state["thermostat_energy"] = (
125                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
126                    )
127                return vel, new_state
128
129    elif thermostat_name in ["BUSSI"]:
130        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
131        assert kT is not None, "kT must be specified for QTB thermostat"
132        assert gamma is not None, "gamma must be specified for QTB thermostat"
133        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
134        rng_key, v_key = jax.random.split(rng_key)
135
136        a1 = math.exp(-gamma * dt)
137        a2 = (1 - a1) * kT
138        vel = (
139            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
140            * (kT / mass[:, None]) ** 0.5
141        )
142
143        state["rng_key"] = rng_key
144        if compute_thermostat_energy:
145            state["thermostat_energy"] = 0.0
146
147        def thermostat(vel, state):
148            rng_key, noise_key = jax.random.split(state["rng_key"])
149            new_state = {**state, "rng_key": rng_key}
150            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
151            R2 = jnp.sum(noise**2)
152            R1 = noise[0, 0]
153            c = a2 / (mass[:, None] * vel**2).sum()
154            d = (a1 * c) ** 0.5
155            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
156            if compute_thermostat_energy:
157                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
158                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
159            return scale * vel, new_state
160
161    elif thermostat_name in [
162        "GD",
163        "DESCENT",
164        "GRADIENT_DESCENT",
165        "MIN",
166        "MINIMIZE",
167    ]:
168        assert nbeads is None, "Gradient descent is not compatible with PIMD"
169        a1 = math.exp(-gamma * dt)
170
171        if nbeads is None:
172            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
173        else:
174            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
175
176        def thermostat(vel, state):
177            return a1 * vel, state
178
179    elif thermostat_name in ["NVE", "NONE"]:
180        if kT is None:
181            kT = 0.
182        
183        if nbeads is None:
184            vel = (
185                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
186                * (kT / mass[:, None]) ** 0.5
187            )
188            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
189            vel = vel * (kT / kTsys) ** 0.5
190        else:
191            vel = (
192                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
193                * (kT / mass[None, :, None]) ** 0.5
194            )
195            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
196                mass.shape[0] * 3
197            )
198            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
199        thermostat = lambda x, s: (x, s)
200
201    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
202        assert gamma is not None, "gamma must be specified for QTB thermostat"
203        ndof = mass.shape[0] * 3
204        nkT = ndof * kT
205        nose_mass = nkT / gamma**2
206        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
207        state["nose_s"] = 0.0
208        state["nose_v"] = 0.0
209        if compute_thermostat_energy:
210            state["thermostat_energy"] = 0.0
211        print(
212            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
213        )
214        vel = (
215            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
216            * (kT / mass[:, None]) ** 0.5
217        )
218
219        def thermostat(vel, state):
220            nose_s = state["nose_s"]
221            nose_v = state["nose_v"]
222            kTsys = jnp.sum(mass[:, None] * vel**2)
223            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
224            nose_s = nose_s + dt * nose_v
225            vel = jnp.exp(-nose_v * dt) * vel
226            kTsys = jnp.sum(mass[:, None] * vel**2)
227            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
228            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
229
230            if compute_thermostat_energy:
231                new_state["thermostat_energy"] = (
232                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
233                )
234            return vel, new_state
235
236    elif thermostat_name in ["QTB", "ADQTB"]:
237        default_ekin_instant = "O"
238        assert nbeads is None, "QTB is not compatible with PIMD"
239        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
240        assert kT is not None, "kT must be specified for QTB thermostat"
241        assert gamma is not None, "gamma must be specified for QTB thermostat"
242        assert species is not None, "species must be provided for QTB thermostat"
243        rng_key, v_key = jax.random.split(rng_key)
244        vel = (
245            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
246            * (kT / mass[:, None]) ** 0.5
247        )
248
249        thermostat, postprocess, qtb_state = initialize_qtb(
250            simulation_parameters,
251            system_data,
252            fprec=fprec,
253            dt=dt,
254            mass=mass,
255            gamma=gamma,
256            kT=kT,
257            species=species,
258            rng_key=rng_key,
259            adaptive=thermostat_name.startswith("AD"),
260            compute_thermostat_energy=compute_thermostat_energy,
261        )
262        state = {**state, **qtb_state}
263
264    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
265        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
266        assert kT is not None, "kT must be specified for QTB thermostat"
267        assert gamma is not None, "gamma must be specified for QTB thermostat"
268        assert nbeads is None, "ANNEAL is not compatible with PIMD"
269        a1 = math.exp(-gamma * dt)
270        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
271
272        anneal_parameters = simulation_parameters.get("annealing", {})
273        """@keyword[fennol_md] annealing
274        Parameters for simulated annealing schedule configuration.
275        Required for ANNEAL/ANNEALING thermostat
276        """
277        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
278        """@keyword[fennol_md] annealing/init_factor
279        Initial temperature factor for annealing schedule.
280        Default: 0.04 (1/25)
281        """
282        assert init_factor > 0.0, "init_factor must be positive"
283        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
284        """@keyword[fennol_md] annealing/final_factor
285        Final temperature factor for annealing schedule.
286        Default: 0.0001 (1/10000)
287        """
288        assert final_factor > 0.0, "final_factor must be positive"
289        nsteps = simulation_parameters.get("nsteps")
290        """@keyword[fennol_md] nsteps
291        Total number of simulation steps for annealing schedule calculation.
292        Required parameter
293        """
294        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
295        """@keyword[fennol_md] annealing/anneal_steps
296        Fraction of total steps for annealing process.
297        Default: 1.0
298        """
299        assert (
300            anneal_steps < 1.0 and anneal_steps > 0.0
301        ), "warmup_steps must be between 0 and nsteps"
302        pct_start = anneal_parameters.get("warmup_steps", 0.3)
303        """@keyword[fennol_md] annealing/warmup_steps
304        Fraction of annealing steps for warmup phase.
305        Default: 0.3
306        """
307        assert (
308            pct_start < 1.0 and pct_start > 0.0
309        ), "warmup_steps must be between 0 and nsteps"
310
311        anneal_type = anneal_parameters.get("type", "cosine").lower()
312        """@keyword[fennol_md] annealing/type
313        Type of annealing schedule (linear, cosine_onecycle).
314        Default: cosine
315        """
316        if anneal_type == "linear":
317            schedule = optax.linear_onecycle_schedule(
318                peak_value=1.0,
319                div_factor=1.0 / init_factor,
320                final_div_factor=1.0 / final_factor,
321                transition_steps=int(anneal_steps * nsteps),
322                pct_start=pct_start,
323                pct_final=1.0,
324            )
325        elif anneal_type == "cosine_onecycle":
326            schedule = optax.cosine_onecycle_schedule(
327                peak_value=1.0,
328                div_factor=1.0 / init_factor,
329                final_div_factor=1.0 / final_factor,
330                transition_steps=int(anneal_steps * nsteps),
331                pct_start=pct_start,
332            )
333        else:
334            raise ValueError(f"Unknown anneal_type {anneal_type}")
335
336        state["rng_key"] = rng_key
337        state["istep_anneal"] = 0
338
339        rng_key, v_key = jax.random.split(rng_key)
340        Tscale = schedule(0)
341        print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K")
342        vel = (
343            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
344            * (kT * Tscale / mass[:, None]) ** 0.5
345        )
346
347        def thermostat(vel, state):
348            rng_key, noise_key = jax.random.split(state["rng_key"])
349            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
350
351            Tscale = schedule(state["istep_anneal"]) ** 0.5
352            vel = a1 * vel + a2 * Tscale * noise
353            return vel, {
354                **state,
355                "rng_key": rng_key,
356                "istep_anneal": state["istep_anneal"] + 1,
357            }
358
359    else:
360        raise ValueError(f"Unknown thermostat {thermostat_name}")
361    
362    # remove center of mass velocity
363    massprop = system_data["mass_Da"]/system_data["totmass_Da"]
364    if nbeads is None:
365        vel = vel.reshape(-1,system_data["nat"], 3)
366        vcom = jnp.sum(vel * massprop[None,:, None], axis=1, keepdims=True)
367        vel = (vel - vcom).reshape(-1,3)
368    else:
369        vcom = jnp.sum(vel[0]*massprop[:,None], axis=0, keepdims=True)
370        vel = vel.at[0].set(vel[0] - vcom)
371
372    ekin_instant = str(simulation_parameters.get("ekin_instant", default_ekin_instant)).upper()
373    """@keyword[fennol_md] ekin_instant
374    Where in the time step to compute the instantaneous kinetic energy. 
375    Options: 'B' (end of step), 'O' (after thermostat).
376    Default: 'O' for LGV and QTB thermostats, 'B' otherwise.
377    """
378    assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'"
379
380    return thermostat, postprocess, state, vel,thermostat_name, ekin_instant
381
382
383def initialize_qtb(
384    simulation_parameters,
385    system_data,
386    fprec,
387    dt,
388    mass,
389    gamma,
390    kT,
391    species,
392    rng_key,
393    adaptive,
394    compute_thermostat_energy=False,
395):
396    state = {}
397    post_state = {}
398    qtb_parameters = simulation_parameters.get("qtb", {})
399    verbose = qtb_parameters.get("verbose", False)
400    """@keyword[fennol_md] qtb/verbose
401    Print verbose QTB thermostat information.
402    Default: False
403    """
404    if compute_thermostat_energy:
405        state["thermostat_energy"] = 0.0
406
407    mass = jnp.asarray(mass, dtype=fprec)
408
409    nat = species.shape[0]
410    # define type indices
411    species_set = set(species)
412    nspecies = len(species_set)
413    idx = {sp: i for i, sp in enumerate(species_set)}
414    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
415    type_labels = [PERIODIC_TABLE[sp] for sp in species_set]
416
417    adapt_groups = qtb_parameters.get("adapt_groups", {})
418    """@keyword[fennol_md] qtb/adapt_groups
419    Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes).
420    Default: {}
421    """
422    ntypes = nspecies
423    allgroups = set()
424    for groupname, interval in adapt_groups.items():
425        indices = read_tinker_interval(interval)
426        assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups"
427        allgroups.update(indices)
428        species_in_group = set(species[indices])
429        idx = {sp: i for i, sp in enumerate(species_in_group)}
430        for i in indices:
431            type_idx[i] = ntypes + idx[species[i]]
432        type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group]
433        ntypes = len(type_labels)
434
435    n_of_type = np.zeros(ntypes, dtype=np.int32)
436    for i in range(ntypes):
437        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
438        print(f"# QTB: n({type_labels[i]}) =", n_of_type[i])
439    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
440    mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type
441
442    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
443    """@keyword[fennol_md] qtb/niter_deconv_kin
444    Number of iterations for kinetic energy deconvolution.
445    Default: 7
446    """
447    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
448    """@keyword[fennol_md] qtb/niter_deconv_pot
449    Number of iterations for potential energy deconvolution.
450    Default: 20
451    """
452    corr_kin = qtb_parameters.get("corr_kin", -1)
453    """@keyword[fennol_md] qtb/corr_kin
454    Kinetic energy correction factor for QTB (-1 for automatic).
455    Default: -1
456    """
457    do_corr_kin = corr_kin <= 0
458    if do_corr_kin:
459        corr_kin = 1.0
460    state["corr_kin"] = corr_kin
461    post_state["corr_kin_prev"] = corr_kin
462    post_state["do_corr_kin"] = do_corr_kin
463    post_state["isame_kin"] = 0
464
465    # spectra parameters
466    omegasmear = np.pi / dt / 100.0
467    Tseg = qtb_parameters.get("tseg", 1.0 / us.PS)
468    """@keyword[fennol_md] qtb/tseg
469    Time segment length for QTB spectrum calculation.
470    Default: 1.0 ps
471    """
472    nseg = int(Tseg / dt)
473    Tseg = nseg * dt
474    dom = 2 * np.pi / (3 * Tseg)
475    omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1)
476    """@keyword[fennol_md] qtb/omegacut
477    Cutoff frequency for QTB spectrum.
478    Default: 15000.0 cm⁻¹
479    """
480    nom = int(omegacut / dom)
481    omega = dom * np.arange((3 * nseg) // 2 + 1)
482    cutoff = jnp.asarray(
483        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
484    )
485    assert (
486        omegacut < omega[-1]
487    ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1"
488
489    # initialize gammar
490    assert (
491        gamma < 0.5 * omegacut
492    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
493    gammar_min = qtb_parameters.get("gammar_min", 0.1)
494    """@keyword[fennol_md] qtb/gammar_min
495    Minimum value for QTB gamma ratio coefficients.
496    Default: 0.1
497    """
498    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
499    gammar = np.ones((ntypes, nom), dtype=float)
500    try:
501        for i, sp in enumerate(type_labels):
502            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
503            data = np.loadtxt(f"QTB_spectra_{sp}.out")
504            gammar[i] = data[:, 4]/(gamma*us.THZ)
505            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
506    except Exception as e:
507        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
508        gammar[:,:] = 1.0
509    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
510
511    # Ornstein-Uhlenbeck correction for colored noise
512    a1 = np.exp(-gamma * dt)
513    OUcorr = jnp.asarray(
514        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
515        dtype=fprec,
516    )
517
518    # hbar schedule
519    classical_kernel = qtb_parameters.get("classical_kernel", False)
520    """@keyword[fennol_md] qtb/classical_kernel
521    Use classical instead of quantum kernel for QTB.
522    Default: False
523    """
524    hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR
525    """@keyword[fennol_md] qtb/hbar
526    Reduced Planck constant scaling factor for quantum effects.
527    Default: 1.0 a.u.
528    """
529    u = 0.5 * hbar * np.abs(omega) / kT
530    theta = kT * np.ones_like(omega)
531    if hbar > 0:
532        theta[1:] *= u[1:] / np.tanh(u[1:])
533    theta = jnp.asarray(theta, dtype=fprec)
534
535    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
536    del rng_key
537    post_state["white_noise"] = jax.random.normal(
538        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
539    )
540
541    startsave = qtb_parameters.get("startsave", 1)
542    """@keyword[fennol_md] qtb/startsave
543    Start saving QTB statistics after this many segments.
544    Default: 1
545    """
546    counter = Counter(nseg, startsave=startsave)
547    state["istep"] = 0
548    post_state["nadapt"] = 0
549    post_state["nsample"] = 0
550
551    write_spectra = qtb_parameters.get("write_spectra", True)
552    """@keyword[fennol_md] qtb/write_spectra
553    Write QTB spectral analysis output files.
554    Default: True
555    """
556    do_compute_spectra = write_spectra or adaptive
557
558    if do_compute_spectra:
559        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
560
561        post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec)
562        post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec)
563        post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec)
564        post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec)
565        post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
566        post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
567        post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
568        post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
569
570    if not adaptive:
571        update_gammar = lambda x: x
572    else:
573        # adaptation parameters
574        skipseg = qtb_parameters.get("skipseg", 1)
575        """@keyword[fennol_md] qtb/skipseg
576        Number of segments to skip before starting adaptive QTB.
577        Default: 1
578        """
579
580        adaptation_method = (
581            str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip()
582        )
583        """@keyword[fennol_md] qtb/adaptation_method
584        Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF).
585        Default: SIMPLE
586        """
587        if adaptation_method == "SIMPLE":
588            agamma = qtb_parameters.get("agamma", 0.1)
589            """@keyword[fennol_md] qtb/agamma
590            Learning rate for adaptive QTB gamma update.
591            Default: 0.1
592            """
593            assert agamma > 0, "agamma must be positive"
594            a1_ad = agamma * Tseg  #  * gamma
595            print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}")
596
597            def update_gammar(post_state):
598                g = post_state["dFDT"]
599                gammar = post_state["gammar"] - a1_ad * g
600                gammar = jnp.maximum(gammar_min, gammar)
601                return {**post_state, "gammar": gammar}
602
603        elif adaptation_method == "RATIO":
604            tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 
605            """@keyword[fennol_md] qtb/tau_ad
606            Adaptation time constant for momentum averaging.
607            Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF)
608            """
609            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad)
610            """@keyword[fennol_md] qtb/tau_s
611            Second moment time constant for variance averaging.
612            Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF)
613            """
614            assert tau_ad > 0, "tau_ad must be positive"
615            print(
616                f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
617            )
618            b1 = np.exp(-Tseg / tau_ad)
619            b2 = np.exp(-Tseg / tau_s)
620            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
621            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
622            post_state["n_adabelief"] = 0
623
624            def update_gammar(post_state):
625                n_adabelief = post_state["n_adabelief"] + 1
626                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
627                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
628                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
629                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
630                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
631                gammar = Cvf / (mCvv + 1.0e-8)
632                gammar = jnp.maximum(gammar_min, gammar)
633                return {
634                    **post_state,
635                    "gammar": gammar,
636                    "mCvv_m": mCvv_m,
637                    "Cvf_m": Cvf_m,
638                    "n_adabelief": n_adabelief,
639                }
640
641        elif adaptation_method == "ADABELIEF":
642            agamma = qtb_parameters.get("agamma", 0.1)
643            tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS)
644            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad)
645            assert tau_ad > 0, "tau_ad must be positive"
646            assert tau_s > 0, "tau_s must be positive"
647            assert agamma > 0, "agamma must be positive"
648            print(
649                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
650            )
651
652            a1_ad = agamma * gamma  # * Tseg #* gamma
653            b1 = np.exp(-Tseg / tau_ad)
654            b2 = np.exp(-Tseg / tau_s)
655            post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec)
656            post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec)
657            post_state["n_adabelief"] = 0
658
659            def update_gammar(post_state):
660                n_adabelief = post_state["n_adabelief"] + 1
661                dFDT = post_state["dFDT"]
662                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
663                dFDT_s = (
664                    post_state["dFDT_s"] * b2
665                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
666                    + 1.0e-8
667                )
668                # bias correction
669                mt = dFDT_m / (1.0 - b1**n_adabelief)
670                st = dFDT_s / (1.0 - b2**n_adabelief)
671                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
672                gammar = jnp.maximum(gammar_min, gammar)
673                return {
674                    **post_state,
675                    "gammar": gammar,
676                    "dFDT_m": dFDT_m,
677                    "n_adabelief": n_adabelief,
678                    "dFDT_s": dFDT_s,
679                }
680        else: 
681            raise ValueError(
682                f"Unknown adaptation method {adaptation_method}."
683            )
684    
685    #####################
686    # RESTART
687    restart_file = system_data["name"]+".qtb.restart"
688    if os.path.exists(restart_file):
689        with open(restart_file, "rb") as f:
690            data = pickle.load(f)
691            state["corr_kin"] = data["corr_kin"]
692            post_state["corr_kin_prev"] = data["corr_kin"]
693            post_state["isame_kin"] = data["isame_kin"]
694            post_state["do_corr_kin"] = data["do_corr_kin"]
695            print(f"# Restored QTB state from {restart_file}")
696
697    def write_qtb_restart(state, post_state):
698        with open(restart_file, "wb") as f:
699            pickle.dump(
700                {
701                    "corr_kin": state["corr_kin"],
702                    "corr_kin_prev": post_state["corr_kin_prev"],
703                    "isame_kin": post_state["isame_kin"],
704                    "do_corr_kin": post_state["do_corr_kin"],
705                },
706                f,
707            )
708    ######################
709
710    def compute_corr_pot(niter=20, verbose=False):
711        if classical_kernel or hbar == 0:
712            return np.ones(nom)
713
714        s_0 = np.array((theta / kT * cutoff)[:nom])
715        s_out, s_rec, _ = deconvolute_spectrum(
716            s_0,
717            omega[:nom],
718            gamma,
719            niter,
720            kernel=kernel_lorentz_pot,
721            trans=True,
722            symmetrize=True,
723            verbose=verbose,
724        )
725        corr_pot = 1.0 + (s_out - s_0) / s_0
726        columns = np.column_stack(
727            (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
728        )
729        np.savetxt(
730            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
731        )
732        return corr_pot
733
734    def compute_corr_kin(post_state, niter=7, verbose=False):
735        if not post_state["do_corr_kin"]:
736            return post_state["corr_kin_prev"], post_state
737        if classical_kernel or hbar == 0:
738            return 1.0, post_state
739
740        K_D = post_state.get("K_D", None)
741        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
742        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
743        s_out, s_rec, K_D = deconvolute_spectrum(
744            s_0,
745            omega[:nom],
746            gamma,
747            niter,
748            kernel=kernel_lorentz,
749            trans=False,
750            symmetrize=True,
751            verbose=verbose,
752            K_D=K_D,
753        )
754        s_out = s_out * theta[:nom] / kT
755        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
756        mCvvsum = mCvv.sum()
757        rec_ratio = mCvvsum / s_rec.sum()
758        if rec_ratio < 0.95 or rec_ratio > 1.05:
759            print(
760                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
761            )
762            return post_state["corr_kin_prev"], post_state
763
764        corr_kin = mCvvsum / s_out.sum()
765        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
766            isame_kin = post_state["isame_kin"] + 1
767        else:
768            isame_kin = 0
769
770        # print("# corr_kin: ", corr_kin)
771        do_corr_kin = post_state["do_corr_kin"]
772        if isame_kin > 10:
773            print(
774                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
775            )
776            do_corr_kin = False
777
778        return corr_kin, {
779            **post_state,
780            "corr_kin_prev": corr_kin,
781            "isame_kin": isame_kin,
782            "do_corr_kin": do_corr_kin,
783            "K_D": K_D,
784        }
785
786    @jax.jit
787    def ff_kernel(post_state):
788        if classical_kernel:
789            kernel = cutoff * (2 * gamma * kT / dt)
790        else:
791            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
792        gamma_ratio = jnp.concatenate(
793            (
794                post_state["gammar"].T * post_state["corr_pot"][:, None],
795                jnp.ones(
796                    (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype
797                ),
798            ),
799            axis=0,
800        )
801        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
802
803    @jax.jit
804    def refresh_force(post_state):
805        rng_key, noise_key = jax.random.split(post_state["rng_key"])
806        white_noise = jnp.concatenate(
807            (
808                post_state["white_noise"][nseg:],
809                jax.random.normal(
810                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
811                ),
812            ),
813            axis=0,
814        )
815        amplitude = ff_kernel(post_state) ** 0.5
816        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
817        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
818        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
819
820    @jax.jit
821    def compute_spectra(force, vel, post_state):
822        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
823        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
824        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
825        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
826        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
827
828        mCvv = (
829            (dt / 3.0)
830            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
831            * mass_idx[:, None]
832            / n_of_type[:, None]
833        )
834        Cvf = (
835            (dt / 3.0)
836            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
837            / n_of_type[:, None]
838        )
839        Cff = (
840            (dt / 3.0)
841            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
842            / n_of_type[:, None]
843        )
844        dFDT = mCvv * post_state["gammar"] - Cvf
845
846        nsinv = 1.0 / post_state["nsample"]
847        b1 = 1.0 - nsinv
848        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
849        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
850        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
851        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
852
853        return {
854            **post_state,
855            "mCvv": mCvv,
856            "Cvf": Cvf,
857            "Cff": Cff,
858            "dFDT": dFDT,
859            "dFDT_avg": dFDT_avg,
860            "mCvv_avg": mCvv_avg,
861            "Cvfg_avg": Cvfg_avg,
862            "Cff_avg": Cff_avg,
863        }
864
865    def write_spectra_to_file(post_state):
866        mCvv_avg = np.array(post_state["mCvv_avg"])
867        Cvfg_avg = np.array(post_state["Cvfg_avg"])
868        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
869        dFDT_avg = np.array(post_state["dFDT_avg"])
870        gammar = np.array(post_state["gammar"])
871        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
872        for i, sp in enumerate(type_labels):
873            ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i])
874            columns = np.column_stack(
875                (
876                    omega[:nom] * us.CM1,
877                    mCvv_avg[i],
878                    Cvfg_avg[i],
879                    dFDT_avg[i],
880                    gammar[i] * gamma * us.THZ,
881                    Cff_avg[i] * ff_scale,
882                    Cff_theo[i] * ff_scale,
883                )
884            )
885            np.savetxt(
886                f"QTB_spectra_{sp}.out",
887                columns,
888                fmt="%12.6f",
889                header="#omega mCvv Cvf dFDT gammar Cff",
890            )
891        if verbose:
892            print("# QTB spectra written.")
893
894    if compute_thermostat_energy:
895        state["qtb_energy_flux"] = 0.0
896
897    @jax.jit
898    def thermostat(vel, state):
899        istep = state["istep"]
900        dvel = dt * state["force"][istep] / mass[:, None]
901        new_vel = vel * a1 + dvel
902        new_state = {**state, "istep": istep + 1}
903        if do_compute_spectra:
904            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
905            new_state["vel"] = vel2
906        if compute_thermostat_energy:
907            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
908            ekcorr = (
909                0.5
910                * (mass[:, None] * new_vel**2).sum()
911                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
912            )
913            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
914            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
915        return new_vel, new_state
916
917    @jax.jit
918    def postprocess_work(state, post_state):
919        if do_compute_spectra:
920            post_state = compute_spectra(state["force"], state["vel"], post_state)
921        if adaptive:
922            post_state = jax.lax.cond(
923                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
924            )
925        new_force, post_state = refresh_force(post_state)
926        return {**state, "force": new_force}, post_state
927
928    def postprocess(state, post_state):
929        counter.increment()
930        if not counter.is_reset_step:
931            return state, post_state
932        post_state["nadapt"] += 1
933        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
934        if verbose:
935            print("# Refreshing QTB forces.")
936        state, post_state = postprocess_work(state, post_state)
937        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
938        state["istep"] = 0
939        if write_spectra:
940            write_spectra_to_file(post_state)
941        write_qtb_restart(state, post_state)
942        return state, post_state
943
944    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
945
946    state["force"], post_state = refresh_force(post_state)
947    return thermostat, (postprocess, post_state), state
def get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 21def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 22    state = {}
 23    postprocess = None
 24    
 25    default_ekin_instant = "B"
 26
 27    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 28    """@keyword[fennol_md] thermostat
 29    Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'.
 30    Default: "LGV"
 31    """
 32    compute_thermostat_energy = simulation_parameters.get(
 33        "include_thermostat_energy", False
 34    )
 35    """@keyword[fennol_md] include_thermostat_energy
 36    Include thermostat energy in total energy calculations.
 37    Default: False
 38    """
 39
 40    kT = system_data.get("kT", None)
 41    nbeads = system_data.get("nbeads", None)
 42    mass = system_data["mass"]
 43    gamma0 = simulation_parameters.get("gamma", 1.0 / us.THZ)
 44    """@keyword[fennol_md] gamma
 45    Friction coefficient for Langevin thermostat.
 46    Default: 1.0 ps^-1
 47    """
 48    gamma = gamma0
 49    if gamma <= 0.0:
 50        gamma = 0.0
 51    species = system_data["species"]
 52
 53    if nbeads is not None:
 54        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 55        """@keyword[fennol_md] trpmd_lambda
 56        Lambda parameter for TRPMD (Thermostatted Ring Polymer MD).
 57        Default: 1.0
 58        """
 59        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 60
 61    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 62        default_ekin_instant = "O"
 63        if gamma0 <= 1.e-5:
 64            default_ekin_instant = "B"
 65        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 66        assert kT is not None, "kT must be specified for QTB thermostat"
 67        assert gamma is not None, "gamma must be specified for QTB thermostat"
 68        rng_key, v_key = jax.random.split(rng_key)
 69        if nbeads is None:
 70            a1 = math.exp(-gamma * dt)
 71            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 72            vel = (
 73                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 74                * (kT / mass[:, None]) ** 0.5
 75            )
 76        else:
 77            if isinstance(gamma, float):
 78                gamma = np.array([gamma] * nbeads)
 79            assert isinstance(
 80                gamma, np.ndarray
 81            ), "gamma must be a float or a numpy array"
 82            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 83            a1 = np.exp(-gamma * dt)[:, None, None]
 84            a2 = jnp.asarray(
 85                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 86            )
 87            vel = (
 88                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 89                * (kT / mass[:, None]) ** 0.5
 90            )
 91
 92        state["rng_key"] = rng_key
 93        if compute_thermostat_energy:
 94            state["thermostat_energy"] = 0.0
 95        if thermostat_name == "FFLGV":
 96            def thermostat(vel, state):
 97                rng_key, noise_key = jax.random.split(state["rng_key"])
 98                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 99                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
100                dirvel = vel / norm_vel
101                if compute_thermostat_energy:
102                    v2 = (vel**2).sum(axis=-1)
103                vel = a1 * vel + a2 * noise
104                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
105                vel = dirvel * new_norm_vel
106                new_state = {**state, "rng_key": rng_key}
107                if compute_thermostat_energy:
108                    v2new = (vel**2).sum(axis=-1)
109                    new_state["thermostat_energy"] = (
110                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
111                    )
112
113                return vel, new_state
114
115        else:
116            def thermostat(vel, state):
117                rng_key, noise_key = jax.random.split(state["rng_key"])
118                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
119                if compute_thermostat_energy:
120                    v2 = (vel**2).sum(axis=-1)
121                vel = a1 * vel + a2 * noise
122                new_state = {**state, "rng_key": rng_key}
123                if compute_thermostat_energy:
124                    v2new = (vel**2).sum(axis=-1)
125                    new_state["thermostat_energy"] = (
126                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
127                    )
128                return vel, new_state
129
130    elif thermostat_name in ["BUSSI"]:
131        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
132        assert kT is not None, "kT must be specified for QTB thermostat"
133        assert gamma is not None, "gamma must be specified for QTB thermostat"
134        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
135        rng_key, v_key = jax.random.split(rng_key)
136
137        a1 = math.exp(-gamma * dt)
138        a2 = (1 - a1) * kT
139        vel = (
140            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
141            * (kT / mass[:, None]) ** 0.5
142        )
143
144        state["rng_key"] = rng_key
145        if compute_thermostat_energy:
146            state["thermostat_energy"] = 0.0
147
148        def thermostat(vel, state):
149            rng_key, noise_key = jax.random.split(state["rng_key"])
150            new_state = {**state, "rng_key": rng_key}
151            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
152            R2 = jnp.sum(noise**2)
153            R1 = noise[0, 0]
154            c = a2 / (mass[:, None] * vel**2).sum()
155            d = (a1 * c) ** 0.5
156            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
157            if compute_thermostat_energy:
158                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
159                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
160            return scale * vel, new_state
161
162    elif thermostat_name in [
163        "GD",
164        "DESCENT",
165        "GRADIENT_DESCENT",
166        "MIN",
167        "MINIMIZE",
168    ]:
169        assert nbeads is None, "Gradient descent is not compatible with PIMD"
170        a1 = math.exp(-gamma * dt)
171
172        if nbeads is None:
173            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
174        else:
175            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
176
177        def thermostat(vel, state):
178            return a1 * vel, state
179
180    elif thermostat_name in ["NVE", "NONE"]:
181        if kT is None:
182            kT = 0.
183        
184        if nbeads is None:
185            vel = (
186                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
187                * (kT / mass[:, None]) ** 0.5
188            )
189            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
190            vel = vel * (kT / kTsys) ** 0.5
191        else:
192            vel = (
193                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
194                * (kT / mass[None, :, None]) ** 0.5
195            )
196            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
197                mass.shape[0] * 3
198            )
199            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
200        thermostat = lambda x, s: (x, s)
201
202    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
203        assert gamma is not None, "gamma must be specified for QTB thermostat"
204        ndof = mass.shape[0] * 3
205        nkT = ndof * kT
206        nose_mass = nkT / gamma**2
207        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
208        state["nose_s"] = 0.0
209        state["nose_v"] = 0.0
210        if compute_thermostat_energy:
211            state["thermostat_energy"] = 0.0
212        print(
213            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
214        )
215        vel = (
216            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
217            * (kT / mass[:, None]) ** 0.5
218        )
219
220        def thermostat(vel, state):
221            nose_s = state["nose_s"]
222            nose_v = state["nose_v"]
223            kTsys = jnp.sum(mass[:, None] * vel**2)
224            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
225            nose_s = nose_s + dt * nose_v
226            vel = jnp.exp(-nose_v * dt) * vel
227            kTsys = jnp.sum(mass[:, None] * vel**2)
228            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
229            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
230
231            if compute_thermostat_energy:
232                new_state["thermostat_energy"] = (
233                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
234                )
235            return vel, new_state
236
237    elif thermostat_name in ["QTB", "ADQTB"]:
238        default_ekin_instant = "O"
239        assert nbeads is None, "QTB is not compatible with PIMD"
240        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
241        assert kT is not None, "kT must be specified for QTB thermostat"
242        assert gamma is not None, "gamma must be specified for QTB thermostat"
243        assert species is not None, "species must be provided for QTB thermostat"
244        rng_key, v_key = jax.random.split(rng_key)
245        vel = (
246            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
247            * (kT / mass[:, None]) ** 0.5
248        )
249
250        thermostat, postprocess, qtb_state = initialize_qtb(
251            simulation_parameters,
252            system_data,
253            fprec=fprec,
254            dt=dt,
255            mass=mass,
256            gamma=gamma,
257            kT=kT,
258            species=species,
259            rng_key=rng_key,
260            adaptive=thermostat_name.startswith("AD"),
261            compute_thermostat_energy=compute_thermostat_energy,
262        )
263        state = {**state, **qtb_state}
264
265    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
266        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
267        assert kT is not None, "kT must be specified for QTB thermostat"
268        assert gamma is not None, "gamma must be specified for QTB thermostat"
269        assert nbeads is None, "ANNEAL is not compatible with PIMD"
270        a1 = math.exp(-gamma * dt)
271        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
272
273        anneal_parameters = simulation_parameters.get("annealing", {})
274        """@keyword[fennol_md] annealing
275        Parameters for simulated annealing schedule configuration.
276        Required for ANNEAL/ANNEALING thermostat
277        """
278        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
279        """@keyword[fennol_md] annealing/init_factor
280        Initial temperature factor for annealing schedule.
281        Default: 0.04 (1/25)
282        """
283        assert init_factor > 0.0, "init_factor must be positive"
284        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
285        """@keyword[fennol_md] annealing/final_factor
286        Final temperature factor for annealing schedule.
287        Default: 0.0001 (1/10000)
288        """
289        assert final_factor > 0.0, "final_factor must be positive"
290        nsteps = simulation_parameters.get("nsteps")
291        """@keyword[fennol_md] nsteps
292        Total number of simulation steps for annealing schedule calculation.
293        Required parameter
294        """
295        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
296        """@keyword[fennol_md] annealing/anneal_steps
297        Fraction of total steps for annealing process.
298        Default: 1.0
299        """
300        assert (
301            anneal_steps < 1.0 and anneal_steps > 0.0
302        ), "warmup_steps must be between 0 and nsteps"
303        pct_start = anneal_parameters.get("warmup_steps", 0.3)
304        """@keyword[fennol_md] annealing/warmup_steps
305        Fraction of annealing steps for warmup phase.
306        Default: 0.3
307        """
308        assert (
309            pct_start < 1.0 and pct_start > 0.0
310        ), "warmup_steps must be between 0 and nsteps"
311
312        anneal_type = anneal_parameters.get("type", "cosine").lower()
313        """@keyword[fennol_md] annealing/type
314        Type of annealing schedule (linear, cosine_onecycle).
315        Default: cosine
316        """
317        if anneal_type == "linear":
318            schedule = optax.linear_onecycle_schedule(
319                peak_value=1.0,
320                div_factor=1.0 / init_factor,
321                final_div_factor=1.0 / final_factor,
322                transition_steps=int(anneal_steps * nsteps),
323                pct_start=pct_start,
324                pct_final=1.0,
325            )
326        elif anneal_type == "cosine_onecycle":
327            schedule = optax.cosine_onecycle_schedule(
328                peak_value=1.0,
329                div_factor=1.0 / init_factor,
330                final_div_factor=1.0 / final_factor,
331                transition_steps=int(anneal_steps * nsteps),
332                pct_start=pct_start,
333            )
334        else:
335            raise ValueError(f"Unknown anneal_type {anneal_type}")
336
337        state["rng_key"] = rng_key
338        state["istep_anneal"] = 0
339
340        rng_key, v_key = jax.random.split(rng_key)
341        Tscale = schedule(0)
342        print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K")
343        vel = (
344            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
345            * (kT * Tscale / mass[:, None]) ** 0.5
346        )
347
348        def thermostat(vel, state):
349            rng_key, noise_key = jax.random.split(state["rng_key"])
350            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
351
352            Tscale = schedule(state["istep_anneal"]) ** 0.5
353            vel = a1 * vel + a2 * Tscale * noise
354            return vel, {
355                **state,
356                "rng_key": rng_key,
357                "istep_anneal": state["istep_anneal"] + 1,
358            }
359
360    else:
361        raise ValueError(f"Unknown thermostat {thermostat_name}")
362    
363    # remove center of mass velocity
364    massprop = system_data["mass_Da"]/system_data["totmass_Da"]
365    if nbeads is None:
366        vel = vel.reshape(-1,system_data["nat"], 3)
367        vcom = jnp.sum(vel * massprop[None,:, None], axis=1, keepdims=True)
368        vel = (vel - vcom).reshape(-1,3)
369    else:
370        vcom = jnp.sum(vel[0]*massprop[:,None], axis=0, keepdims=True)
371        vel = vel.at[0].set(vel[0] - vcom)
372
373    ekin_instant = str(simulation_parameters.get("ekin_instant", default_ekin_instant)).upper()
374    """@keyword[fennol_md] ekin_instant
375    Where in the time step to compute the instantaneous kinetic energy. 
376    Options: 'B' (end of step), 'O' (after thermostat).
377    Default: 'O' for LGV and QTB thermostats, 'B' otherwise.
378    """
379    assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'"
380
381    return thermostat, postprocess, state, vel,thermostat_name, ekin_instant
def initialize_qtb( simulation_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
384def initialize_qtb(
385    simulation_parameters,
386    system_data,
387    fprec,
388    dt,
389    mass,
390    gamma,
391    kT,
392    species,
393    rng_key,
394    adaptive,
395    compute_thermostat_energy=False,
396):
397    state = {}
398    post_state = {}
399    qtb_parameters = simulation_parameters.get("qtb", {})
400    verbose = qtb_parameters.get("verbose", False)
401    """@keyword[fennol_md] qtb/verbose
402    Print verbose QTB thermostat information.
403    Default: False
404    """
405    if compute_thermostat_energy:
406        state["thermostat_energy"] = 0.0
407
408    mass = jnp.asarray(mass, dtype=fprec)
409
410    nat = species.shape[0]
411    # define type indices
412    species_set = set(species)
413    nspecies = len(species_set)
414    idx = {sp: i for i, sp in enumerate(species_set)}
415    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
416    type_labels = [PERIODIC_TABLE[sp] for sp in species_set]
417
418    adapt_groups = qtb_parameters.get("adapt_groups", {})
419    """@keyword[fennol_md] qtb/adapt_groups
420    Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes).
421    Default: {}
422    """
423    ntypes = nspecies
424    allgroups = set()
425    for groupname, interval in adapt_groups.items():
426        indices = read_tinker_interval(interval)
427        assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups"
428        allgroups.update(indices)
429        species_in_group = set(species[indices])
430        idx = {sp: i for i, sp in enumerate(species_in_group)}
431        for i in indices:
432            type_idx[i] = ntypes + idx[species[i]]
433        type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group]
434        ntypes = len(type_labels)
435
436    n_of_type = np.zeros(ntypes, dtype=np.int32)
437    for i in range(ntypes):
438        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
439        print(f"# QTB: n({type_labels[i]}) =", n_of_type[i])
440    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
441    mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type
442
443    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
444    """@keyword[fennol_md] qtb/niter_deconv_kin
445    Number of iterations for kinetic energy deconvolution.
446    Default: 7
447    """
448    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
449    """@keyword[fennol_md] qtb/niter_deconv_pot
450    Number of iterations for potential energy deconvolution.
451    Default: 20
452    """
453    corr_kin = qtb_parameters.get("corr_kin", -1)
454    """@keyword[fennol_md] qtb/corr_kin
455    Kinetic energy correction factor for QTB (-1 for automatic).
456    Default: -1
457    """
458    do_corr_kin = corr_kin <= 0
459    if do_corr_kin:
460        corr_kin = 1.0
461    state["corr_kin"] = corr_kin
462    post_state["corr_kin_prev"] = corr_kin
463    post_state["do_corr_kin"] = do_corr_kin
464    post_state["isame_kin"] = 0
465
466    # spectra parameters
467    omegasmear = np.pi / dt / 100.0
468    Tseg = qtb_parameters.get("tseg", 1.0 / us.PS)
469    """@keyword[fennol_md] qtb/tseg
470    Time segment length for QTB spectrum calculation.
471    Default: 1.0 ps
472    """
473    nseg = int(Tseg / dt)
474    Tseg = nseg * dt
475    dom = 2 * np.pi / (3 * Tseg)
476    omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1)
477    """@keyword[fennol_md] qtb/omegacut
478    Cutoff frequency for QTB spectrum.
479    Default: 15000.0 cm⁻¹
480    """
481    nom = int(omegacut / dom)
482    omega = dom * np.arange((3 * nseg) // 2 + 1)
483    cutoff = jnp.asarray(
484        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
485    )
486    assert (
487        omegacut < omega[-1]
488    ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1"
489
490    # initialize gammar
491    assert (
492        gamma < 0.5 * omegacut
493    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
494    gammar_min = qtb_parameters.get("gammar_min", 0.1)
495    """@keyword[fennol_md] qtb/gammar_min
496    Minimum value for QTB gamma ratio coefficients.
497    Default: 0.1
498    """
499    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
500    gammar = np.ones((ntypes, nom), dtype=float)
501    try:
502        for i, sp in enumerate(type_labels):
503            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
504            data = np.loadtxt(f"QTB_spectra_{sp}.out")
505            gammar[i] = data[:, 4]/(gamma*us.THZ)
506            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
507    except Exception as e:
508        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
509        gammar[:,:] = 1.0
510    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
511
512    # Ornstein-Uhlenbeck correction for colored noise
513    a1 = np.exp(-gamma * dt)
514    OUcorr = jnp.asarray(
515        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
516        dtype=fprec,
517    )
518
519    # hbar schedule
520    classical_kernel = qtb_parameters.get("classical_kernel", False)
521    """@keyword[fennol_md] qtb/classical_kernel
522    Use classical instead of quantum kernel for QTB.
523    Default: False
524    """
525    hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR
526    """@keyword[fennol_md] qtb/hbar
527    Reduced Planck constant scaling factor for quantum effects.
528    Default: 1.0 a.u.
529    """
530    u = 0.5 * hbar * np.abs(omega) / kT
531    theta = kT * np.ones_like(omega)
532    if hbar > 0:
533        theta[1:] *= u[1:] / np.tanh(u[1:])
534    theta = jnp.asarray(theta, dtype=fprec)
535
536    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
537    del rng_key
538    post_state["white_noise"] = jax.random.normal(
539        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
540    )
541
542    startsave = qtb_parameters.get("startsave", 1)
543    """@keyword[fennol_md] qtb/startsave
544    Start saving QTB statistics after this many segments.
545    Default: 1
546    """
547    counter = Counter(nseg, startsave=startsave)
548    state["istep"] = 0
549    post_state["nadapt"] = 0
550    post_state["nsample"] = 0
551
552    write_spectra = qtb_parameters.get("write_spectra", True)
553    """@keyword[fennol_md] qtb/write_spectra
554    Write QTB spectral analysis output files.
555    Default: True
556    """
557    do_compute_spectra = write_spectra or adaptive
558
559    if do_compute_spectra:
560        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
561
562        post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec)
563        post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec)
564        post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec)
565        post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec)
566        post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
567        post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
568        post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
569        post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
570
571    if not adaptive:
572        update_gammar = lambda x: x
573    else:
574        # adaptation parameters
575        skipseg = qtb_parameters.get("skipseg", 1)
576        """@keyword[fennol_md] qtb/skipseg
577        Number of segments to skip before starting adaptive QTB.
578        Default: 1
579        """
580
581        adaptation_method = (
582            str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip()
583        )
584        """@keyword[fennol_md] qtb/adaptation_method
585        Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF).
586        Default: SIMPLE
587        """
588        if adaptation_method == "SIMPLE":
589            agamma = qtb_parameters.get("agamma", 0.1)
590            """@keyword[fennol_md] qtb/agamma
591            Learning rate for adaptive QTB gamma update.
592            Default: 0.1
593            """
594            assert agamma > 0, "agamma must be positive"
595            a1_ad = agamma * Tseg  #  * gamma
596            print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}")
597
598            def update_gammar(post_state):
599                g = post_state["dFDT"]
600                gammar = post_state["gammar"] - a1_ad * g
601                gammar = jnp.maximum(gammar_min, gammar)
602                return {**post_state, "gammar": gammar}
603
604        elif adaptation_method == "RATIO":
605            tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 
606            """@keyword[fennol_md] qtb/tau_ad
607            Adaptation time constant for momentum averaging.
608            Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF)
609            """
610            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad)
611            """@keyword[fennol_md] qtb/tau_s
612            Second moment time constant for variance averaging.
613            Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF)
614            """
615            assert tau_ad > 0, "tau_ad must be positive"
616            print(
617                f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
618            )
619            b1 = np.exp(-Tseg / tau_ad)
620            b2 = np.exp(-Tseg / tau_s)
621            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
622            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
623            post_state["n_adabelief"] = 0
624
625            def update_gammar(post_state):
626                n_adabelief = post_state["n_adabelief"] + 1
627                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
628                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
629                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
630                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
631                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
632                gammar = Cvf / (mCvv + 1.0e-8)
633                gammar = jnp.maximum(gammar_min, gammar)
634                return {
635                    **post_state,
636                    "gammar": gammar,
637                    "mCvv_m": mCvv_m,
638                    "Cvf_m": Cvf_m,
639                    "n_adabelief": n_adabelief,
640                }
641
642        elif adaptation_method == "ADABELIEF":
643            agamma = qtb_parameters.get("agamma", 0.1)
644            tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS)
645            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad)
646            assert tau_ad > 0, "tau_ad must be positive"
647            assert tau_s > 0, "tau_s must be positive"
648            assert agamma > 0, "agamma must be positive"
649            print(
650                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
651            )
652
653            a1_ad = agamma * gamma  # * Tseg #* gamma
654            b1 = np.exp(-Tseg / tau_ad)
655            b2 = np.exp(-Tseg / tau_s)
656            post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec)
657            post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec)
658            post_state["n_adabelief"] = 0
659
660            def update_gammar(post_state):
661                n_adabelief = post_state["n_adabelief"] + 1
662                dFDT = post_state["dFDT"]
663                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
664                dFDT_s = (
665                    post_state["dFDT_s"] * b2
666                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
667                    + 1.0e-8
668                )
669                # bias correction
670                mt = dFDT_m / (1.0 - b1**n_adabelief)
671                st = dFDT_s / (1.0 - b2**n_adabelief)
672                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
673                gammar = jnp.maximum(gammar_min, gammar)
674                return {
675                    **post_state,
676                    "gammar": gammar,
677                    "dFDT_m": dFDT_m,
678                    "n_adabelief": n_adabelief,
679                    "dFDT_s": dFDT_s,
680                }
681        else: 
682            raise ValueError(
683                f"Unknown adaptation method {adaptation_method}."
684            )
685    
686    #####################
687    # RESTART
688    restart_file = system_data["name"]+".qtb.restart"
689    if os.path.exists(restart_file):
690        with open(restart_file, "rb") as f:
691            data = pickle.load(f)
692            state["corr_kin"] = data["corr_kin"]
693            post_state["corr_kin_prev"] = data["corr_kin"]
694            post_state["isame_kin"] = data["isame_kin"]
695            post_state["do_corr_kin"] = data["do_corr_kin"]
696            print(f"# Restored QTB state from {restart_file}")
697
698    def write_qtb_restart(state, post_state):
699        with open(restart_file, "wb") as f:
700            pickle.dump(
701                {
702                    "corr_kin": state["corr_kin"],
703                    "corr_kin_prev": post_state["corr_kin_prev"],
704                    "isame_kin": post_state["isame_kin"],
705                    "do_corr_kin": post_state["do_corr_kin"],
706                },
707                f,
708            )
709    ######################
710
711    def compute_corr_pot(niter=20, verbose=False):
712        if classical_kernel or hbar == 0:
713            return np.ones(nom)
714
715        s_0 = np.array((theta / kT * cutoff)[:nom])
716        s_out, s_rec, _ = deconvolute_spectrum(
717            s_0,
718            omega[:nom],
719            gamma,
720            niter,
721            kernel=kernel_lorentz_pot,
722            trans=True,
723            symmetrize=True,
724            verbose=verbose,
725        )
726        corr_pot = 1.0 + (s_out - s_0) / s_0
727        columns = np.column_stack(
728            (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
729        )
730        np.savetxt(
731            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
732        )
733        return corr_pot
734
735    def compute_corr_kin(post_state, niter=7, verbose=False):
736        if not post_state["do_corr_kin"]:
737            return post_state["corr_kin_prev"], post_state
738        if classical_kernel or hbar == 0:
739            return 1.0, post_state
740
741        K_D = post_state.get("K_D", None)
742        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
743        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
744        s_out, s_rec, K_D = deconvolute_spectrum(
745            s_0,
746            omega[:nom],
747            gamma,
748            niter,
749            kernel=kernel_lorentz,
750            trans=False,
751            symmetrize=True,
752            verbose=verbose,
753            K_D=K_D,
754        )
755        s_out = s_out * theta[:nom] / kT
756        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
757        mCvvsum = mCvv.sum()
758        rec_ratio = mCvvsum / s_rec.sum()
759        if rec_ratio < 0.95 or rec_ratio > 1.05:
760            print(
761                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
762            )
763            return post_state["corr_kin_prev"], post_state
764
765        corr_kin = mCvvsum / s_out.sum()
766        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
767            isame_kin = post_state["isame_kin"] + 1
768        else:
769            isame_kin = 0
770
771        # print("# corr_kin: ", corr_kin)
772        do_corr_kin = post_state["do_corr_kin"]
773        if isame_kin > 10:
774            print(
775                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
776            )
777            do_corr_kin = False
778
779        return corr_kin, {
780            **post_state,
781            "corr_kin_prev": corr_kin,
782            "isame_kin": isame_kin,
783            "do_corr_kin": do_corr_kin,
784            "K_D": K_D,
785        }
786
787    @jax.jit
788    def ff_kernel(post_state):
789        if classical_kernel:
790            kernel = cutoff * (2 * gamma * kT / dt)
791        else:
792            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
793        gamma_ratio = jnp.concatenate(
794            (
795                post_state["gammar"].T * post_state["corr_pot"][:, None],
796                jnp.ones(
797                    (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype
798                ),
799            ),
800            axis=0,
801        )
802        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
803
804    @jax.jit
805    def refresh_force(post_state):
806        rng_key, noise_key = jax.random.split(post_state["rng_key"])
807        white_noise = jnp.concatenate(
808            (
809                post_state["white_noise"][nseg:],
810                jax.random.normal(
811                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
812                ),
813            ),
814            axis=0,
815        )
816        amplitude = ff_kernel(post_state) ** 0.5
817        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
818        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
819        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
820
821    @jax.jit
822    def compute_spectra(force, vel, post_state):
823        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
824        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
825        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
826        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
827        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
828
829        mCvv = (
830            (dt / 3.0)
831            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
832            * mass_idx[:, None]
833            / n_of_type[:, None]
834        )
835        Cvf = (
836            (dt / 3.0)
837            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
838            / n_of_type[:, None]
839        )
840        Cff = (
841            (dt / 3.0)
842            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
843            / n_of_type[:, None]
844        )
845        dFDT = mCvv * post_state["gammar"] - Cvf
846
847        nsinv = 1.0 / post_state["nsample"]
848        b1 = 1.0 - nsinv
849        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
850        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
851        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
852        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
853
854        return {
855            **post_state,
856            "mCvv": mCvv,
857            "Cvf": Cvf,
858            "Cff": Cff,
859            "dFDT": dFDT,
860            "dFDT_avg": dFDT_avg,
861            "mCvv_avg": mCvv_avg,
862            "Cvfg_avg": Cvfg_avg,
863            "Cff_avg": Cff_avg,
864        }
865
866    def write_spectra_to_file(post_state):
867        mCvv_avg = np.array(post_state["mCvv_avg"])
868        Cvfg_avg = np.array(post_state["Cvfg_avg"])
869        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
870        dFDT_avg = np.array(post_state["dFDT_avg"])
871        gammar = np.array(post_state["gammar"])
872        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
873        for i, sp in enumerate(type_labels):
874            ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i])
875            columns = np.column_stack(
876                (
877                    omega[:nom] * us.CM1,
878                    mCvv_avg[i],
879                    Cvfg_avg[i],
880                    dFDT_avg[i],
881                    gammar[i] * gamma * us.THZ,
882                    Cff_avg[i] * ff_scale,
883                    Cff_theo[i] * ff_scale,
884                )
885            )
886            np.savetxt(
887                f"QTB_spectra_{sp}.out",
888                columns,
889                fmt="%12.6f",
890                header="#omega mCvv Cvf dFDT gammar Cff",
891            )
892        if verbose:
893            print("# QTB spectra written.")
894
895    if compute_thermostat_energy:
896        state["qtb_energy_flux"] = 0.0
897
898    @jax.jit
899    def thermostat(vel, state):
900        istep = state["istep"]
901        dvel = dt * state["force"][istep] / mass[:, None]
902        new_vel = vel * a1 + dvel
903        new_state = {**state, "istep": istep + 1}
904        if do_compute_spectra:
905            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
906            new_state["vel"] = vel2
907        if compute_thermostat_energy:
908            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
909            ekcorr = (
910                0.5
911                * (mass[:, None] * new_vel**2).sum()
912                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
913            )
914            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
915            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
916        return new_vel, new_state
917
918    @jax.jit
919    def postprocess_work(state, post_state):
920        if do_compute_spectra:
921            post_state = compute_spectra(state["force"], state["vel"], post_state)
922        if adaptive:
923            post_state = jax.lax.cond(
924                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
925            )
926        new_force, post_state = refresh_force(post_state)
927        return {**state, "force": new_force}, post_state
928
929    def postprocess(state, post_state):
930        counter.increment()
931        if not counter.is_reset_step:
932            return state, post_state
933        post_state["nadapt"] += 1
934        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
935        if verbose:
936            print("# Refreshing QTB forces.")
937        state, post_state = postprocess_work(state, post_state)
938        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
939        state["istep"] = 0
940        if write_spectra:
941            write_spectra_to_file(post_state)
942        write_qtb_restart(state, post_state)
943        return state, post_state
944
945    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
946
947    state["force"], post_state = refresh_force(post_state)
948    return thermostat, (postprocess, post_state), state