fennol.md.thermostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7import os 8import pickle 9 10from .utils import us 11from ..utils import Counter,read_tinker_interval 12from ..utils.deconvolution import ( 13 deconvolute_spectrum, 14 kernel_lorentz_pot, 15 kernel_lorentz, 16) 17from ..utils.periodic_table import PERIODIC_TABLE 18 19 20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 21 state = {} 22 postprocess = None 23 24 default_ekin_instant = "B" 25 26 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 27 """@keyword[fennol_md] thermostat 28 Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'. 29 Default: "LGV" 30 """ 31 compute_thermostat_energy = simulation_parameters.get( 32 "include_thermostat_energy", False 33 ) 34 """@keyword[fennol_md] include_thermostat_energy 35 Include thermostat energy in total energy calculations. 36 Default: False 37 """ 38 39 kT = system_data.get("kT", None) 40 nbeads = system_data.get("nbeads", None) 41 mass = system_data["mass"] 42 gamma0 = simulation_parameters.get("gamma", 1.0 / us.THZ) 43 """@keyword[fennol_md] gamma 44 Friction coefficient for Langevin thermostat. 45 Default: 1.0 ps^-1 46 """ 47 gamma = gamma0 48 if gamma <= 0.0: 49 gamma = 0.0 50 species = system_data["species"] 51 52 if nbeads is not None: 53 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 54 """@keyword[fennol_md] trpmd_lambda 55 Lambda parameter for TRPMD (Thermostatted Ring Polymer MD). 56 Default: 1.0 57 """ 58 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 59 60 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 61 default_ekin_instant = "O" 62 if gamma0 <= 1.e-5: 63 default_ekin_instant = "B" 64 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 65 assert kT is not None, "kT must be specified for QTB thermostat" 66 assert gamma is not None, "gamma must be specified for QTB thermostat" 67 rng_key, v_key = jax.random.split(rng_key) 68 if nbeads is None: 69 a1 = math.exp(-gamma * dt) 70 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 71 vel = ( 72 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 73 * (kT / mass[:, None]) ** 0.5 74 ) 75 else: 76 if isinstance(gamma, float): 77 gamma = np.array([gamma] * nbeads) 78 assert isinstance( 79 gamma, np.ndarray 80 ), "gamma must be a float or a numpy array" 81 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 82 a1 = np.exp(-gamma * dt)[:, None, None] 83 a2 = jnp.asarray( 84 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 85 ) 86 vel = ( 87 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 88 * (kT / mass[:, None]) ** 0.5 89 ) 90 91 state["rng_key"] = rng_key 92 if compute_thermostat_energy: 93 state["thermostat_energy"] = 0.0 94 if thermostat_name == "FFLGV": 95 def thermostat(vel, state): 96 rng_key, noise_key = jax.random.split(state["rng_key"]) 97 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 98 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 99 dirvel = vel / norm_vel 100 if compute_thermostat_energy: 101 v2 = (vel**2).sum(axis=-1) 102 vel = a1 * vel + a2 * noise 103 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 104 vel = dirvel * new_norm_vel 105 new_state = {**state, "rng_key": rng_key} 106 if compute_thermostat_energy: 107 v2new = (vel**2).sum(axis=-1) 108 new_state["thermostat_energy"] = ( 109 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 110 ) 111 112 return vel, new_state 113 114 else: 115 def thermostat(vel, state): 116 rng_key, noise_key = jax.random.split(state["rng_key"]) 117 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 118 if compute_thermostat_energy: 119 v2 = (vel**2).sum(axis=-1) 120 vel = a1 * vel + a2 * noise 121 new_state = {**state, "rng_key": rng_key} 122 if compute_thermostat_energy: 123 v2new = (vel**2).sum(axis=-1) 124 new_state["thermostat_energy"] = ( 125 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 126 ) 127 return vel, new_state 128 129 elif thermostat_name in ["BUSSI"]: 130 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 131 assert kT is not None, "kT must be specified for QTB thermostat" 132 assert gamma is not None, "gamma must be specified for QTB thermostat" 133 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 134 rng_key, v_key = jax.random.split(rng_key) 135 136 a1 = math.exp(-gamma * dt) 137 a2 = (1 - a1) * kT 138 vel = ( 139 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 140 * (kT / mass[:, None]) ** 0.5 141 ) 142 143 state["rng_key"] = rng_key 144 if compute_thermostat_energy: 145 state["thermostat_energy"] = 0.0 146 147 def thermostat(vel, state): 148 rng_key, noise_key = jax.random.split(state["rng_key"]) 149 new_state = {**state, "rng_key": rng_key} 150 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 151 R2 = jnp.sum(noise**2) 152 R1 = noise[0, 0] 153 c = a2 / (mass[:, None] * vel**2).sum() 154 d = (a1 * c) ** 0.5 155 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 156 if compute_thermostat_energy: 157 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 158 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 159 return scale * vel, new_state 160 161 elif thermostat_name in [ 162 "GD", 163 "DESCENT", 164 "GRADIENT_DESCENT", 165 "MIN", 166 "MINIMIZE", 167 ]: 168 assert nbeads is None, "Gradient descent is not compatible with PIMD" 169 a1 = math.exp(-gamma * dt) 170 171 if nbeads is None: 172 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 173 else: 174 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 175 176 def thermostat(vel, state): 177 return a1 * vel, state 178 179 elif thermostat_name in ["NVE", "NONE"]: 180 if kT is None: 181 kT = 0. 182 183 if nbeads is None: 184 vel = ( 185 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 186 * (kT / mass[:, None]) ** 0.5 187 ) 188 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 189 vel = vel * (kT / kTsys) ** 0.5 190 else: 191 vel = ( 192 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 193 * (kT / mass[None, :, None]) ** 0.5 194 ) 195 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 196 mass.shape[0] * 3 197 ) 198 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 199 thermostat = lambda x, s: (x, s) 200 201 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 202 assert gamma is not None, "gamma must be specified for QTB thermostat" 203 ndof = mass.shape[0] * 3 204 nkT = ndof * kT 205 nose_mass = nkT / gamma**2 206 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 207 state["nose_s"] = 0.0 208 state["nose_v"] = 0.0 209 if compute_thermostat_energy: 210 state["thermostat_energy"] = 0.0 211 print( 212 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 213 ) 214 vel = ( 215 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 216 * (kT / mass[:, None]) ** 0.5 217 ) 218 219 def thermostat(vel, state): 220 nose_s = state["nose_s"] 221 nose_v = state["nose_v"] 222 kTsys = jnp.sum(mass[:, None] * vel**2) 223 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 224 nose_s = nose_s + dt * nose_v 225 vel = jnp.exp(-nose_v * dt) * vel 226 kTsys = jnp.sum(mass[:, None] * vel**2) 227 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 228 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 229 230 if compute_thermostat_energy: 231 new_state["thermostat_energy"] = ( 232 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 233 ) 234 return vel, new_state 235 236 elif thermostat_name in ["QTB", "ADQTB"]: 237 default_ekin_instant = "O" 238 assert nbeads is None, "QTB is not compatible with PIMD" 239 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 240 assert kT is not None, "kT must be specified for QTB thermostat" 241 assert gamma is not None, "gamma must be specified for QTB thermostat" 242 assert species is not None, "species must be provided for QTB thermostat" 243 rng_key, v_key = jax.random.split(rng_key) 244 vel = ( 245 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 246 * (kT / mass[:, None]) ** 0.5 247 ) 248 249 thermostat, postprocess, qtb_state = initialize_qtb( 250 simulation_parameters, 251 system_data, 252 fprec=fprec, 253 dt=dt, 254 mass=mass, 255 gamma=gamma, 256 kT=kT, 257 species=species, 258 rng_key=rng_key, 259 adaptive=thermostat_name.startswith("AD"), 260 compute_thermostat_energy=compute_thermostat_energy, 261 ) 262 state = {**state, **qtb_state} 263 264 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 265 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 266 assert kT is not None, "kT must be specified for QTB thermostat" 267 assert gamma is not None, "gamma must be specified for QTB thermostat" 268 assert nbeads is None, "ANNEAL is not compatible with PIMD" 269 a1 = math.exp(-gamma * dt) 270 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 271 272 anneal_parameters = simulation_parameters.get("annealing", {}) 273 """@keyword[fennol_md] annealing 274 Parameters for simulated annealing schedule configuration. 275 Required for ANNEAL/ANNEALING thermostat 276 """ 277 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 278 """@keyword[fennol_md] annealing/init_factor 279 Initial temperature factor for annealing schedule. 280 Default: 0.04 (1/25) 281 """ 282 assert init_factor > 0.0, "init_factor must be positive" 283 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 284 """@keyword[fennol_md] annealing/final_factor 285 Final temperature factor for annealing schedule. 286 Default: 0.0001 (1/10000) 287 """ 288 assert final_factor > 0.0, "final_factor must be positive" 289 nsteps = simulation_parameters.get("nsteps") 290 """@keyword[fennol_md] nsteps 291 Total number of simulation steps for annealing schedule calculation. 292 Required parameter 293 """ 294 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 295 """@keyword[fennol_md] annealing/anneal_steps 296 Fraction of total steps for annealing process. 297 Default: 1.0 298 """ 299 assert ( 300 anneal_steps < 1.0 and anneal_steps > 0.0 301 ), "warmup_steps must be between 0 and nsteps" 302 pct_start = anneal_parameters.get("warmup_steps", 0.3) 303 """@keyword[fennol_md] annealing/warmup_steps 304 Fraction of annealing steps for warmup phase. 305 Default: 0.3 306 """ 307 assert ( 308 pct_start < 1.0 and pct_start > 0.0 309 ), "warmup_steps must be between 0 and nsteps" 310 311 anneal_type = anneal_parameters.get("type", "cosine").lower() 312 """@keyword[fennol_md] annealing/type 313 Type of annealing schedule (linear, cosine_onecycle). 314 Default: cosine 315 """ 316 if anneal_type == "linear": 317 schedule = optax.linear_onecycle_schedule( 318 peak_value=1.0, 319 div_factor=1.0 / init_factor, 320 final_div_factor=1.0 / final_factor, 321 transition_steps=int(anneal_steps * nsteps), 322 pct_start=pct_start, 323 pct_final=1.0, 324 ) 325 elif anneal_type == "cosine_onecycle": 326 schedule = optax.cosine_onecycle_schedule( 327 peak_value=1.0, 328 div_factor=1.0 / init_factor, 329 final_div_factor=1.0 / final_factor, 330 transition_steps=int(anneal_steps * nsteps), 331 pct_start=pct_start, 332 ) 333 else: 334 raise ValueError(f"Unknown anneal_type {anneal_type}") 335 336 state["rng_key"] = rng_key 337 state["istep_anneal"] = 0 338 339 rng_key, v_key = jax.random.split(rng_key) 340 Tscale = schedule(0) 341 print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K") 342 vel = ( 343 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 344 * (kT * Tscale / mass[:, None]) ** 0.5 345 ) 346 347 def thermostat(vel, state): 348 rng_key, noise_key = jax.random.split(state["rng_key"]) 349 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 350 351 Tscale = schedule(state["istep_anneal"]) ** 0.5 352 vel = a1 * vel + a2 * Tscale * noise 353 return vel, { 354 **state, 355 "rng_key": rng_key, 356 "istep_anneal": state["istep_anneal"] + 1, 357 } 358 359 else: 360 raise ValueError(f"Unknown thermostat {thermostat_name}") 361 362 # remove center of mass velocity 363 massprop = system_data["mass_Da"]/system_data["totmass_Da"] 364 if nbeads is None: 365 vel = vel.reshape(-1,system_data["nat"], 3) 366 vcom = jnp.sum(vel * massprop[None,:, None], axis=1, keepdims=True) 367 vel = (vel - vcom).reshape(-1,3) 368 else: 369 vcom = jnp.sum(vel[0]*massprop[:,None], axis=0, keepdims=True) 370 vel = vel.at[0].set(vel[0] - vcom) 371 372 ekin_instant = str(simulation_parameters.get("ekin_instant", default_ekin_instant)).upper() 373 """@keyword[fennol_md] ekin_instant 374 Where in the time step to compute the instantaneous kinetic energy. 375 Options: 'B' (end of step), 'O' (after thermostat). 376 Default: 'O' for LGV and QTB thermostats, 'B' otherwise. 377 """ 378 assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'" 379 380 return thermostat, postprocess, state, vel,thermostat_name, ekin_instant 381 382 383def initialize_qtb( 384 simulation_parameters, 385 system_data, 386 fprec, 387 dt, 388 mass, 389 gamma, 390 kT, 391 species, 392 rng_key, 393 adaptive, 394 compute_thermostat_energy=False, 395): 396 state = {} 397 post_state = {} 398 qtb_parameters = simulation_parameters.get("qtb", {}) 399 verbose = qtb_parameters.get("verbose", False) 400 """@keyword[fennol_md] qtb/verbose 401 Print verbose QTB thermostat information. 402 Default: False 403 """ 404 if compute_thermostat_energy: 405 state["thermostat_energy"] = 0.0 406 407 mass = jnp.asarray(mass, dtype=fprec) 408 409 nat = species.shape[0] 410 # define type indices 411 species_set = set(species) 412 nspecies = len(species_set) 413 idx = {sp: i for i, sp in enumerate(species_set)} 414 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 415 type_labels = [PERIODIC_TABLE[sp] for sp in species_set] 416 417 adapt_groups = qtb_parameters.get("adapt_groups", {}) 418 """@keyword[fennol_md] qtb/adapt_groups 419 Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes). 420 Default: {} 421 """ 422 ntypes = nspecies 423 allgroups = set() 424 for groupname, interval in adapt_groups.items(): 425 indices = read_tinker_interval(interval) 426 assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups" 427 allgroups.update(indices) 428 species_in_group = set(species[indices]) 429 idx = {sp: i for i, sp in enumerate(species_in_group)} 430 for i in indices: 431 type_idx[i] = ntypes + idx[species[i]] 432 type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group] 433 ntypes = len(type_labels) 434 435 n_of_type = np.zeros(ntypes, dtype=np.int32) 436 for i in range(ntypes): 437 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 438 print(f"# QTB: n({type_labels[i]}) =", n_of_type[i]) 439 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 440 mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type 441 442 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 443 """@keyword[fennol_md] qtb/niter_deconv_kin 444 Number of iterations for kinetic energy deconvolution. 445 Default: 7 446 """ 447 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 448 """@keyword[fennol_md] qtb/niter_deconv_pot 449 Number of iterations for potential energy deconvolution. 450 Default: 20 451 """ 452 corr_kin = qtb_parameters.get("corr_kin", -1) 453 """@keyword[fennol_md] qtb/corr_kin 454 Kinetic energy correction factor for QTB (-1 for automatic). 455 Default: -1 456 """ 457 do_corr_kin = corr_kin <= 0 458 if do_corr_kin: 459 corr_kin = 1.0 460 state["corr_kin"] = corr_kin 461 post_state["corr_kin_prev"] = corr_kin 462 post_state["do_corr_kin"] = do_corr_kin 463 post_state["isame_kin"] = 0 464 465 # spectra parameters 466 omegasmear = np.pi / dt / 100.0 467 Tseg = qtb_parameters.get("tseg", 1.0 / us.PS) 468 """@keyword[fennol_md] qtb/tseg 469 Time segment length for QTB spectrum calculation. 470 Default: 1.0 ps 471 """ 472 nseg = int(Tseg / dt) 473 Tseg = nseg * dt 474 dom = 2 * np.pi / (3 * Tseg) 475 omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1) 476 """@keyword[fennol_md] qtb/omegacut 477 Cutoff frequency for QTB spectrum. 478 Default: 15000.0 cm⁻¹ 479 """ 480 nom = int(omegacut / dom) 481 omega = dom * np.arange((3 * nseg) // 2 + 1) 482 cutoff = jnp.asarray( 483 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 484 ) 485 assert ( 486 omegacut < omega[-1] 487 ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1" 488 489 # initialize gammar 490 assert ( 491 gamma < 0.5 * omegacut 492 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 493 gammar_min = qtb_parameters.get("gammar_min", 0.1) 494 """@keyword[fennol_md] qtb/gammar_min 495 Minimum value for QTB gamma ratio coefficients. 496 Default: 0.1 497 """ 498 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 499 gammar = np.ones((ntypes, nom), dtype=float) 500 try: 501 for i, sp in enumerate(type_labels): 502 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 503 data = np.loadtxt(f"QTB_spectra_{sp}.out") 504 gammar[i] = data[:, 4]/(gamma*us.THZ) 505 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 506 except Exception as e: 507 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 508 gammar[:,:] = 1.0 509 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 510 511 # Ornstein-Uhlenbeck correction for colored noise 512 a1 = np.exp(-gamma * dt) 513 OUcorr = jnp.asarray( 514 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 515 dtype=fprec, 516 ) 517 518 # hbar schedule 519 classical_kernel = qtb_parameters.get("classical_kernel", False) 520 """@keyword[fennol_md] qtb/classical_kernel 521 Use classical instead of quantum kernel for QTB. 522 Default: False 523 """ 524 hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR 525 """@keyword[fennol_md] qtb/hbar 526 Reduced Planck constant scaling factor for quantum effects. 527 Default: 1.0 a.u. 528 """ 529 u = 0.5 * hbar * np.abs(omega) / kT 530 theta = kT * np.ones_like(omega) 531 if hbar > 0: 532 theta[1:] *= u[1:] / np.tanh(u[1:]) 533 theta = jnp.asarray(theta, dtype=fprec) 534 535 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 536 del rng_key 537 post_state["white_noise"] = jax.random.normal( 538 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 539 ) 540 541 startsave = qtb_parameters.get("startsave", 1) 542 """@keyword[fennol_md] qtb/startsave 543 Start saving QTB statistics after this many segments. 544 Default: 1 545 """ 546 counter = Counter(nseg, startsave=startsave) 547 state["istep"] = 0 548 post_state["nadapt"] = 0 549 post_state["nsample"] = 0 550 551 write_spectra = qtb_parameters.get("write_spectra", True) 552 """@keyword[fennol_md] qtb/write_spectra 553 Write QTB spectral analysis output files. 554 Default: True 555 """ 556 do_compute_spectra = write_spectra or adaptive 557 558 if do_compute_spectra: 559 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 560 561 post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec) 562 post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec) 563 post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec) 564 post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec) 565 post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 566 post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 567 post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 568 post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 569 570 if not adaptive: 571 update_gammar = lambda x: x 572 else: 573 # adaptation parameters 574 skipseg = qtb_parameters.get("skipseg", 1) 575 """@keyword[fennol_md] qtb/skipseg 576 Number of segments to skip before starting adaptive QTB. 577 Default: 1 578 """ 579 580 adaptation_method = ( 581 str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip() 582 ) 583 """@keyword[fennol_md] qtb/adaptation_method 584 Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF). 585 Default: SIMPLE 586 """ 587 if adaptation_method == "SIMPLE": 588 agamma = qtb_parameters.get("agamma", 0.1) 589 """@keyword[fennol_md] qtb/agamma 590 Learning rate for adaptive QTB gamma update. 591 Default: 0.1 592 """ 593 assert agamma > 0, "agamma must be positive" 594 a1_ad = agamma * Tseg # * gamma 595 print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}") 596 597 def update_gammar(post_state): 598 g = post_state["dFDT"] 599 gammar = post_state["gammar"] - a1_ad * g 600 gammar = jnp.maximum(gammar_min, gammar) 601 return {**post_state, "gammar": gammar} 602 603 elif adaptation_method == "RATIO": 604 tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 605 """@keyword[fennol_md] qtb/tau_ad 606 Adaptation time constant for momentum averaging. 607 Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF) 608 """ 609 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) 610 """@keyword[fennol_md] qtb/tau_s 611 Second moment time constant for variance averaging. 612 Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF) 613 """ 614 assert tau_ad > 0, "tau_ad must be positive" 615 print( 616 f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 617 ) 618 b1 = np.exp(-Tseg / tau_ad) 619 b2 = np.exp(-Tseg / tau_s) 620 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 621 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 622 post_state["n_adabelief"] = 0 623 624 def update_gammar(post_state): 625 n_adabelief = post_state["n_adabelief"] + 1 626 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 627 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 628 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 629 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 630 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 631 gammar = Cvf / (mCvv + 1.0e-8) 632 gammar = jnp.maximum(gammar_min, gammar) 633 return { 634 **post_state, 635 "gammar": gammar, 636 "mCvv_m": mCvv_m, 637 "Cvf_m": Cvf_m, 638 "n_adabelief": n_adabelief, 639 } 640 641 elif adaptation_method == "ADABELIEF": 642 agamma = qtb_parameters.get("agamma", 0.1) 643 tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS) 644 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) 645 assert tau_ad > 0, "tau_ad must be positive" 646 assert tau_s > 0, "tau_s must be positive" 647 assert agamma > 0, "agamma must be positive" 648 print( 649 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 650 ) 651 652 a1_ad = agamma * gamma # * Tseg #* gamma 653 b1 = np.exp(-Tseg / tau_ad) 654 b2 = np.exp(-Tseg / tau_s) 655 post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec) 656 post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec) 657 post_state["n_adabelief"] = 0 658 659 def update_gammar(post_state): 660 n_adabelief = post_state["n_adabelief"] + 1 661 dFDT = post_state["dFDT"] 662 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 663 dFDT_s = ( 664 post_state["dFDT_s"] * b2 665 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 666 + 1.0e-8 667 ) 668 # bias correction 669 mt = dFDT_m / (1.0 - b1**n_adabelief) 670 st = dFDT_s / (1.0 - b2**n_adabelief) 671 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 672 gammar = jnp.maximum(gammar_min, gammar) 673 return { 674 **post_state, 675 "gammar": gammar, 676 "dFDT_m": dFDT_m, 677 "n_adabelief": n_adabelief, 678 "dFDT_s": dFDT_s, 679 } 680 else: 681 raise ValueError( 682 f"Unknown adaptation method {adaptation_method}." 683 ) 684 685 ##################### 686 # RESTART 687 restart_file = system_data["name"]+".qtb.restart" 688 if os.path.exists(restart_file): 689 with open(restart_file, "rb") as f: 690 data = pickle.load(f) 691 state["corr_kin"] = data["corr_kin"] 692 post_state["corr_kin_prev"] = data["corr_kin"] 693 post_state["isame_kin"] = data["isame_kin"] 694 post_state["do_corr_kin"] = data["do_corr_kin"] 695 print(f"# Restored QTB state from {restart_file}") 696 697 def write_qtb_restart(state, post_state): 698 with open(restart_file, "wb") as f: 699 pickle.dump( 700 { 701 "corr_kin": state["corr_kin"], 702 "corr_kin_prev": post_state["corr_kin_prev"], 703 "isame_kin": post_state["isame_kin"], 704 "do_corr_kin": post_state["do_corr_kin"], 705 }, 706 f, 707 ) 708 ###################### 709 710 def compute_corr_pot(niter=20, verbose=False): 711 if classical_kernel or hbar == 0: 712 return np.ones(nom) 713 714 s_0 = np.array((theta / kT * cutoff)[:nom]) 715 s_out, s_rec, _ = deconvolute_spectrum( 716 s_0, 717 omega[:nom], 718 gamma, 719 niter, 720 kernel=kernel_lorentz_pot, 721 trans=True, 722 symmetrize=True, 723 verbose=verbose, 724 ) 725 corr_pot = 1.0 + (s_out - s_0) / s_0 726 columns = np.column_stack( 727 (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 728 ) 729 np.savetxt( 730 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 731 ) 732 return corr_pot 733 734 def compute_corr_kin(post_state, niter=7, verbose=False): 735 if not post_state["do_corr_kin"]: 736 return post_state["corr_kin_prev"], post_state 737 if classical_kernel or hbar == 0: 738 return 1.0, post_state 739 740 K_D = post_state.get("K_D", None) 741 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 742 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 743 s_out, s_rec, K_D = deconvolute_spectrum( 744 s_0, 745 omega[:nom], 746 gamma, 747 niter, 748 kernel=kernel_lorentz, 749 trans=False, 750 symmetrize=True, 751 verbose=verbose, 752 K_D=K_D, 753 ) 754 s_out = s_out * theta[:nom] / kT 755 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 756 mCvvsum = mCvv.sum() 757 rec_ratio = mCvvsum / s_rec.sum() 758 if rec_ratio < 0.95 or rec_ratio > 1.05: 759 print( 760 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 761 ) 762 return post_state["corr_kin_prev"], post_state 763 764 corr_kin = mCvvsum / s_out.sum() 765 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 766 isame_kin = post_state["isame_kin"] + 1 767 else: 768 isame_kin = 0 769 770 # print("# corr_kin: ", corr_kin) 771 do_corr_kin = post_state["do_corr_kin"] 772 if isame_kin > 10: 773 print( 774 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 775 ) 776 do_corr_kin = False 777 778 return corr_kin, { 779 **post_state, 780 "corr_kin_prev": corr_kin, 781 "isame_kin": isame_kin, 782 "do_corr_kin": do_corr_kin, 783 "K_D": K_D, 784 } 785 786 @jax.jit 787 def ff_kernel(post_state): 788 if classical_kernel: 789 kernel = cutoff * (2 * gamma * kT / dt) 790 else: 791 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 792 gamma_ratio = jnp.concatenate( 793 ( 794 post_state["gammar"].T * post_state["corr_pot"][:, None], 795 jnp.ones( 796 (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype 797 ), 798 ), 799 axis=0, 800 ) 801 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 802 803 @jax.jit 804 def refresh_force(post_state): 805 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 806 white_noise = jnp.concatenate( 807 ( 808 post_state["white_noise"][nseg:], 809 jax.random.normal( 810 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 811 ), 812 ), 813 axis=0, 814 ) 815 amplitude = ff_kernel(post_state) ** 0.5 816 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 817 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 818 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 819 820 @jax.jit 821 def compute_spectra(force, vel, post_state): 822 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 823 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 824 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 825 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 826 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 827 828 mCvv = ( 829 (dt / 3.0) 830 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 831 * mass_idx[:, None] 832 / n_of_type[:, None] 833 ) 834 Cvf = ( 835 (dt / 3.0) 836 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 837 / n_of_type[:, None] 838 ) 839 Cff = ( 840 (dt / 3.0) 841 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 842 / n_of_type[:, None] 843 ) 844 dFDT = mCvv * post_state["gammar"] - Cvf 845 846 nsinv = 1.0 / post_state["nsample"] 847 b1 = 1.0 - nsinv 848 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 849 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 850 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 851 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 852 853 return { 854 **post_state, 855 "mCvv": mCvv, 856 "Cvf": Cvf, 857 "Cff": Cff, 858 "dFDT": dFDT, 859 "dFDT_avg": dFDT_avg, 860 "mCvv_avg": mCvv_avg, 861 "Cvfg_avg": Cvfg_avg, 862 "Cff_avg": Cff_avg, 863 } 864 865 def write_spectra_to_file(post_state): 866 mCvv_avg = np.array(post_state["mCvv_avg"]) 867 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 868 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 869 dFDT_avg = np.array(post_state["dFDT_avg"]) 870 gammar = np.array(post_state["gammar"]) 871 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 872 for i, sp in enumerate(type_labels): 873 ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 874 columns = np.column_stack( 875 ( 876 omega[:nom] * us.CM1, 877 mCvv_avg[i], 878 Cvfg_avg[i], 879 dFDT_avg[i], 880 gammar[i] * gamma * us.THZ, 881 Cff_avg[i] * ff_scale, 882 Cff_theo[i] * ff_scale, 883 ) 884 ) 885 np.savetxt( 886 f"QTB_spectra_{sp}.out", 887 columns, 888 fmt="%12.6f", 889 header="#omega mCvv Cvf dFDT gammar Cff", 890 ) 891 if verbose: 892 print("# QTB spectra written.") 893 894 if compute_thermostat_energy: 895 state["qtb_energy_flux"] = 0.0 896 897 @jax.jit 898 def thermostat(vel, state): 899 istep = state["istep"] 900 dvel = dt * state["force"][istep] / mass[:, None] 901 new_vel = vel * a1 + dvel 902 new_state = {**state, "istep": istep + 1} 903 if do_compute_spectra: 904 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 905 new_state["vel"] = vel2 906 if compute_thermostat_energy: 907 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 908 ekcorr = ( 909 0.5 910 * (mass[:, None] * new_vel**2).sum() 911 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 912 ) 913 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 914 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 915 return new_vel, new_state 916 917 @jax.jit 918 def postprocess_work(state, post_state): 919 if do_compute_spectra: 920 post_state = compute_spectra(state["force"], state["vel"], post_state) 921 if adaptive: 922 post_state = jax.lax.cond( 923 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 924 ) 925 new_force, post_state = refresh_force(post_state) 926 return {**state, "force": new_force}, post_state 927 928 def postprocess(state, post_state): 929 counter.increment() 930 if not counter.is_reset_step: 931 return state, post_state 932 post_state["nadapt"] += 1 933 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 934 if verbose: 935 print("# Refreshing QTB forces.") 936 state, post_state = postprocess_work(state, post_state) 937 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 938 state["istep"] = 0 939 if write_spectra: 940 write_spectra_to_file(post_state) 941 write_qtb_restart(state, post_state) 942 return state, post_state 943 944 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 945 946 state["force"], post_state = refresh_force(post_state) 947 return thermostat, (postprocess, post_state), state
def
get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
21def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 22 state = {} 23 postprocess = None 24 25 default_ekin_instant = "B" 26 27 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 28 """@keyword[fennol_md] thermostat 29 Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'. 30 Default: "LGV" 31 """ 32 compute_thermostat_energy = simulation_parameters.get( 33 "include_thermostat_energy", False 34 ) 35 """@keyword[fennol_md] include_thermostat_energy 36 Include thermostat energy in total energy calculations. 37 Default: False 38 """ 39 40 kT = system_data.get("kT", None) 41 nbeads = system_data.get("nbeads", None) 42 mass = system_data["mass"] 43 gamma0 = simulation_parameters.get("gamma", 1.0 / us.THZ) 44 """@keyword[fennol_md] gamma 45 Friction coefficient for Langevin thermostat. 46 Default: 1.0 ps^-1 47 """ 48 gamma = gamma0 49 if gamma <= 0.0: 50 gamma = 0.0 51 species = system_data["species"] 52 53 if nbeads is not None: 54 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 55 """@keyword[fennol_md] trpmd_lambda 56 Lambda parameter for TRPMD (Thermostatted Ring Polymer MD). 57 Default: 1.0 58 """ 59 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 60 61 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 62 default_ekin_instant = "O" 63 if gamma0 <= 1.e-5: 64 default_ekin_instant = "B" 65 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 66 assert kT is not None, "kT must be specified for QTB thermostat" 67 assert gamma is not None, "gamma must be specified for QTB thermostat" 68 rng_key, v_key = jax.random.split(rng_key) 69 if nbeads is None: 70 a1 = math.exp(-gamma * dt) 71 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 72 vel = ( 73 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 74 * (kT / mass[:, None]) ** 0.5 75 ) 76 else: 77 if isinstance(gamma, float): 78 gamma = np.array([gamma] * nbeads) 79 assert isinstance( 80 gamma, np.ndarray 81 ), "gamma must be a float or a numpy array" 82 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 83 a1 = np.exp(-gamma * dt)[:, None, None] 84 a2 = jnp.asarray( 85 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 86 ) 87 vel = ( 88 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 89 * (kT / mass[:, None]) ** 0.5 90 ) 91 92 state["rng_key"] = rng_key 93 if compute_thermostat_energy: 94 state["thermostat_energy"] = 0.0 95 if thermostat_name == "FFLGV": 96 def thermostat(vel, state): 97 rng_key, noise_key = jax.random.split(state["rng_key"]) 98 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 99 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 100 dirvel = vel / norm_vel 101 if compute_thermostat_energy: 102 v2 = (vel**2).sum(axis=-1) 103 vel = a1 * vel + a2 * noise 104 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 105 vel = dirvel * new_norm_vel 106 new_state = {**state, "rng_key": rng_key} 107 if compute_thermostat_energy: 108 v2new = (vel**2).sum(axis=-1) 109 new_state["thermostat_energy"] = ( 110 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 111 ) 112 113 return vel, new_state 114 115 else: 116 def thermostat(vel, state): 117 rng_key, noise_key = jax.random.split(state["rng_key"]) 118 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 119 if compute_thermostat_energy: 120 v2 = (vel**2).sum(axis=-1) 121 vel = a1 * vel + a2 * noise 122 new_state = {**state, "rng_key": rng_key} 123 if compute_thermostat_energy: 124 v2new = (vel**2).sum(axis=-1) 125 new_state["thermostat_energy"] = ( 126 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 127 ) 128 return vel, new_state 129 130 elif thermostat_name in ["BUSSI"]: 131 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 132 assert kT is not None, "kT must be specified for QTB thermostat" 133 assert gamma is not None, "gamma must be specified for QTB thermostat" 134 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 135 rng_key, v_key = jax.random.split(rng_key) 136 137 a1 = math.exp(-gamma * dt) 138 a2 = (1 - a1) * kT 139 vel = ( 140 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 141 * (kT / mass[:, None]) ** 0.5 142 ) 143 144 state["rng_key"] = rng_key 145 if compute_thermostat_energy: 146 state["thermostat_energy"] = 0.0 147 148 def thermostat(vel, state): 149 rng_key, noise_key = jax.random.split(state["rng_key"]) 150 new_state = {**state, "rng_key": rng_key} 151 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 152 R2 = jnp.sum(noise**2) 153 R1 = noise[0, 0] 154 c = a2 / (mass[:, None] * vel**2).sum() 155 d = (a1 * c) ** 0.5 156 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 157 if compute_thermostat_energy: 158 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 159 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 160 return scale * vel, new_state 161 162 elif thermostat_name in [ 163 "GD", 164 "DESCENT", 165 "GRADIENT_DESCENT", 166 "MIN", 167 "MINIMIZE", 168 ]: 169 assert nbeads is None, "Gradient descent is not compatible with PIMD" 170 a1 = math.exp(-gamma * dt) 171 172 if nbeads is None: 173 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 174 else: 175 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 176 177 def thermostat(vel, state): 178 return a1 * vel, state 179 180 elif thermostat_name in ["NVE", "NONE"]: 181 if kT is None: 182 kT = 0. 183 184 if nbeads is None: 185 vel = ( 186 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 187 * (kT / mass[:, None]) ** 0.5 188 ) 189 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 190 vel = vel * (kT / kTsys) ** 0.5 191 else: 192 vel = ( 193 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 194 * (kT / mass[None, :, None]) ** 0.5 195 ) 196 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 197 mass.shape[0] * 3 198 ) 199 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 200 thermostat = lambda x, s: (x, s) 201 202 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 203 assert gamma is not None, "gamma must be specified for QTB thermostat" 204 ndof = mass.shape[0] * 3 205 nkT = ndof * kT 206 nose_mass = nkT / gamma**2 207 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 208 state["nose_s"] = 0.0 209 state["nose_v"] = 0.0 210 if compute_thermostat_energy: 211 state["thermostat_energy"] = 0.0 212 print( 213 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 214 ) 215 vel = ( 216 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 217 * (kT / mass[:, None]) ** 0.5 218 ) 219 220 def thermostat(vel, state): 221 nose_s = state["nose_s"] 222 nose_v = state["nose_v"] 223 kTsys = jnp.sum(mass[:, None] * vel**2) 224 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 225 nose_s = nose_s + dt * nose_v 226 vel = jnp.exp(-nose_v * dt) * vel 227 kTsys = jnp.sum(mass[:, None] * vel**2) 228 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 229 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 230 231 if compute_thermostat_energy: 232 new_state["thermostat_energy"] = ( 233 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 234 ) 235 return vel, new_state 236 237 elif thermostat_name in ["QTB", "ADQTB"]: 238 default_ekin_instant = "O" 239 assert nbeads is None, "QTB is not compatible with PIMD" 240 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 241 assert kT is not None, "kT must be specified for QTB thermostat" 242 assert gamma is not None, "gamma must be specified for QTB thermostat" 243 assert species is not None, "species must be provided for QTB thermostat" 244 rng_key, v_key = jax.random.split(rng_key) 245 vel = ( 246 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 247 * (kT / mass[:, None]) ** 0.5 248 ) 249 250 thermostat, postprocess, qtb_state = initialize_qtb( 251 simulation_parameters, 252 system_data, 253 fprec=fprec, 254 dt=dt, 255 mass=mass, 256 gamma=gamma, 257 kT=kT, 258 species=species, 259 rng_key=rng_key, 260 adaptive=thermostat_name.startswith("AD"), 261 compute_thermostat_energy=compute_thermostat_energy, 262 ) 263 state = {**state, **qtb_state} 264 265 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 266 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 267 assert kT is not None, "kT must be specified for QTB thermostat" 268 assert gamma is not None, "gamma must be specified for QTB thermostat" 269 assert nbeads is None, "ANNEAL is not compatible with PIMD" 270 a1 = math.exp(-gamma * dt) 271 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 272 273 anneal_parameters = simulation_parameters.get("annealing", {}) 274 """@keyword[fennol_md] annealing 275 Parameters for simulated annealing schedule configuration. 276 Required for ANNEAL/ANNEALING thermostat 277 """ 278 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 279 """@keyword[fennol_md] annealing/init_factor 280 Initial temperature factor for annealing schedule. 281 Default: 0.04 (1/25) 282 """ 283 assert init_factor > 0.0, "init_factor must be positive" 284 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 285 """@keyword[fennol_md] annealing/final_factor 286 Final temperature factor for annealing schedule. 287 Default: 0.0001 (1/10000) 288 """ 289 assert final_factor > 0.0, "final_factor must be positive" 290 nsteps = simulation_parameters.get("nsteps") 291 """@keyword[fennol_md] nsteps 292 Total number of simulation steps for annealing schedule calculation. 293 Required parameter 294 """ 295 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 296 """@keyword[fennol_md] annealing/anneal_steps 297 Fraction of total steps for annealing process. 298 Default: 1.0 299 """ 300 assert ( 301 anneal_steps < 1.0 and anneal_steps > 0.0 302 ), "warmup_steps must be between 0 and nsteps" 303 pct_start = anneal_parameters.get("warmup_steps", 0.3) 304 """@keyword[fennol_md] annealing/warmup_steps 305 Fraction of annealing steps for warmup phase. 306 Default: 0.3 307 """ 308 assert ( 309 pct_start < 1.0 and pct_start > 0.0 310 ), "warmup_steps must be between 0 and nsteps" 311 312 anneal_type = anneal_parameters.get("type", "cosine").lower() 313 """@keyword[fennol_md] annealing/type 314 Type of annealing schedule (linear, cosine_onecycle). 315 Default: cosine 316 """ 317 if anneal_type == "linear": 318 schedule = optax.linear_onecycle_schedule( 319 peak_value=1.0, 320 div_factor=1.0 / init_factor, 321 final_div_factor=1.0 / final_factor, 322 transition_steps=int(anneal_steps * nsteps), 323 pct_start=pct_start, 324 pct_final=1.0, 325 ) 326 elif anneal_type == "cosine_onecycle": 327 schedule = optax.cosine_onecycle_schedule( 328 peak_value=1.0, 329 div_factor=1.0 / init_factor, 330 final_div_factor=1.0 / final_factor, 331 transition_steps=int(anneal_steps * nsteps), 332 pct_start=pct_start, 333 ) 334 else: 335 raise ValueError(f"Unknown anneal_type {anneal_type}") 336 337 state["rng_key"] = rng_key 338 state["istep_anneal"] = 0 339 340 rng_key, v_key = jax.random.split(rng_key) 341 Tscale = schedule(0) 342 print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K") 343 vel = ( 344 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 345 * (kT * Tscale / mass[:, None]) ** 0.5 346 ) 347 348 def thermostat(vel, state): 349 rng_key, noise_key = jax.random.split(state["rng_key"]) 350 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 351 352 Tscale = schedule(state["istep_anneal"]) ** 0.5 353 vel = a1 * vel + a2 * Tscale * noise 354 return vel, { 355 **state, 356 "rng_key": rng_key, 357 "istep_anneal": state["istep_anneal"] + 1, 358 } 359 360 else: 361 raise ValueError(f"Unknown thermostat {thermostat_name}") 362 363 # remove center of mass velocity 364 massprop = system_data["mass_Da"]/system_data["totmass_Da"] 365 if nbeads is None: 366 vel = vel.reshape(-1,system_data["nat"], 3) 367 vcom = jnp.sum(vel * massprop[None,:, None], axis=1, keepdims=True) 368 vel = (vel - vcom).reshape(-1,3) 369 else: 370 vcom = jnp.sum(vel[0]*massprop[:,None], axis=0, keepdims=True) 371 vel = vel.at[0].set(vel[0] - vcom) 372 373 ekin_instant = str(simulation_parameters.get("ekin_instant", default_ekin_instant)).upper() 374 """@keyword[fennol_md] ekin_instant 375 Where in the time step to compute the instantaneous kinetic energy. 376 Options: 'B' (end of step), 'O' (after thermostat). 377 Default: 'O' for LGV and QTB thermostats, 'B' otherwise. 378 """ 379 assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'" 380 381 return thermostat, postprocess, state, vel,thermostat_name, ekin_instant
def
initialize_qtb( simulation_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
384def initialize_qtb( 385 simulation_parameters, 386 system_data, 387 fprec, 388 dt, 389 mass, 390 gamma, 391 kT, 392 species, 393 rng_key, 394 adaptive, 395 compute_thermostat_energy=False, 396): 397 state = {} 398 post_state = {} 399 qtb_parameters = simulation_parameters.get("qtb", {}) 400 verbose = qtb_parameters.get("verbose", False) 401 """@keyword[fennol_md] qtb/verbose 402 Print verbose QTB thermostat information. 403 Default: False 404 """ 405 if compute_thermostat_energy: 406 state["thermostat_energy"] = 0.0 407 408 mass = jnp.asarray(mass, dtype=fprec) 409 410 nat = species.shape[0] 411 # define type indices 412 species_set = set(species) 413 nspecies = len(species_set) 414 idx = {sp: i for i, sp in enumerate(species_set)} 415 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 416 type_labels = [PERIODIC_TABLE[sp] for sp in species_set] 417 418 adapt_groups = qtb_parameters.get("adapt_groups", {}) 419 """@keyword[fennol_md] qtb/adapt_groups 420 Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes). 421 Default: {} 422 """ 423 ntypes = nspecies 424 allgroups = set() 425 for groupname, interval in adapt_groups.items(): 426 indices = read_tinker_interval(interval) 427 assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups" 428 allgroups.update(indices) 429 species_in_group = set(species[indices]) 430 idx = {sp: i for i, sp in enumerate(species_in_group)} 431 for i in indices: 432 type_idx[i] = ntypes + idx[species[i]] 433 type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group] 434 ntypes = len(type_labels) 435 436 n_of_type = np.zeros(ntypes, dtype=np.int32) 437 for i in range(ntypes): 438 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 439 print(f"# QTB: n({type_labels[i]}) =", n_of_type[i]) 440 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 441 mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type 442 443 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 444 """@keyword[fennol_md] qtb/niter_deconv_kin 445 Number of iterations for kinetic energy deconvolution. 446 Default: 7 447 """ 448 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 449 """@keyword[fennol_md] qtb/niter_deconv_pot 450 Number of iterations for potential energy deconvolution. 451 Default: 20 452 """ 453 corr_kin = qtb_parameters.get("corr_kin", -1) 454 """@keyword[fennol_md] qtb/corr_kin 455 Kinetic energy correction factor for QTB (-1 for automatic). 456 Default: -1 457 """ 458 do_corr_kin = corr_kin <= 0 459 if do_corr_kin: 460 corr_kin = 1.0 461 state["corr_kin"] = corr_kin 462 post_state["corr_kin_prev"] = corr_kin 463 post_state["do_corr_kin"] = do_corr_kin 464 post_state["isame_kin"] = 0 465 466 # spectra parameters 467 omegasmear = np.pi / dt / 100.0 468 Tseg = qtb_parameters.get("tseg", 1.0 / us.PS) 469 """@keyword[fennol_md] qtb/tseg 470 Time segment length for QTB spectrum calculation. 471 Default: 1.0 ps 472 """ 473 nseg = int(Tseg / dt) 474 Tseg = nseg * dt 475 dom = 2 * np.pi / (3 * Tseg) 476 omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1) 477 """@keyword[fennol_md] qtb/omegacut 478 Cutoff frequency for QTB spectrum. 479 Default: 15000.0 cm⁻¹ 480 """ 481 nom = int(omegacut / dom) 482 omega = dom * np.arange((3 * nseg) // 2 + 1) 483 cutoff = jnp.asarray( 484 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 485 ) 486 assert ( 487 omegacut < omega[-1] 488 ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1" 489 490 # initialize gammar 491 assert ( 492 gamma < 0.5 * omegacut 493 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 494 gammar_min = qtb_parameters.get("gammar_min", 0.1) 495 """@keyword[fennol_md] qtb/gammar_min 496 Minimum value for QTB gamma ratio coefficients. 497 Default: 0.1 498 """ 499 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 500 gammar = np.ones((ntypes, nom), dtype=float) 501 try: 502 for i, sp in enumerate(type_labels): 503 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 504 data = np.loadtxt(f"QTB_spectra_{sp}.out") 505 gammar[i] = data[:, 4]/(gamma*us.THZ) 506 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 507 except Exception as e: 508 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 509 gammar[:,:] = 1.0 510 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 511 512 # Ornstein-Uhlenbeck correction for colored noise 513 a1 = np.exp(-gamma * dt) 514 OUcorr = jnp.asarray( 515 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 516 dtype=fprec, 517 ) 518 519 # hbar schedule 520 classical_kernel = qtb_parameters.get("classical_kernel", False) 521 """@keyword[fennol_md] qtb/classical_kernel 522 Use classical instead of quantum kernel for QTB. 523 Default: False 524 """ 525 hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR 526 """@keyword[fennol_md] qtb/hbar 527 Reduced Planck constant scaling factor for quantum effects. 528 Default: 1.0 a.u. 529 """ 530 u = 0.5 * hbar * np.abs(omega) / kT 531 theta = kT * np.ones_like(omega) 532 if hbar > 0: 533 theta[1:] *= u[1:] / np.tanh(u[1:]) 534 theta = jnp.asarray(theta, dtype=fprec) 535 536 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 537 del rng_key 538 post_state["white_noise"] = jax.random.normal( 539 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 540 ) 541 542 startsave = qtb_parameters.get("startsave", 1) 543 """@keyword[fennol_md] qtb/startsave 544 Start saving QTB statistics after this many segments. 545 Default: 1 546 """ 547 counter = Counter(nseg, startsave=startsave) 548 state["istep"] = 0 549 post_state["nadapt"] = 0 550 post_state["nsample"] = 0 551 552 write_spectra = qtb_parameters.get("write_spectra", True) 553 """@keyword[fennol_md] qtb/write_spectra 554 Write QTB spectral analysis output files. 555 Default: True 556 """ 557 do_compute_spectra = write_spectra or adaptive 558 559 if do_compute_spectra: 560 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 561 562 post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec) 563 post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec) 564 post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec) 565 post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec) 566 post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 567 post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 568 post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 569 post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 570 571 if not adaptive: 572 update_gammar = lambda x: x 573 else: 574 # adaptation parameters 575 skipseg = qtb_parameters.get("skipseg", 1) 576 """@keyword[fennol_md] qtb/skipseg 577 Number of segments to skip before starting adaptive QTB. 578 Default: 1 579 """ 580 581 adaptation_method = ( 582 str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip() 583 ) 584 """@keyword[fennol_md] qtb/adaptation_method 585 Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF). 586 Default: SIMPLE 587 """ 588 if adaptation_method == "SIMPLE": 589 agamma = qtb_parameters.get("agamma", 0.1) 590 """@keyword[fennol_md] qtb/agamma 591 Learning rate for adaptive QTB gamma update. 592 Default: 0.1 593 """ 594 assert agamma > 0, "agamma must be positive" 595 a1_ad = agamma * Tseg # * gamma 596 print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}") 597 598 def update_gammar(post_state): 599 g = post_state["dFDT"] 600 gammar = post_state["gammar"] - a1_ad * g 601 gammar = jnp.maximum(gammar_min, gammar) 602 return {**post_state, "gammar": gammar} 603 604 elif adaptation_method == "RATIO": 605 tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 606 """@keyword[fennol_md] qtb/tau_ad 607 Adaptation time constant for momentum averaging. 608 Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF) 609 """ 610 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) 611 """@keyword[fennol_md] qtb/tau_s 612 Second moment time constant for variance averaging. 613 Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF) 614 """ 615 assert tau_ad > 0, "tau_ad must be positive" 616 print( 617 f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 618 ) 619 b1 = np.exp(-Tseg / tau_ad) 620 b2 = np.exp(-Tseg / tau_s) 621 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 622 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 623 post_state["n_adabelief"] = 0 624 625 def update_gammar(post_state): 626 n_adabelief = post_state["n_adabelief"] + 1 627 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 628 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 629 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 630 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 631 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 632 gammar = Cvf / (mCvv + 1.0e-8) 633 gammar = jnp.maximum(gammar_min, gammar) 634 return { 635 **post_state, 636 "gammar": gammar, 637 "mCvv_m": mCvv_m, 638 "Cvf_m": Cvf_m, 639 "n_adabelief": n_adabelief, 640 } 641 642 elif adaptation_method == "ADABELIEF": 643 agamma = qtb_parameters.get("agamma", 0.1) 644 tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS) 645 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) 646 assert tau_ad > 0, "tau_ad must be positive" 647 assert tau_s > 0, "tau_s must be positive" 648 assert agamma > 0, "agamma must be positive" 649 print( 650 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 651 ) 652 653 a1_ad = agamma * gamma # * Tseg #* gamma 654 b1 = np.exp(-Tseg / tau_ad) 655 b2 = np.exp(-Tseg / tau_s) 656 post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec) 657 post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec) 658 post_state["n_adabelief"] = 0 659 660 def update_gammar(post_state): 661 n_adabelief = post_state["n_adabelief"] + 1 662 dFDT = post_state["dFDT"] 663 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 664 dFDT_s = ( 665 post_state["dFDT_s"] * b2 666 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 667 + 1.0e-8 668 ) 669 # bias correction 670 mt = dFDT_m / (1.0 - b1**n_adabelief) 671 st = dFDT_s / (1.0 - b2**n_adabelief) 672 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 673 gammar = jnp.maximum(gammar_min, gammar) 674 return { 675 **post_state, 676 "gammar": gammar, 677 "dFDT_m": dFDT_m, 678 "n_adabelief": n_adabelief, 679 "dFDT_s": dFDT_s, 680 } 681 else: 682 raise ValueError( 683 f"Unknown adaptation method {adaptation_method}." 684 ) 685 686 ##################### 687 # RESTART 688 restart_file = system_data["name"]+".qtb.restart" 689 if os.path.exists(restart_file): 690 with open(restart_file, "rb") as f: 691 data = pickle.load(f) 692 state["corr_kin"] = data["corr_kin"] 693 post_state["corr_kin_prev"] = data["corr_kin"] 694 post_state["isame_kin"] = data["isame_kin"] 695 post_state["do_corr_kin"] = data["do_corr_kin"] 696 print(f"# Restored QTB state from {restart_file}") 697 698 def write_qtb_restart(state, post_state): 699 with open(restart_file, "wb") as f: 700 pickle.dump( 701 { 702 "corr_kin": state["corr_kin"], 703 "corr_kin_prev": post_state["corr_kin_prev"], 704 "isame_kin": post_state["isame_kin"], 705 "do_corr_kin": post_state["do_corr_kin"], 706 }, 707 f, 708 ) 709 ###################### 710 711 def compute_corr_pot(niter=20, verbose=False): 712 if classical_kernel or hbar == 0: 713 return np.ones(nom) 714 715 s_0 = np.array((theta / kT * cutoff)[:nom]) 716 s_out, s_rec, _ = deconvolute_spectrum( 717 s_0, 718 omega[:nom], 719 gamma, 720 niter, 721 kernel=kernel_lorentz_pot, 722 trans=True, 723 symmetrize=True, 724 verbose=verbose, 725 ) 726 corr_pot = 1.0 + (s_out - s_0) / s_0 727 columns = np.column_stack( 728 (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 729 ) 730 np.savetxt( 731 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 732 ) 733 return corr_pot 734 735 def compute_corr_kin(post_state, niter=7, verbose=False): 736 if not post_state["do_corr_kin"]: 737 return post_state["corr_kin_prev"], post_state 738 if classical_kernel or hbar == 0: 739 return 1.0, post_state 740 741 K_D = post_state.get("K_D", None) 742 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 743 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 744 s_out, s_rec, K_D = deconvolute_spectrum( 745 s_0, 746 omega[:nom], 747 gamma, 748 niter, 749 kernel=kernel_lorentz, 750 trans=False, 751 symmetrize=True, 752 verbose=verbose, 753 K_D=K_D, 754 ) 755 s_out = s_out * theta[:nom] / kT 756 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 757 mCvvsum = mCvv.sum() 758 rec_ratio = mCvvsum / s_rec.sum() 759 if rec_ratio < 0.95 or rec_ratio > 1.05: 760 print( 761 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 762 ) 763 return post_state["corr_kin_prev"], post_state 764 765 corr_kin = mCvvsum / s_out.sum() 766 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 767 isame_kin = post_state["isame_kin"] + 1 768 else: 769 isame_kin = 0 770 771 # print("# corr_kin: ", corr_kin) 772 do_corr_kin = post_state["do_corr_kin"] 773 if isame_kin > 10: 774 print( 775 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 776 ) 777 do_corr_kin = False 778 779 return corr_kin, { 780 **post_state, 781 "corr_kin_prev": corr_kin, 782 "isame_kin": isame_kin, 783 "do_corr_kin": do_corr_kin, 784 "K_D": K_D, 785 } 786 787 @jax.jit 788 def ff_kernel(post_state): 789 if classical_kernel: 790 kernel = cutoff * (2 * gamma * kT / dt) 791 else: 792 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 793 gamma_ratio = jnp.concatenate( 794 ( 795 post_state["gammar"].T * post_state["corr_pot"][:, None], 796 jnp.ones( 797 (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype 798 ), 799 ), 800 axis=0, 801 ) 802 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 803 804 @jax.jit 805 def refresh_force(post_state): 806 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 807 white_noise = jnp.concatenate( 808 ( 809 post_state["white_noise"][nseg:], 810 jax.random.normal( 811 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 812 ), 813 ), 814 axis=0, 815 ) 816 amplitude = ff_kernel(post_state) ** 0.5 817 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 818 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 819 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 820 821 @jax.jit 822 def compute_spectra(force, vel, post_state): 823 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 824 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 825 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 826 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 827 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 828 829 mCvv = ( 830 (dt / 3.0) 831 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 832 * mass_idx[:, None] 833 / n_of_type[:, None] 834 ) 835 Cvf = ( 836 (dt / 3.0) 837 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 838 / n_of_type[:, None] 839 ) 840 Cff = ( 841 (dt / 3.0) 842 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 843 / n_of_type[:, None] 844 ) 845 dFDT = mCvv * post_state["gammar"] - Cvf 846 847 nsinv = 1.0 / post_state["nsample"] 848 b1 = 1.0 - nsinv 849 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 850 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 851 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 852 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 853 854 return { 855 **post_state, 856 "mCvv": mCvv, 857 "Cvf": Cvf, 858 "Cff": Cff, 859 "dFDT": dFDT, 860 "dFDT_avg": dFDT_avg, 861 "mCvv_avg": mCvv_avg, 862 "Cvfg_avg": Cvfg_avg, 863 "Cff_avg": Cff_avg, 864 } 865 866 def write_spectra_to_file(post_state): 867 mCvv_avg = np.array(post_state["mCvv_avg"]) 868 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 869 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 870 dFDT_avg = np.array(post_state["dFDT_avg"]) 871 gammar = np.array(post_state["gammar"]) 872 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 873 for i, sp in enumerate(type_labels): 874 ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 875 columns = np.column_stack( 876 ( 877 omega[:nom] * us.CM1, 878 mCvv_avg[i], 879 Cvfg_avg[i], 880 dFDT_avg[i], 881 gammar[i] * gamma * us.THZ, 882 Cff_avg[i] * ff_scale, 883 Cff_theo[i] * ff_scale, 884 ) 885 ) 886 np.savetxt( 887 f"QTB_spectra_{sp}.out", 888 columns, 889 fmt="%12.6f", 890 header="#omega mCvv Cvf dFDT gammar Cff", 891 ) 892 if verbose: 893 print("# QTB spectra written.") 894 895 if compute_thermostat_energy: 896 state["qtb_energy_flux"] = 0.0 897 898 @jax.jit 899 def thermostat(vel, state): 900 istep = state["istep"] 901 dvel = dt * state["force"][istep] / mass[:, None] 902 new_vel = vel * a1 + dvel 903 new_state = {**state, "istep": istep + 1} 904 if do_compute_spectra: 905 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 906 new_state["vel"] = vel2 907 if compute_thermostat_energy: 908 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 909 ekcorr = ( 910 0.5 911 * (mass[:, None] * new_vel**2).sum() 912 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 913 ) 914 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 915 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 916 return new_vel, new_state 917 918 @jax.jit 919 def postprocess_work(state, post_state): 920 if do_compute_spectra: 921 post_state = compute_spectra(state["force"], state["vel"], post_state) 922 if adaptive: 923 post_state = jax.lax.cond( 924 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 925 ) 926 new_force, post_state = refresh_force(post_state) 927 return {**state, "force": new_force}, post_state 928 929 def postprocess(state, post_state): 930 counter.increment() 931 if not counter.is_reset_step: 932 return state, post_state 933 post_state["nadapt"] += 1 934 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 935 if verbose: 936 print("# Refreshing QTB forces.") 937 state, post_state = postprocess_work(state, post_state) 938 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 939 state["istep"] = 0 940 if write_spectra: 941 write_spectra_to_file(post_state) 942 write_qtb_restart(state, post_state) 943 return state, post_state 944 945 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 946 947 state["force"], post_state = refresh_force(post_state) 948 return thermostat, (postprocess, post_state), state