fennol.md.integrate
1import time 2import math 3import os 4 5import numpy as np 6import jax 7import jax.numpy as jnp 8 9from .thermostats import get_thermostat 10from .barostats import get_barostat 11from .colvars import setup_colvars 12from .spectra import initialize_ir_spectrum 13 14from .utils import load_dynamics_restart, get_restart_file,optimize_fire2, us 15from .initial import load_model, load_system_data, initialize_preprocessing 16 17 18def initialize_dynamics(simulation_parameters, fprec, rng_key): 19 ### LOAD MODEL 20 model = load_model(simulation_parameters) 21 model_energy_unit = us.get_multiplier(model.energy_unit) 22 23 ### Get the coordinates and species from the xyz file 24 system_data, conformation = load_system_data(simulation_parameters, fprec) 25 system_data["model_energy_unit"] = model_energy_unit 26 system_data["model_energy_unit_str"] = model.energy_unit 27 28 ### FINISH BUILDING conformation 29 do_restart = os.path.exists(get_restart_file(system_data)) 30 if do_restart: 31 ### RESTART FROM PREVIOUS DYNAMICS 32 restart_data = load_dynamics_restart(system_data) 33 print("# RESTARTING FROM PREVIOUS DYNAMICS") 34 model.preproc_state = restart_data["preproc_state"] 35 conformation["coordinates"] = restart_data["coordinates"] 36 else: 37 restart_data = {} 38 39 ### INITIALIZE PREPROCESSING 40 preproc_state, conformation = initialize_preprocessing( 41 simulation_parameters, model, conformation, system_data 42 ) 43 44 minimize = simulation_parameters.get("xyz_input/minimize", False) 45 """@keyword[fennol_md] xyz_input/minimize 46 Perform energy minimization before dynamics. 47 Default: False 48 """ 49 if minimize and not do_restart: 50 assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems" 51 model.preproc_state = preproc_state 52 convert = us.KCALPERMOL / model_energy_unit 53 nat = system_data["nat"] 54 def energy_force_fn(coordinates): 55 inputs = {**conformation, "coordinates": coordinates} 56 e, f, _ = model.energy_and_forces( 57 **inputs, gpu_preprocessing=True 58 ) 59 e = float(e[0]) * convert / nat 60 f = np.array(f) * convert 61 return e, f 62 tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL 63 """@keyword[fennol_md] xyz_input/minimize_ftol 64 Force tolerance for minimization. 65 Default: 0.1 kcal/mol/Å 66 """ 67 print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A") 68 conformation["coordinates"], success = optimize_fire2( 69 conformation["coordinates"], 70 energy_force_fn, 71 atol=tol, 72 max_disp=0.02, 73 ) 74 if success: 75 print("# Minimization successful") 76 else: 77 print("# Warning: Minimization failed, continuing with last configuration") 78 # write the minimized coordinates as an xyz file 79 from ..utils.io import write_xyz_frame 80 with open(system_data["name"]+".opt.xyz", "w") as f: 81 write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None)) 82 print("# Minimized configuration written to", system_data["name"]+".opt.xyz") 83 preproc_state = model.preproc_state 84 conformation = model.preprocessing.process(preproc_state, conformation) 85 system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy() 86 87 ### get dynamics parameters 88 dt = simulation_parameters.get("dt") 89 """@keyword[fennol_md] dt 90 Integration time step. Required parameter. 91 Type: float, Required 92 """ 93 dt2 = 0.5 * dt 94 mass = system_data["mass"] 95 densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3) 96 nat = system_data["nat"] 97 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 98 ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3) 99 100 nreplicas = system_data.get("nreplicas", 1) 101 nbeads = system_data.get("nbeads", None) 102 if nbeads is not None: 103 nreplicas = nbeads 104 dtm = dtm[None, :, :] 105 106 fix_com = simulation_parameters.get("fix_com", False) 107 if fix_com: 108 massprop = system_data["mass_Da"]/system_data["totmass_Da"] 109 print("# Reference frame fixed to center of mass") 110 x = conformation["coordinates"].reshape(-1, nat, 3) 111 if nbeads is None: 112 xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True) 113 else: 114 xcom = jnp.sum(x*massprop[None,:,None], axis=(0,1), keepdims=True)/nbeads 115 conformation["coordinates"] = (x - xcom).reshape(-1, 3) 116 117 ### INITIALIZE DYNAMICS STATE 118 system = {"coordinates": conformation["coordinates"]} 119 dyn_state = { 120 "istep": 0, 121 "dt": dt, 122 "pimd": nbeads is not None, 123 "preproc_state": preproc_state, 124 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 125 } 126 gradient_keys = ["coordinates"] 127 thermo_updates = [] 128 129 ### INITIALIZE THERMOSTAT 130 thermostat_rng, rng_key = jax.random.split(rng_key) 131 ( 132 thermostat, 133 thermostat_post, 134 thermostat_state, 135 initial_vel, 136 dyn_state["thermostat_name"], 137 ekin_instant, 138 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 139 do_thermostat_post = thermostat_post is not None 140 if do_thermostat_post: 141 thermostat_post, post_state = thermostat_post 142 dyn_state["thermostat_post_state"] = post_state 143 144 assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'" 145 146 system["thermostat"] = thermostat_state 147 148 if "initial_velocities" in simulation_parameters: 149 initial_vel_file = simulation_parameters["initial_velocities"] 150 print(f"# Using initial velocities from {initial_vel_file}") 151 v0 = jnp.array(np.loadtxt(initial_vel_file).reshape(-1,nat, 3).astype(fprec)) 152 if v0.shape[0] == 1: 153 if nbeads is None: 154 v0 = v0.repeat(nreplicas, axis=0) 155 else: 156 v0 = initial_vel.at[0].set(v0[0]) 157 assert v0.shape == (nreplicas, nat, 3), f"initial_velocities file must have shape (nat, 3) or (nreplicas, nat, 3), but got {v0.shape}" 158 initial_vel = v0 159 if nbeads is None: 160 initial_vel = initial_vel.reshape(-1, 3) 161 162 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 163 164 ### PBC 165 pbc_data = system_data.get("pbc", None) 166 if pbc_data is not None: 167 ### INITIALIZE BAROSTAT 168 barostat_key, rng_key = jax.random.split(rng_key) 169 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 170 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 171 ) 172 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 173 system["barostat"] = barostat_state 174 system["cell"] = conformation["cells"][0] 175 if estimate_pressure: 176 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0) 177 """@keyword[fennol_md] pressure_o_weight 178 Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator. 179 Default: 1.0 180 """ 181 assert ( 182 0.0 <= pressure_o_weight <= 1.0 183 ), "pressure_o_weight must be between 0 and 1" 184 gradient_keys.append("strain") 185 print("# Estimate pressure: ", estimate_pressure) 186 else: 187 estimate_pressure = False 188 variable_cell = False 189 190 def thermo_update_ensemble(x, v, system): 191 v, thermostat_state = thermostat(v, system["thermostat"]) 192 return x, v, {**system, "thermostat": thermostat_state} 193 194 dyn_state["estimate_pressure"] = estimate_pressure 195 dyn_state["variable_cell"] = variable_cell 196 thermo_updates.append(thermo_update_ensemble) 197 198 if estimate_pressure: 199 use_average_Pkin = simulation_parameters.get("use_average_Pkin", False) 200 """@keyword[fennol_md] use_average_Pkin 201 Use time-averaged kinetic energy for pressure estimation instead of instantaneous values. 202 Default: False 203 """ 204 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 205 if is_qtb and use_average_Pkin: 206 raise ValueError( 207 "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False" 208 ) 209 210 211 ### ENERGY ENSEMBLE 212 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 213 """@keyword[fennol_md] etot_ensemble_key 214 Key for energy ensemble calculation. Enables computation of ensemble weights. 215 Default: None 216 """ 217 218 ### COLVARS 219 colvars_definitions = simulation_parameters.get("colvars", None) 220 """@keyword[fennol_md] colvars 221 Collective variables definitions for enhanced sampling or monitoring. 222 Default: None 223 """ 224 use_colvars = colvars_definitions is not None 225 if use_colvars: 226 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 227 dyn_state["colvars"] = colvars_names 228 229 ### IR SPECTRUM 230 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 231 """@keyword[fennol_md] ir_spectrum 232 Calculate infrared spectrum from molecular dipole moment time series. 233 Default: False 234 """ 235 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 236 if do_ir_spectrum: 237 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 238 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 239 simulation_parameters, system_data, fprec, dt, is_qtb 240 ) 241 dyn_state["ir_spectrum"] = ir_state 242 243 ### BUILD GRADIENT FUNCTION 244 energy_and_gradient = model.get_gradient_function( 245 *gradient_keys, jit=True, variables_as_input=True 246 ) 247 248 ### COLLECT THERMO UPDATES 249 if len(thermo_updates) == 1: 250 thermo_update = thermo_updates[0] 251 else: 252 253 def thermo_update(x, v, system): 254 for update in thermo_updates: 255 x, v, system = update(x, v, system) 256 return x, v, system 257 258 ### RING POLYMER INITIALIZATION 259 if nbeads is not None: 260 cay_correction = simulation_parameters.get("cay_correction", True) 261 """@keyword[fennol_md] cay_correction 262 Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation. 263 Default: True 264 """ 265 omk = system_data["omk"] 266 eigmat = system_data["eigmat"] 267 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 268 if cay_correction: 269 axx = jnp.asarray(2 * cayfact) 270 axv = jnp.asarray(dt * cayfact) 271 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 272 else: 273 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 274 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 275 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 276 277 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 278 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 279 system["coordinates"] = eigx 280 281 ############################################### 282 ### DEFINE UPDATE FUNCTION 283 @jax.jit 284 def update_conformation(conformation, system): 285 x = system["coordinates"] 286 if nbeads is not None: 287 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 288 nbeads**0.5 289 ) 290 conformation = {**conformation, "coordinates": x} 291 if variable_cell: 292 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 293 294 295 296 return conformation 297 298 ############################################### 299 ### DEFINE INTEGRATION FUNCTIONS 300 def integrate_A_half(x0, v0): 301 if nbeads is None: 302 return x0 + dt2 * v0, v0 303 304 # update coordinates and velocities of a free ring polymer for a half time step 305 eigx_c = x0[0] + dt2 * v0[0] 306 eigv_c = v0[0] 307 eigx = x0[1:] * axx + v0[1:] * axv 308 eigv = x0[1:] * avx + v0[1:] * axx 309 310 return ( 311 jnp.concatenate((eigx_c[None], eigx), axis=0), 312 jnp.concatenate((eigv_c[None], eigv), axis=0), 313 ) 314 315 @jax.jit 316 def integrate(system): 317 x = system["coordinates"] 318 v = system["vel"] + dtm * system["forces"] 319 x, v = integrate_A_half(x, v) 320 x, v, system = thermo_update(x, v, system) 321 x, v = integrate_A_half(x, v) 322 323 if fix_com: 324 if nbeads is None: 325 x = x.reshape(-1, nat, 3) 326 xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True) 327 x = (x - xcom).reshape(-1, 3) 328 else: 329 xcom = jnp.sum(x[0]*massprop[:,None], axis=0, keepdims=True) 330 x = x.at[0].set(x[0] - xcom) 331 332 return {**system, "coordinates": x, "vel": v} 333 334 ############################################### 335 ### DEFINE OBSERVABLE FUNCTION 336 @jax.jit 337 def update_observables(system, conformation): 338 ### POTENTIAL ENERGY AND FORCES 339 epot, de, out = energy_and_gradient(model.variables, conformation) 340 out["forces"] = -de["coordinates"] 341 epot = epot / model_energy_unit 342 de = {k: v / model_energy_unit for k, v in de.items()} 343 forces = -de["coordinates"] 344 345 if nbeads is not None: 346 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 347 forces = jnp.einsum( 348 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 349 ) * (1.0 / nbeads**0.5) 350 351 352 system = { 353 **system, 354 "epot": jnp.mean(epot), 355 "forces": forces, 356 "energy_gradients": de, 357 } 358 359 ### KINETIC ENERGY 360 v = system["vel"] 361 if ekin_instant == "B": 362 v = v + 0.5 * dtm * forces 363 364 if nbeads is None: 365 corr_kin = system["thermostat"].get("corr_kin", 1.0) 366 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 367 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 368 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 369 ) 370 else: 371 ek_c = 0.5 * jnp.sum( 372 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 373 ) 374 ek = ek_c - 0.5 * jnp.sum( 375 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 376 axis=(0, 1), 377 ) 378 system["ek_c"] = jnp.trace(ek_c) 379 380 system["ek"] = jnp.trace(ek) 381 system["ek_tensor"] = ek 382 383 if estimate_pressure: 384 if use_average_Pkin: 385 ek = ek_avg 386 elif pressure_o_weight != 1.0: 387 v = system["vel"] + 0.5 * dtm * system["forces"] 388 if nbeads is None: 389 corr_kin = system["thermostat"].get("corr_kin", 1.0) 390 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 391 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 392 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 393 ) 394 else: 395 ek_c = 0.5 * jnp.sum( 396 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 397 axis=0, 398 ) 399 ek = ek_c - 0.5 * jnp.sum( 400 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 401 axis=(0, 1), 402 ) 403 b = pressure_o_weight 404 ek = (1.0 - b) * ek + b * system["ek_tensor"] 405 406 vir = jnp.mean(de["strain"], axis=0) 407 system["virial"] = vir 408 out["virial_tensor"] = vir * model_energy_unit 409 410 volume = jnp.abs(jnp.linalg.det(system["cell"])) 411 Pres = ek*(2./volume) - vir/volume 412 system["pressure_tensor"] = Pres 413 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 414 if variable_cell: 415 density = densmass / volume 416 system["density"] = density 417 system["volume"] = volume 418 419 if ensemble_key is not None: 420 kT = system_data["kT"] 421 dE = ( 422 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 423 ) 424 system["ensemble_weights"] = -dE / kT 425 426 if "total_dipole" in out: 427 if nbeads is None: 428 system["total_dipole"] = out["total_dipole"][0] 429 else: 430 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 431 432 if use_colvars: 433 coords = system["coordinates"].reshape(-1, nat, 3)[0] 434 colvars = {} 435 for colvar_name, colvar_calc in colvars_calculators.items(): 436 colvars[colvar_name] = colvar_calc(coords) 437 system["colvars"] = colvars 438 439 return system, out 440 441 ############################################### 442 ### IR SPECTRUM 443 if do_ir_spectrum: 444 # @jax.jit 445 # def update_dipole(ir_state,system,conformation): 446 # def mumodel(coords): 447 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 448 # if nbeads is None: 449 # return out["total_dipole"][0] 450 # return out["total_dipole"].sum(axis=0) 451 # dmudqmodel = jax.jacobian(mumodel) 452 453 # dmudq = dmudqmodel(conformation["coordinates"]) 454 # # print(dmudq.shape) 455 # if nbeads is None: 456 # vel = system["vel"].reshape(-1,1,nat,3)[0] 457 # mudot = (vel*dmudq).sum(axis=(1,2)) 458 # else: 459 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 460 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 461 # ) 462 # # vel = system["vel"][0].reshape(1,nat,3) 463 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 464 465 # ir_state = save_dipole(mudot,ir_state) 466 # return ir_state 467 @jax.jit 468 def update_conformation_ir(conformation, system): 469 conformation = { 470 **conformation, 471 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 472 "natoms": jnp.asarray([nat]), 473 "batch_index": jnp.asarray([0] * nat), 474 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 475 } 476 if variable_cell: 477 conformation["cells"] = system["cell"][None, :, :] 478 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 479 None, :, : 480 ] 481 return conformation 482 483 @jax.jit 484 def update_dipole(ir_state, system, conformation): 485 if model_ir is not None: 486 out = model_ir._apply(model_ir.variables, conformation) 487 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 488 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 489 else: 490 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 491 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 492 if nbeads is not None: 493 q = jnp.mean(q, axis=0) 494 dip = jnp.mean(dip, axis=0) 495 vel = system["vel"][0] 496 pos = system["coordinates"][0] 497 else: 498 q = q[0] 499 dip = dip[0] 500 vel = system["vel"].reshape(-1, nat, 3)[0] 501 pos = system["coordinates"].reshape(-1, nat, 3)[0] 502 503 if pbc_data is not None: 504 cell_reciprocal = ( 505 conformation["cells"][0], 506 conformation["reciprocal_cells"][0], 507 ) 508 else: 509 cell_reciprocal = None 510 511 ir_state = save_dipole( 512 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 513 ) 514 return ir_state 515 516 ############################################### 517 ### GRAPH UPDATES 518 519 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 520 """@keyword[fennol_md] nblist_verbose 521 Print verbose information about neighbor list updates and reallocations. 522 Default: False 523 """ 524 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 525 """@keyword[fennol_md] nblist_stride 526 Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0. 527 Default: -1 528 """ 529 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) 530 """@keyword[fennol_md] nblist_warmup_time 531 Time period for neighbor list warmup before using skin updates. 532 Default: -1.0 533 """ 534 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 535 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 536 """@keyword[fennol_md] nblist_skin 537 Neighbor list skin distance for efficient updates (in Angstroms). 538 Default: -1.0 539 """ 540 if nblist_skin > 0: 541 if nblist_stride <= 0: 542 ## reference skin parameters at 300K (from Tinker-HP) 543 ## => skin of 2 A gives you 40 fs without complete rebuild 544 t_ref = 40.0 /us.FS # FS 545 nblist_skin_ref = 2.0 # A 546 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 547 print( 548 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 549 ) 550 551 if nblist_skin <= 0: 552 nblist_stride = 1 553 554 dyn_state["nblist_countdown"] = 0 555 dyn_state["print_skin_activation"] = nblist_warmup > 0 556 557 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 558 nblist_countdown = dyn_state["nblist_countdown"] 559 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 560 ### FULL NBLIST REBUILD 561 dyn_state["nblist_countdown"] = nblist_stride - 1 562 preproc_state = dyn_state["preproc_state"] 563 conformation = model.preprocessing.process( 564 preproc_state, update_conformation(conformation, system) 565 ) 566 preproc_state, state_up, conformation, overflow = ( 567 model.preprocessing.check_reallocate(preproc_state, conformation) 568 ) 569 dyn_state["preproc_state"] = preproc_state 570 if nblist_verbose and overflow: 571 print("step", istep, ", nblist overflow => reallocating nblist") 572 print("size updates:", state_up) 573 574 if do_ir_spectrum and model_ir is not None: 575 conformation_ir = model_ir.preprocessing.process( 576 dyn_state["preproc_state_ir"], 577 update_conformation_ir(dyn_state["conformation_ir"], system), 578 ) 579 ( 580 dyn_state["preproc_state_ir"], 581 _, 582 dyn_state["conformation_ir"], 583 overflow, 584 ) = model_ir.preprocessing.check_reallocate( 585 dyn_state["preproc_state_ir"], conformation_ir 586 ) 587 588 else: 589 ### SKIN UPDATE 590 if dyn_state["print_skin_activation"]: 591 if nblist_verbose: 592 print( 593 "step", 594 istep, 595 ", end of nblist warmup phase => activating skin updates", 596 ) 597 dyn_state["print_skin_activation"] = False 598 599 dyn_state["nblist_countdown"] = nblist_countdown - 1 600 conformation = model.preprocessing.update_skin( 601 update_conformation(conformation, system) 602 ) 603 if do_ir_spectrum and model_ir is not None: 604 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 605 update_conformation_ir(dyn_state["conformation_ir"], system) 606 ) 607 608 return conformation, dyn_state 609 610 ################################################ 611 ### DEFINE STEP FUNCTION 612 def step(istep, dyn_state, system, conformation, force_preprocess=False): 613 614 dyn_state = { 615 **dyn_state, 616 "istep": dyn_state["istep"] + 1, 617 } 618 619 ### INTEGRATE EQUATIONS OF MOTION 620 system = integrate(system) 621 622 ### UPDATE CONFORMATION AND GRAPHS 623 conformation, dyn_state = update_graphs( 624 istep, dyn_state, system, conformation, force_preprocess 625 ) 626 627 ## COMPUTE FORCES AND OBSERVABLES 628 system, out = update_observables(system, conformation) 629 630 ## END OF STEP UPDATES 631 if do_thermostat_post: 632 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 633 system["thermostat"], dyn_state["thermostat_post_state"] 634 ) 635 636 if do_ir_spectrum: 637 ir_state = update_dipole( 638 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 639 ) 640 dyn_state["ir_spectrum"] = ir_post(ir_state) 641 642 return dyn_state, system, conformation, out 643 644 ########################################################### 645 646 print("# Computing initial energy and forces") 647 648 conformation = update_conformation(conformation, system) 649 # initialize IR conformation 650 if do_ir_spectrum and model_ir is not None: 651 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 652 model_ir.preprocessing( 653 model_ir.preproc_state, 654 update_conformation_ir(conformation, system), 655 ) 656 ) 657 658 system, _ = update_observables(system, conformation) 659 660 return step, update_conformation, system_data, dyn_state, conformation, system
def
initialize_dynamics(simulation_parameters, fprec, rng_key):
19def initialize_dynamics(simulation_parameters, fprec, rng_key): 20 ### LOAD MODEL 21 model = load_model(simulation_parameters) 22 model_energy_unit = us.get_multiplier(model.energy_unit) 23 24 ### Get the coordinates and species from the xyz file 25 system_data, conformation = load_system_data(simulation_parameters, fprec) 26 system_data["model_energy_unit"] = model_energy_unit 27 system_data["model_energy_unit_str"] = model.energy_unit 28 29 ### FINISH BUILDING conformation 30 do_restart = os.path.exists(get_restart_file(system_data)) 31 if do_restart: 32 ### RESTART FROM PREVIOUS DYNAMICS 33 restart_data = load_dynamics_restart(system_data) 34 print("# RESTARTING FROM PREVIOUS DYNAMICS") 35 model.preproc_state = restart_data["preproc_state"] 36 conformation["coordinates"] = restart_data["coordinates"] 37 else: 38 restart_data = {} 39 40 ### INITIALIZE PREPROCESSING 41 preproc_state, conformation = initialize_preprocessing( 42 simulation_parameters, model, conformation, system_data 43 ) 44 45 minimize = simulation_parameters.get("xyz_input/minimize", False) 46 """@keyword[fennol_md] xyz_input/minimize 47 Perform energy minimization before dynamics. 48 Default: False 49 """ 50 if minimize and not do_restart: 51 assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems" 52 model.preproc_state = preproc_state 53 convert = us.KCALPERMOL / model_energy_unit 54 nat = system_data["nat"] 55 def energy_force_fn(coordinates): 56 inputs = {**conformation, "coordinates": coordinates} 57 e, f, _ = model.energy_and_forces( 58 **inputs, gpu_preprocessing=True 59 ) 60 e = float(e[0]) * convert / nat 61 f = np.array(f) * convert 62 return e, f 63 tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL 64 """@keyword[fennol_md] xyz_input/minimize_ftol 65 Force tolerance for minimization. 66 Default: 0.1 kcal/mol/Å 67 """ 68 print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A") 69 conformation["coordinates"], success = optimize_fire2( 70 conformation["coordinates"], 71 energy_force_fn, 72 atol=tol, 73 max_disp=0.02, 74 ) 75 if success: 76 print("# Minimization successful") 77 else: 78 print("# Warning: Minimization failed, continuing with last configuration") 79 # write the minimized coordinates as an xyz file 80 from ..utils.io import write_xyz_frame 81 with open(system_data["name"]+".opt.xyz", "w") as f: 82 write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None)) 83 print("# Minimized configuration written to", system_data["name"]+".opt.xyz") 84 preproc_state = model.preproc_state 85 conformation = model.preprocessing.process(preproc_state, conformation) 86 system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy() 87 88 ### get dynamics parameters 89 dt = simulation_parameters.get("dt") 90 """@keyword[fennol_md] dt 91 Integration time step. Required parameter. 92 Type: float, Required 93 """ 94 dt2 = 0.5 * dt 95 mass = system_data["mass"] 96 densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3) 97 nat = system_data["nat"] 98 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 99 ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3) 100 101 nreplicas = system_data.get("nreplicas", 1) 102 nbeads = system_data.get("nbeads", None) 103 if nbeads is not None: 104 nreplicas = nbeads 105 dtm = dtm[None, :, :] 106 107 fix_com = simulation_parameters.get("fix_com", False) 108 if fix_com: 109 massprop = system_data["mass_Da"]/system_data["totmass_Da"] 110 print("# Reference frame fixed to center of mass") 111 x = conformation["coordinates"].reshape(-1, nat, 3) 112 if nbeads is None: 113 xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True) 114 else: 115 xcom = jnp.sum(x*massprop[None,:,None], axis=(0,1), keepdims=True)/nbeads 116 conformation["coordinates"] = (x - xcom).reshape(-1, 3) 117 118 ### INITIALIZE DYNAMICS STATE 119 system = {"coordinates": conformation["coordinates"]} 120 dyn_state = { 121 "istep": 0, 122 "dt": dt, 123 "pimd": nbeads is not None, 124 "preproc_state": preproc_state, 125 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 126 } 127 gradient_keys = ["coordinates"] 128 thermo_updates = [] 129 130 ### INITIALIZE THERMOSTAT 131 thermostat_rng, rng_key = jax.random.split(rng_key) 132 ( 133 thermostat, 134 thermostat_post, 135 thermostat_state, 136 initial_vel, 137 dyn_state["thermostat_name"], 138 ekin_instant, 139 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 140 do_thermostat_post = thermostat_post is not None 141 if do_thermostat_post: 142 thermostat_post, post_state = thermostat_post 143 dyn_state["thermostat_post_state"] = post_state 144 145 assert ekin_instant in ["B", "O"], "ekin_instant must be 'B' or 'O'" 146 147 system["thermostat"] = thermostat_state 148 149 if "initial_velocities" in simulation_parameters: 150 initial_vel_file = simulation_parameters["initial_velocities"] 151 print(f"# Using initial velocities from {initial_vel_file}") 152 v0 = jnp.array(np.loadtxt(initial_vel_file).reshape(-1,nat, 3).astype(fprec)) 153 if v0.shape[0] == 1: 154 if nbeads is None: 155 v0 = v0.repeat(nreplicas, axis=0) 156 else: 157 v0 = initial_vel.at[0].set(v0[0]) 158 assert v0.shape == (nreplicas, nat, 3), f"initial_velocities file must have shape (nat, 3) or (nreplicas, nat, 3), but got {v0.shape}" 159 initial_vel = v0 160 if nbeads is None: 161 initial_vel = initial_vel.reshape(-1, 3) 162 163 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 164 165 ### PBC 166 pbc_data = system_data.get("pbc", None) 167 if pbc_data is not None: 168 ### INITIALIZE BAROSTAT 169 barostat_key, rng_key = jax.random.split(rng_key) 170 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 171 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 172 ) 173 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 174 system["barostat"] = barostat_state 175 system["cell"] = conformation["cells"][0] 176 if estimate_pressure: 177 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0) 178 """@keyword[fennol_md] pressure_o_weight 179 Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator. 180 Default: 1.0 181 """ 182 assert ( 183 0.0 <= pressure_o_weight <= 1.0 184 ), "pressure_o_weight must be between 0 and 1" 185 gradient_keys.append("strain") 186 print("# Estimate pressure: ", estimate_pressure) 187 else: 188 estimate_pressure = False 189 variable_cell = False 190 191 def thermo_update_ensemble(x, v, system): 192 v, thermostat_state = thermostat(v, system["thermostat"]) 193 return x, v, {**system, "thermostat": thermostat_state} 194 195 dyn_state["estimate_pressure"] = estimate_pressure 196 dyn_state["variable_cell"] = variable_cell 197 thermo_updates.append(thermo_update_ensemble) 198 199 if estimate_pressure: 200 use_average_Pkin = simulation_parameters.get("use_average_Pkin", False) 201 """@keyword[fennol_md] use_average_Pkin 202 Use time-averaged kinetic energy for pressure estimation instead of instantaneous values. 203 Default: False 204 """ 205 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 206 if is_qtb and use_average_Pkin: 207 raise ValueError( 208 "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False" 209 ) 210 211 212 ### ENERGY ENSEMBLE 213 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 214 """@keyword[fennol_md] etot_ensemble_key 215 Key for energy ensemble calculation. Enables computation of ensemble weights. 216 Default: None 217 """ 218 219 ### COLVARS 220 colvars_definitions = simulation_parameters.get("colvars", None) 221 """@keyword[fennol_md] colvars 222 Collective variables definitions for enhanced sampling or monitoring. 223 Default: None 224 """ 225 use_colvars = colvars_definitions is not None 226 if use_colvars: 227 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 228 dyn_state["colvars"] = colvars_names 229 230 ### IR SPECTRUM 231 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 232 """@keyword[fennol_md] ir_spectrum 233 Calculate infrared spectrum from molecular dipole moment time series. 234 Default: False 235 """ 236 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 237 if do_ir_spectrum: 238 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 239 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 240 simulation_parameters, system_data, fprec, dt, is_qtb 241 ) 242 dyn_state["ir_spectrum"] = ir_state 243 244 ### BUILD GRADIENT FUNCTION 245 energy_and_gradient = model.get_gradient_function( 246 *gradient_keys, jit=True, variables_as_input=True 247 ) 248 249 ### COLLECT THERMO UPDATES 250 if len(thermo_updates) == 1: 251 thermo_update = thermo_updates[0] 252 else: 253 254 def thermo_update(x, v, system): 255 for update in thermo_updates: 256 x, v, system = update(x, v, system) 257 return x, v, system 258 259 ### RING POLYMER INITIALIZATION 260 if nbeads is not None: 261 cay_correction = simulation_parameters.get("cay_correction", True) 262 """@keyword[fennol_md] cay_correction 263 Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation. 264 Default: True 265 """ 266 omk = system_data["omk"] 267 eigmat = system_data["eigmat"] 268 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 269 if cay_correction: 270 axx = jnp.asarray(2 * cayfact) 271 axv = jnp.asarray(dt * cayfact) 272 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 273 else: 274 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 275 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 276 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 277 278 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 279 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 280 system["coordinates"] = eigx 281 282 ############################################### 283 ### DEFINE UPDATE FUNCTION 284 @jax.jit 285 def update_conformation(conformation, system): 286 x = system["coordinates"] 287 if nbeads is not None: 288 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 289 nbeads**0.5 290 ) 291 conformation = {**conformation, "coordinates": x} 292 if variable_cell: 293 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 294 295 296 297 return conformation 298 299 ############################################### 300 ### DEFINE INTEGRATION FUNCTIONS 301 def integrate_A_half(x0, v0): 302 if nbeads is None: 303 return x0 + dt2 * v0, v0 304 305 # update coordinates and velocities of a free ring polymer for a half time step 306 eigx_c = x0[0] + dt2 * v0[0] 307 eigv_c = v0[0] 308 eigx = x0[1:] * axx + v0[1:] * axv 309 eigv = x0[1:] * avx + v0[1:] * axx 310 311 return ( 312 jnp.concatenate((eigx_c[None], eigx), axis=0), 313 jnp.concatenate((eigv_c[None], eigv), axis=0), 314 ) 315 316 @jax.jit 317 def integrate(system): 318 x = system["coordinates"] 319 v = system["vel"] + dtm * system["forces"] 320 x, v = integrate_A_half(x, v) 321 x, v, system = thermo_update(x, v, system) 322 x, v = integrate_A_half(x, v) 323 324 if fix_com: 325 if nbeads is None: 326 x = x.reshape(-1, nat, 3) 327 xcom = jnp.sum(x*massprop[None,:,None], axis=1, keepdims=True) 328 x = (x - xcom).reshape(-1, 3) 329 else: 330 xcom = jnp.sum(x[0]*massprop[:,None], axis=0, keepdims=True) 331 x = x.at[0].set(x[0] - xcom) 332 333 return {**system, "coordinates": x, "vel": v} 334 335 ############################################### 336 ### DEFINE OBSERVABLE FUNCTION 337 @jax.jit 338 def update_observables(system, conformation): 339 ### POTENTIAL ENERGY AND FORCES 340 epot, de, out = energy_and_gradient(model.variables, conformation) 341 out["forces"] = -de["coordinates"] 342 epot = epot / model_energy_unit 343 de = {k: v / model_energy_unit for k, v in de.items()} 344 forces = -de["coordinates"] 345 346 if nbeads is not None: 347 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 348 forces = jnp.einsum( 349 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 350 ) * (1.0 / nbeads**0.5) 351 352 353 system = { 354 **system, 355 "epot": jnp.mean(epot), 356 "forces": forces, 357 "energy_gradients": de, 358 } 359 360 ### KINETIC ENERGY 361 v = system["vel"] 362 if ekin_instant == "B": 363 v = v + 0.5 * dtm * forces 364 365 if nbeads is None: 366 corr_kin = system["thermostat"].get("corr_kin", 1.0) 367 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 368 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 369 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 370 ) 371 else: 372 ek_c = 0.5 * jnp.sum( 373 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 374 ) 375 ek = ek_c - 0.5 * jnp.sum( 376 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 377 axis=(0, 1), 378 ) 379 system["ek_c"] = jnp.trace(ek_c) 380 381 system["ek"] = jnp.trace(ek) 382 system["ek_tensor"] = ek 383 384 if estimate_pressure: 385 if use_average_Pkin: 386 ek = ek_avg 387 elif pressure_o_weight != 1.0: 388 v = system["vel"] + 0.5 * dtm * system["forces"] 389 if nbeads is None: 390 corr_kin = system["thermostat"].get("corr_kin", 1.0) 391 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 392 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 393 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 394 ) 395 else: 396 ek_c = 0.5 * jnp.sum( 397 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 398 axis=0, 399 ) 400 ek = ek_c - 0.5 * jnp.sum( 401 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 402 axis=(0, 1), 403 ) 404 b = pressure_o_weight 405 ek = (1.0 - b) * ek + b * system["ek_tensor"] 406 407 vir = jnp.mean(de["strain"], axis=0) 408 system["virial"] = vir 409 out["virial_tensor"] = vir * model_energy_unit 410 411 volume = jnp.abs(jnp.linalg.det(system["cell"])) 412 Pres = ek*(2./volume) - vir/volume 413 system["pressure_tensor"] = Pres 414 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 415 if variable_cell: 416 density = densmass / volume 417 system["density"] = density 418 system["volume"] = volume 419 420 if ensemble_key is not None: 421 kT = system_data["kT"] 422 dE = ( 423 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 424 ) 425 system["ensemble_weights"] = -dE / kT 426 427 if "total_dipole" in out: 428 if nbeads is None: 429 system["total_dipole"] = out["total_dipole"][0] 430 else: 431 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 432 433 if use_colvars: 434 coords = system["coordinates"].reshape(-1, nat, 3)[0] 435 colvars = {} 436 for colvar_name, colvar_calc in colvars_calculators.items(): 437 colvars[colvar_name] = colvar_calc(coords) 438 system["colvars"] = colvars 439 440 return system, out 441 442 ############################################### 443 ### IR SPECTRUM 444 if do_ir_spectrum: 445 # @jax.jit 446 # def update_dipole(ir_state,system,conformation): 447 # def mumodel(coords): 448 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 449 # if nbeads is None: 450 # return out["total_dipole"][0] 451 # return out["total_dipole"].sum(axis=0) 452 # dmudqmodel = jax.jacobian(mumodel) 453 454 # dmudq = dmudqmodel(conformation["coordinates"]) 455 # # print(dmudq.shape) 456 # if nbeads is None: 457 # vel = system["vel"].reshape(-1,1,nat,3)[0] 458 # mudot = (vel*dmudq).sum(axis=(1,2)) 459 # else: 460 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 461 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 462 # ) 463 # # vel = system["vel"][0].reshape(1,nat,3) 464 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 465 466 # ir_state = save_dipole(mudot,ir_state) 467 # return ir_state 468 @jax.jit 469 def update_conformation_ir(conformation, system): 470 conformation = { 471 **conformation, 472 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 473 "natoms": jnp.asarray([nat]), 474 "batch_index": jnp.asarray([0] * nat), 475 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 476 } 477 if variable_cell: 478 conformation["cells"] = system["cell"][None, :, :] 479 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 480 None, :, : 481 ] 482 return conformation 483 484 @jax.jit 485 def update_dipole(ir_state, system, conformation): 486 if model_ir is not None: 487 out = model_ir._apply(model_ir.variables, conformation) 488 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 489 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 490 else: 491 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 492 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 493 if nbeads is not None: 494 q = jnp.mean(q, axis=0) 495 dip = jnp.mean(dip, axis=0) 496 vel = system["vel"][0] 497 pos = system["coordinates"][0] 498 else: 499 q = q[0] 500 dip = dip[0] 501 vel = system["vel"].reshape(-1, nat, 3)[0] 502 pos = system["coordinates"].reshape(-1, nat, 3)[0] 503 504 if pbc_data is not None: 505 cell_reciprocal = ( 506 conformation["cells"][0], 507 conformation["reciprocal_cells"][0], 508 ) 509 else: 510 cell_reciprocal = None 511 512 ir_state = save_dipole( 513 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 514 ) 515 return ir_state 516 517 ############################################### 518 ### GRAPH UPDATES 519 520 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 521 """@keyword[fennol_md] nblist_verbose 522 Print verbose information about neighbor list updates and reallocations. 523 Default: False 524 """ 525 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 526 """@keyword[fennol_md] nblist_stride 527 Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0. 528 Default: -1 529 """ 530 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) 531 """@keyword[fennol_md] nblist_warmup_time 532 Time period for neighbor list warmup before using skin updates. 533 Default: -1.0 534 """ 535 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 536 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 537 """@keyword[fennol_md] nblist_skin 538 Neighbor list skin distance for efficient updates (in Angstroms). 539 Default: -1.0 540 """ 541 if nblist_skin > 0: 542 if nblist_stride <= 0: 543 ## reference skin parameters at 300K (from Tinker-HP) 544 ## => skin of 2 A gives you 40 fs without complete rebuild 545 t_ref = 40.0 /us.FS # FS 546 nblist_skin_ref = 2.0 # A 547 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 548 print( 549 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 550 ) 551 552 if nblist_skin <= 0: 553 nblist_stride = 1 554 555 dyn_state["nblist_countdown"] = 0 556 dyn_state["print_skin_activation"] = nblist_warmup > 0 557 558 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 559 nblist_countdown = dyn_state["nblist_countdown"] 560 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 561 ### FULL NBLIST REBUILD 562 dyn_state["nblist_countdown"] = nblist_stride - 1 563 preproc_state = dyn_state["preproc_state"] 564 conformation = model.preprocessing.process( 565 preproc_state, update_conformation(conformation, system) 566 ) 567 preproc_state, state_up, conformation, overflow = ( 568 model.preprocessing.check_reallocate(preproc_state, conformation) 569 ) 570 dyn_state["preproc_state"] = preproc_state 571 if nblist_verbose and overflow: 572 print("step", istep, ", nblist overflow => reallocating nblist") 573 print("size updates:", state_up) 574 575 if do_ir_spectrum and model_ir is not None: 576 conformation_ir = model_ir.preprocessing.process( 577 dyn_state["preproc_state_ir"], 578 update_conformation_ir(dyn_state["conformation_ir"], system), 579 ) 580 ( 581 dyn_state["preproc_state_ir"], 582 _, 583 dyn_state["conformation_ir"], 584 overflow, 585 ) = model_ir.preprocessing.check_reallocate( 586 dyn_state["preproc_state_ir"], conformation_ir 587 ) 588 589 else: 590 ### SKIN UPDATE 591 if dyn_state["print_skin_activation"]: 592 if nblist_verbose: 593 print( 594 "step", 595 istep, 596 ", end of nblist warmup phase => activating skin updates", 597 ) 598 dyn_state["print_skin_activation"] = False 599 600 dyn_state["nblist_countdown"] = nblist_countdown - 1 601 conformation = model.preprocessing.update_skin( 602 update_conformation(conformation, system) 603 ) 604 if do_ir_spectrum and model_ir is not None: 605 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 606 update_conformation_ir(dyn_state["conformation_ir"], system) 607 ) 608 609 return conformation, dyn_state 610 611 ################################################ 612 ### DEFINE STEP FUNCTION 613 def step(istep, dyn_state, system, conformation, force_preprocess=False): 614 615 dyn_state = { 616 **dyn_state, 617 "istep": dyn_state["istep"] + 1, 618 } 619 620 ### INTEGRATE EQUATIONS OF MOTION 621 system = integrate(system) 622 623 ### UPDATE CONFORMATION AND GRAPHS 624 conformation, dyn_state = update_graphs( 625 istep, dyn_state, system, conformation, force_preprocess 626 ) 627 628 ## COMPUTE FORCES AND OBSERVABLES 629 system, out = update_observables(system, conformation) 630 631 ## END OF STEP UPDATES 632 if do_thermostat_post: 633 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 634 system["thermostat"], dyn_state["thermostat_post_state"] 635 ) 636 637 if do_ir_spectrum: 638 ir_state = update_dipole( 639 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 640 ) 641 dyn_state["ir_spectrum"] = ir_post(ir_state) 642 643 return dyn_state, system, conformation, out 644 645 ########################################################### 646 647 print("# Computing initial energy and forces") 648 649 conformation = update_conformation(conformation, system) 650 # initialize IR conformation 651 if do_ir_spectrum and model_ir is not None: 652 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 653 model_ir.preprocessing( 654 model_ir.preproc_state, 655 update_conformation_ir(conformation, system), 656 ) 657 ) 658 659 system, _ = update_observables(system, conformation) 660 661 return step, update_conformation, system_data, dyn_state, conformation, system