fennol.md.initial
1from pathlib import Path 2 3import numpy as np 4import jax.numpy as jnp 5 6from flax.core import freeze, unfreeze 7 8from ..utils.io import last_xyz_frame 9 10 11from ..models import FENNIX 12 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES 14from .utils import us 15from ..utils import detect_topology,parse_cell,cell_is_triangular,tril_cell 16 17 18def load_model(simulation_parameters): 19 model_file = simulation_parameters.get("model_file") 20 """@keyword[fennol_md] model_file 21 Path to the machine learning model file (.fnx format). Required parameter. 22 Type: str, Required 23 """ 24 model_file = Path(str(model_file).strip()) 25 if not model_file.exists(): 26 raise FileNotFoundError(f"model file {model_file} not found") 27 else: 28 graph_config = simulation_parameters.get("graph_config", {}) 29 """@keyword[fennol_md] graph_config 30 Advanced graph configuration for model initialization. 31 Default: {} 32 """ 33 model = FENNIX.load(model_file, graph_config=graph_config) # \ 34 print(f"# model_file: {model_file}") 35 36 if "energy_terms" in simulation_parameters: 37 energy_terms = simulation_parameters["energy_terms"] 38 if isinstance(energy_terms, str): 39 energy_terms = energy_terms.split() 40 model.set_energy_terms(energy_terms) 41 print("# energy terms:", model.energy_terms) 42 43 return model 44 45 46def load_system_data(simulation_parameters, fprec): 47 ## LOAD SYSTEM CONFORMATION FROM FILES 48 system_name = str(simulation_parameters.get("system_name", "system")).strip() 49 """@keyword[fennol_md] system_name 50 Name prefix for output files. If not specified, uses the xyz filename stem. 51 Default: "system" 52 """ 53 indexed = simulation_parameters.get("xyz_input/indexed", False) 54 """@keyword[fennol_md] xyz_input/indexed 55 Whether first column contains atom indices (Tinker format). 56 Default: False 57 """ 58 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True) 59 """@keyword[fennol_md] xyz_input/has_comment_line 60 Whether file contains comment lines. 61 Default: True 62 """ 63 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 64 """@keyword[fennol_md] xyz_input/file 65 Path to xyz/arc coordinate file. Required parameter. 66 Type: str, Required 67 """ 68 if not xyzfile.exists(): 69 raise FileNotFoundError(f"xyz file {xyzfile} not found") 70 system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip() 71 symbols, coordinates, _ = last_xyz_frame( 72 xyzfile, indexed=indexed, has_comment_line=has_comment_line 73 ) 74 coordinates = coordinates.astype(fprec) 75 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 76 nat = species.shape[0] 77 78 ## GET MASS 79 mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species] 80 deuterate = simulation_parameters.get("deuterate", False) 81 """@keyword[fennol_md] deuterate 82 Replace hydrogen masses with deuterium masses. 83 Default: False 84 """ 85 if deuterate: 86 print("# Replacing all hydrogens with deuteriums") 87 mass_Da[species == 1] *= 2.0 88 89 totmass_Da = mass_Da.sum() 90 91 mass = mass_Da.copy() 92 hmr = simulation_parameters.get("hmr", 0) 93 """@keyword[fennol_md] hmr 94 Hydrogen mass repartitioning factor. 0 = no repartitioning. 95 Default: 0 96 """ 97 if hmr > 0: 98 print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.") 99 Hmask = species == 1 100 added_mass = hmr * Hmask.sum() 101 mass[Hmask] += hmr 102 wmass = mass[~Hmask] 103 mass[~Hmask] -= added_mass * wmass/ wmass.sum() 104 105 assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed" 106 107 # convert to internal units 108 mass = mass / us.DA 109 110 ### GET TEMPERATURE 111 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 112 """@keyword[fennol_md] temperature 113 Target temperature in Kelvin. 114 Default: 300.0 115 """ 116 kT = us.K_B * temperature 117 118 ### GET TOTAL CHARGE 119 total_charge = simulation_parameters.get("total_charge", None) 120 """@keyword[fennol_md] total_charge 121 Total system charge for charged systems. 122 Default: None (interpreted as 0) 123 """ 124 if total_charge is None: 125 total_charge = 0 126 else: 127 total_charge = int(total_charge) 128 print("# total charge: ", total_charge,"e") 129 130 ### ENERGY UNIT 131 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 132 """@keyword[fennol_md] energy_unit 133 Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'. 134 Default: "kcal/mol" 135 """ 136 energy_unit = us.get_multiplier(energy_unit_str) 137 138 ## SYSTEM DATA 139 system_data = { 140 "name": system_name, 141 "nat": nat, 142 "symbols": symbols, 143 "species": species, 144 "mass": mass, 145 "mass_Da": mass_Da, 146 "totmass_Da": totmass_Da, 147 "temperature": temperature, 148 "kT": kT, 149 "total_charge": total_charge, 150 "energy_unit": energy_unit, 151 "energy_unit_str": energy_unit_str, 152 } 153 input_flags = simulation_parameters.get("model_flags", []) 154 """@keyword[fennol_md] model_flags 155 Additional flags to pass to the model. 156 Default: [] 157 """ 158 flags = {f:None for f in input_flags} 159 160 ### Set boundary conditions 161 cell = simulation_parameters.get("cell", None) 162 """@keyword[fennol_md] cell 163 Unit cell vectors. Required for PBC. It is a sequence of floats: 164 - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz] 165 - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma] 166 - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic) 167 - 1 float: length of cell vectors (cubic cell) 168 Lengths are in Angstroms, angles in degrees. 169 Default: None 170 """ 171 if cell is not None: 172 cell = parse_cell(cell).astype(fprec) 173 rotate_cell = not cell_is_triangular(cell) 174 if rotate_cell: 175 print("# Warning: provided cell is not lower triangular. Rotating to canonical cell orientation.") 176 cell, cell_rotation = tril_cell(cell) 177 # cell = np.array(cell, dtype=fprec).reshape(3, 3) 178 reciprocal_cell = np.linalg.inv(cell) 179 volume = np.abs(np.linalg.det(cell)) 180 print("# cell matrix:") 181 for l in cell: 182 print("# ", l) 183 # print(cell) 184 dens = (totmass_Da/volume) * (us.MOL/us.CM**3) 185 print("# density: ", dens.item(), " g/cm^3") 186 minimum_image = simulation_parameters.get("minimum_image", True) 187 """@keyword[fennol_md] minimum_image 188 Use minimum image convention for neighbor lists in periodic systems. 189 Default: True 190 """ 191 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 192 """@keyword[fennol_md] estimate_pressure 193 Calculate and print pressure during simulation. 194 Default: False 195 """ 196 print("# minimum_image: ", minimum_image) 197 198 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 199 """@keyword[fennol_md] xyz_input/crystal 200 Use crystal coordinates. 201 Default: False 202 """ 203 if crystal_input: 204 coordinates = coordinates @ cell 205 206 if rotate_cell: 207 coordinates = coordinates @ cell_rotation 208 209 pbc_data = { 210 "cell": cell, 211 "reciprocal_cell": reciprocal_cell, 212 "volume": volume, 213 "minimum_image": minimum_image, 214 "estimate_pressure": estimate_pressure, 215 } 216 if minimum_image: 217 flags["minimum_image"] = None 218 else: 219 pbc_data = None 220 system_data["pbc"] = pbc_data 221 system_data["initial_coordinates"] = coordinates.copy() 222 223 ### TOPOLOGY 224 topology_key = simulation_parameters.get("topology", None) 225 """@keyword[fennol_md] topology 226 Topology specification for molecular systems. Use "detect" for automatic detection. 227 Default: None 228 """ 229 if topology_key is not None: 230 topology_key = str(topology_key).strip() 231 if topology_key.lower() == "detect": 232 topology = detect_topology(species,coordinates,cell=cell) 233 np.savetxt(system_name +".topo", topology+1, fmt="%d") 234 print("# Detected topology saved to", system_name + ".topo") 235 else: 236 assert Path(topology_key).exists(), f"Topology file {topology_key} not found" 237 topology = np.loadtxt(topology_key, dtype=np.int32)-1 238 assert topology.shape[1] == 2, "Topology file must have two columns (source, target)" 239 print("# Topology loaded from", topology_key) 240 else: 241 topology = None 242 243 system_data["topology"] = topology 244 245 ### PIMD 246 nbeads = simulation_parameters.get("nbeads", None) 247 """@keyword[fennol_md] nbeads 248 Number of beads for Path Integral MD. 249 Default: None 250 """ 251 nreplicas = simulation_parameters.get("nreplicas", None) 252 """@keyword[fennol_md] nreplicas 253 Number of replicas for independent replica simulations. 254 Default: None 255 """ 256 if nbeads is not None: 257 nbeads = int(nbeads) 258 print("# nbeads: ", nbeads) 259 system_data["nbeads"] = nbeads 260 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 261 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 262 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 263 natoms = np.array([nat] * nbeads, dtype=np.int32) 264 265 eigmat = np.zeros((nbeads, nbeads)) 266 for i in range(nbeads - 1): 267 eigmat[i, i] = 2.0 268 eigmat[i, i + 1] = -1.0 269 eigmat[i + 1, i] = -1.0 270 eigmat[-1, -1] = 2.0 271 eigmat[0, -1] = -1.0 272 eigmat[-1, 0] = -1.0 273 omk, eigmat = np.linalg.eigh(eigmat) 274 omk[0] = 0.0 275 omk = (nbeads * kT / us.HBAR) * omk**0.5 276 for i in range(nbeads): 277 if eigmat[i, 0] < 0: 278 eigmat[i] *= -1.0 279 eigmat = jnp.asarray(eigmat, dtype=fprec) 280 system_data["omk"] = omk 281 system_data["eigmat"] = eigmat 282 nreplicas = None 283 elif nreplicas is not None: 284 nreplicas = int(nreplicas) 285 print("# nreplicas: ", nreplicas) 286 system_data["nreplicas"] = nreplicas 287 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 288 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 289 -1 290 ) 291 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 292 -1, 3 293 ) 294 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 295 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 296 natoms = np.array([nat] * nreplicas, dtype=np.int32) 297 else: 298 system_data["nreplicas"] = 1 299 bead_index = np.array([0] * nat, dtype=np.int32) 300 natoms = np.array([nat], dtype=np.int32) 301 302 conformation = { 303 "species": species, 304 "coordinates": coordinates, 305 "batch_index": bead_index, 306 "natoms": natoms, 307 "total_charge": total_charge, 308 } 309 if cell is not None: 310 cell = cell[None, :, :] 311 reciprocal_cell = reciprocal_cell[None, :, :] 312 if nbeads is not None: 313 cell = np.repeat(cell, nbeads, axis=0) 314 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 315 elif nreplicas is not None: 316 cell = np.repeat(cell, nreplicas, axis=0) 317 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 318 conformation["cells"] = cell 319 conformation["reciprocal_cells"] = reciprocal_cell 320 321 additional_keys = simulation_parameters.get("additional_keys", {}) 322 """@keyword[fennol_md] additional_keys 323 Additional custom keys for model input. 324 Default: {} 325 """ 326 for key, value in additional_keys.items(): 327 conformation[key] = value 328 329 conformation["flags"] = flags 330 331 return system_data, conformation 332 333 334def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 335 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 336 """@keyword[fennol_md] nblist_verbose 337 Print detailed neighbor list information. 338 Default: False 339 """ 340 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 341 """@keyword[fennol_md] nblist_skin 342 Neighbor list skin distance in Angstroms. 343 Default: -1.0 (automatic) 344 """ 345 346 ### CONFIGURE PREPROCESSING 347 preproc_state = unfreeze(model.preproc_state) 348 layer_state = [] 349 for st in preproc_state["layers_state"]: 350 stnew = unfreeze(st) 351 if nblist_skin > 0: 352 stnew["nblist_skin"] = nblist_skin 353 if "nblist_mult_size" in simulation_parameters: 354 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 355 """@keyword[fennol_md] nblist_mult_size 356 Multiplier for neighbor list size. 357 Default: None 358 """ 359 if "nblist_add_neigh" in simulation_parameters: 360 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 361 """@keyword[fennol_md] nblist_add_neigh 362 Additional neighbors to include in lists. 363 Default: None 364 """ 365 layer_state.append(freeze(stnew)) 366 preproc_state["layers_state"] = layer_state 367 preproc_state = freeze(preproc_state) 368 369 ## initial preprocessing 370 preproc_state = preproc_state.copy({"check_input": True}) 371 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 372 373 preproc_state = preproc_state.copy({"check_input": False}) 374 375 if nblist_verbose: 376 graphs_keys = list(model._graphs_properties.keys()) 377 print("# graphs_keys: ", graphs_keys) 378 print("# nblist state:", preproc_state) 379 380 ### print model 381 if simulation_parameters.get("print_model", False): 382 """@keyword[fennol_md] print_model 383 Print detailed model information at startup. 384 Default: False 385 """ 386 print(model.summarize(example_data=conformation)) 387 388 return preproc_state, conformation
def
load_model(simulation_parameters):
19def load_model(simulation_parameters): 20 model_file = simulation_parameters.get("model_file") 21 """@keyword[fennol_md] model_file 22 Path to the machine learning model file (.fnx format). Required parameter. 23 Type: str, Required 24 """ 25 model_file = Path(str(model_file).strip()) 26 if not model_file.exists(): 27 raise FileNotFoundError(f"model file {model_file} not found") 28 else: 29 graph_config = simulation_parameters.get("graph_config", {}) 30 """@keyword[fennol_md] graph_config 31 Advanced graph configuration for model initialization. 32 Default: {} 33 """ 34 model = FENNIX.load(model_file, graph_config=graph_config) # \ 35 print(f"# model_file: {model_file}") 36 37 if "energy_terms" in simulation_parameters: 38 energy_terms = simulation_parameters["energy_terms"] 39 if isinstance(energy_terms, str): 40 energy_terms = energy_terms.split() 41 model.set_energy_terms(energy_terms) 42 print("# energy terms:", model.energy_terms) 43 44 return model
def
load_system_data(simulation_parameters, fprec):
47def load_system_data(simulation_parameters, fprec): 48 ## LOAD SYSTEM CONFORMATION FROM FILES 49 system_name = str(simulation_parameters.get("system_name", "system")).strip() 50 """@keyword[fennol_md] system_name 51 Name prefix for output files. If not specified, uses the xyz filename stem. 52 Default: "system" 53 """ 54 indexed = simulation_parameters.get("xyz_input/indexed", False) 55 """@keyword[fennol_md] xyz_input/indexed 56 Whether first column contains atom indices (Tinker format). 57 Default: False 58 """ 59 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True) 60 """@keyword[fennol_md] xyz_input/has_comment_line 61 Whether file contains comment lines. 62 Default: True 63 """ 64 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 65 """@keyword[fennol_md] xyz_input/file 66 Path to xyz/arc coordinate file. Required parameter. 67 Type: str, Required 68 """ 69 if not xyzfile.exists(): 70 raise FileNotFoundError(f"xyz file {xyzfile} not found") 71 system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip() 72 symbols, coordinates, _ = last_xyz_frame( 73 xyzfile, indexed=indexed, has_comment_line=has_comment_line 74 ) 75 coordinates = coordinates.astype(fprec) 76 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 77 nat = species.shape[0] 78 79 ## GET MASS 80 mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species] 81 deuterate = simulation_parameters.get("deuterate", False) 82 """@keyword[fennol_md] deuterate 83 Replace hydrogen masses with deuterium masses. 84 Default: False 85 """ 86 if deuterate: 87 print("# Replacing all hydrogens with deuteriums") 88 mass_Da[species == 1] *= 2.0 89 90 totmass_Da = mass_Da.sum() 91 92 mass = mass_Da.copy() 93 hmr = simulation_parameters.get("hmr", 0) 94 """@keyword[fennol_md] hmr 95 Hydrogen mass repartitioning factor. 0 = no repartitioning. 96 Default: 0 97 """ 98 if hmr > 0: 99 print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.") 100 Hmask = species == 1 101 added_mass = hmr * Hmask.sum() 102 mass[Hmask] += hmr 103 wmass = mass[~Hmask] 104 mass[~Hmask] -= added_mass * wmass/ wmass.sum() 105 106 assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed" 107 108 # convert to internal units 109 mass = mass / us.DA 110 111 ### GET TEMPERATURE 112 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 113 """@keyword[fennol_md] temperature 114 Target temperature in Kelvin. 115 Default: 300.0 116 """ 117 kT = us.K_B * temperature 118 119 ### GET TOTAL CHARGE 120 total_charge = simulation_parameters.get("total_charge", None) 121 """@keyword[fennol_md] total_charge 122 Total system charge for charged systems. 123 Default: None (interpreted as 0) 124 """ 125 if total_charge is None: 126 total_charge = 0 127 else: 128 total_charge = int(total_charge) 129 print("# total charge: ", total_charge,"e") 130 131 ### ENERGY UNIT 132 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 133 """@keyword[fennol_md] energy_unit 134 Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'. 135 Default: "kcal/mol" 136 """ 137 energy_unit = us.get_multiplier(energy_unit_str) 138 139 ## SYSTEM DATA 140 system_data = { 141 "name": system_name, 142 "nat": nat, 143 "symbols": symbols, 144 "species": species, 145 "mass": mass, 146 "mass_Da": mass_Da, 147 "totmass_Da": totmass_Da, 148 "temperature": temperature, 149 "kT": kT, 150 "total_charge": total_charge, 151 "energy_unit": energy_unit, 152 "energy_unit_str": energy_unit_str, 153 } 154 input_flags = simulation_parameters.get("model_flags", []) 155 """@keyword[fennol_md] model_flags 156 Additional flags to pass to the model. 157 Default: [] 158 """ 159 flags = {f:None for f in input_flags} 160 161 ### Set boundary conditions 162 cell = simulation_parameters.get("cell", None) 163 """@keyword[fennol_md] cell 164 Unit cell vectors. Required for PBC. It is a sequence of floats: 165 - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz] 166 - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma] 167 - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic) 168 - 1 float: length of cell vectors (cubic cell) 169 Lengths are in Angstroms, angles in degrees. 170 Default: None 171 """ 172 if cell is not None: 173 cell = parse_cell(cell).astype(fprec) 174 rotate_cell = not cell_is_triangular(cell) 175 if rotate_cell: 176 print("# Warning: provided cell is not lower triangular. Rotating to canonical cell orientation.") 177 cell, cell_rotation = tril_cell(cell) 178 # cell = np.array(cell, dtype=fprec).reshape(3, 3) 179 reciprocal_cell = np.linalg.inv(cell) 180 volume = np.abs(np.linalg.det(cell)) 181 print("# cell matrix:") 182 for l in cell: 183 print("# ", l) 184 # print(cell) 185 dens = (totmass_Da/volume) * (us.MOL/us.CM**3) 186 print("# density: ", dens.item(), " g/cm^3") 187 minimum_image = simulation_parameters.get("minimum_image", True) 188 """@keyword[fennol_md] minimum_image 189 Use minimum image convention for neighbor lists in periodic systems. 190 Default: True 191 """ 192 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 193 """@keyword[fennol_md] estimate_pressure 194 Calculate and print pressure during simulation. 195 Default: False 196 """ 197 print("# minimum_image: ", minimum_image) 198 199 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 200 """@keyword[fennol_md] xyz_input/crystal 201 Use crystal coordinates. 202 Default: False 203 """ 204 if crystal_input: 205 coordinates = coordinates @ cell 206 207 if rotate_cell: 208 coordinates = coordinates @ cell_rotation 209 210 pbc_data = { 211 "cell": cell, 212 "reciprocal_cell": reciprocal_cell, 213 "volume": volume, 214 "minimum_image": minimum_image, 215 "estimate_pressure": estimate_pressure, 216 } 217 if minimum_image: 218 flags["minimum_image"] = None 219 else: 220 pbc_data = None 221 system_data["pbc"] = pbc_data 222 system_data["initial_coordinates"] = coordinates.copy() 223 224 ### TOPOLOGY 225 topology_key = simulation_parameters.get("topology", None) 226 """@keyword[fennol_md] topology 227 Topology specification for molecular systems. Use "detect" for automatic detection. 228 Default: None 229 """ 230 if topology_key is not None: 231 topology_key = str(topology_key).strip() 232 if topology_key.lower() == "detect": 233 topology = detect_topology(species,coordinates,cell=cell) 234 np.savetxt(system_name +".topo", topology+1, fmt="%d") 235 print("# Detected topology saved to", system_name + ".topo") 236 else: 237 assert Path(topology_key).exists(), f"Topology file {topology_key} not found" 238 topology = np.loadtxt(topology_key, dtype=np.int32)-1 239 assert topology.shape[1] == 2, "Topology file must have two columns (source, target)" 240 print("# Topology loaded from", topology_key) 241 else: 242 topology = None 243 244 system_data["topology"] = topology 245 246 ### PIMD 247 nbeads = simulation_parameters.get("nbeads", None) 248 """@keyword[fennol_md] nbeads 249 Number of beads for Path Integral MD. 250 Default: None 251 """ 252 nreplicas = simulation_parameters.get("nreplicas", None) 253 """@keyword[fennol_md] nreplicas 254 Number of replicas for independent replica simulations. 255 Default: None 256 """ 257 if nbeads is not None: 258 nbeads = int(nbeads) 259 print("# nbeads: ", nbeads) 260 system_data["nbeads"] = nbeads 261 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 262 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 263 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 264 natoms = np.array([nat] * nbeads, dtype=np.int32) 265 266 eigmat = np.zeros((nbeads, nbeads)) 267 for i in range(nbeads - 1): 268 eigmat[i, i] = 2.0 269 eigmat[i, i + 1] = -1.0 270 eigmat[i + 1, i] = -1.0 271 eigmat[-1, -1] = 2.0 272 eigmat[0, -1] = -1.0 273 eigmat[-1, 0] = -1.0 274 omk, eigmat = np.linalg.eigh(eigmat) 275 omk[0] = 0.0 276 omk = (nbeads * kT / us.HBAR) * omk**0.5 277 for i in range(nbeads): 278 if eigmat[i, 0] < 0: 279 eigmat[i] *= -1.0 280 eigmat = jnp.asarray(eigmat, dtype=fprec) 281 system_data["omk"] = omk 282 system_data["eigmat"] = eigmat 283 nreplicas = None 284 elif nreplicas is not None: 285 nreplicas = int(nreplicas) 286 print("# nreplicas: ", nreplicas) 287 system_data["nreplicas"] = nreplicas 288 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 289 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 290 -1 291 ) 292 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 293 -1, 3 294 ) 295 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 296 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 297 natoms = np.array([nat] * nreplicas, dtype=np.int32) 298 else: 299 system_data["nreplicas"] = 1 300 bead_index = np.array([0] * nat, dtype=np.int32) 301 natoms = np.array([nat], dtype=np.int32) 302 303 conformation = { 304 "species": species, 305 "coordinates": coordinates, 306 "batch_index": bead_index, 307 "natoms": natoms, 308 "total_charge": total_charge, 309 } 310 if cell is not None: 311 cell = cell[None, :, :] 312 reciprocal_cell = reciprocal_cell[None, :, :] 313 if nbeads is not None: 314 cell = np.repeat(cell, nbeads, axis=0) 315 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 316 elif nreplicas is not None: 317 cell = np.repeat(cell, nreplicas, axis=0) 318 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 319 conformation["cells"] = cell 320 conformation["reciprocal_cells"] = reciprocal_cell 321 322 additional_keys = simulation_parameters.get("additional_keys", {}) 323 """@keyword[fennol_md] additional_keys 324 Additional custom keys for model input. 325 Default: {} 326 """ 327 for key, value in additional_keys.items(): 328 conformation[key] = value 329 330 conformation["flags"] = flags 331 332 return system_data, conformation
def
initialize_preprocessing(simulation_parameters, model, conformation, system_data):
335def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 336 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 337 """@keyword[fennol_md] nblist_verbose 338 Print detailed neighbor list information. 339 Default: False 340 """ 341 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 342 """@keyword[fennol_md] nblist_skin 343 Neighbor list skin distance in Angstroms. 344 Default: -1.0 (automatic) 345 """ 346 347 ### CONFIGURE PREPROCESSING 348 preproc_state = unfreeze(model.preproc_state) 349 layer_state = [] 350 for st in preproc_state["layers_state"]: 351 stnew = unfreeze(st) 352 if nblist_skin > 0: 353 stnew["nblist_skin"] = nblist_skin 354 if "nblist_mult_size" in simulation_parameters: 355 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 356 """@keyword[fennol_md] nblist_mult_size 357 Multiplier for neighbor list size. 358 Default: None 359 """ 360 if "nblist_add_neigh" in simulation_parameters: 361 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 362 """@keyword[fennol_md] nblist_add_neigh 363 Additional neighbors to include in lists. 364 Default: None 365 """ 366 layer_state.append(freeze(stnew)) 367 preproc_state["layers_state"] = layer_state 368 preproc_state = freeze(preproc_state) 369 370 ## initial preprocessing 371 preproc_state = preproc_state.copy({"check_input": True}) 372 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 373 374 preproc_state = preproc_state.copy({"check_input": False}) 375 376 if nblist_verbose: 377 graphs_keys = list(model._graphs_properties.keys()) 378 print("# graphs_keys: ", graphs_keys) 379 print("# nblist state:", preproc_state) 380 381 ### print model 382 if simulation_parameters.get("print_model", False): 383 """@keyword[fennol_md] print_model 384 Print detailed model information at startup. 385 Default: False 386 """ 387 print(model.summarize(example_data=conformation)) 388 389 return preproc_state, conformation