fennol.md.initial

  1from pathlib import Path
  2
  3import numpy as np
  4import jax.numpy as jnp
  5
  6from flax.core import freeze, unfreeze
  7
  8from ..utils.io import last_xyz_frame
  9
 10
 11from ..models import FENNIX
 12
 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES
 14from .utils import us
 15from ..utils import detect_topology,parse_cell,cell_is_triangular,tril_cell
 16
 17
 18def load_model(simulation_parameters):
 19    model_file = simulation_parameters.get("model_file")
 20    """@keyword[fennol_md] model_file
 21    Path to the machine learning model file (.fnx format). Required parameter.
 22    Type: str, Required
 23    """
 24    model_file = Path(str(model_file).strip())
 25    if not model_file.exists():
 26        raise FileNotFoundError(f"model file {model_file} not found")
 27    else:
 28        graph_config = simulation_parameters.get("graph_config", {})
 29        """@keyword[fennol_md] graph_config
 30        Advanced graph configuration for model initialization.
 31        Default: {}
 32        """
 33        model = FENNIX.load(model_file, graph_config=graph_config)  # \
 34        print(f"# model_file: {model_file}")
 35
 36    if "energy_terms" in simulation_parameters:
 37        energy_terms = simulation_parameters["energy_terms"]
 38        if isinstance(energy_terms, str):
 39            energy_terms = energy_terms.split()
 40        model.set_energy_terms(energy_terms)
 41        print("# energy terms:", model.energy_terms)
 42
 43    return model
 44
 45
 46def load_system_data(simulation_parameters, fprec):
 47    ## LOAD SYSTEM CONFORMATION FROM FILES
 48    system_name = str(simulation_parameters.get("system_name", "system")).strip()
 49    """@keyword[fennol_md] system_name
 50    Name prefix for output files. If not specified, uses the xyz filename stem.
 51    Default: "system"
 52    """
 53    indexed = simulation_parameters.get("xyz_input/indexed", False)
 54    """@keyword[fennol_md] xyz_input/indexed
 55    Whether first column contains atom indices (Tinker format).
 56    Default: False
 57    """
 58    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True)
 59    """@keyword[fennol_md] xyz_input/has_comment_line
 60    Whether file contains comment lines.
 61    Default: True
 62    """
 63    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 64    """@keyword[fennol_md] xyz_input/file
 65    Path to xyz/arc coordinate file. Required parameter.
 66    Type: str, Required
 67    """
 68    if not xyzfile.exists():
 69        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 70    system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip()
 71    symbols, coordinates, _ = last_xyz_frame(
 72        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 73    )
 74    coordinates = coordinates.astype(fprec)
 75    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 76    nat = species.shape[0]
 77
 78    ## GET MASS
 79    mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 80    deuterate = simulation_parameters.get("deuterate", False)
 81    """@keyword[fennol_md] deuterate
 82    Replace hydrogen masses with deuterium masses.
 83    Default: False
 84    """
 85    if deuterate:
 86        print("# Replacing all hydrogens with deuteriums")
 87        mass_Da[species == 1] *= 2.0
 88
 89    totmass_Da = mass_Da.sum()
 90
 91    mass = mass_Da.copy()
 92    hmr = simulation_parameters.get("hmr", 0)
 93    """@keyword[fennol_md] hmr
 94    Hydrogen mass repartitioning factor. 0 = no repartitioning.
 95    Default: 0
 96    """
 97    if hmr > 0:
 98        print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.")
 99        Hmask = species == 1
100        added_mass = hmr * Hmask.sum()
101        mass[Hmask] += hmr
102        wmass = mass[~Hmask]
103        mass[~Hmask] -= added_mass * wmass/ wmass.sum()
104
105        assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed"
106
107    # convert to internal units
108    mass = mass / us.DA
109
110    ### GET TEMPERATURE
111    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
112    """@keyword[fennol_md] temperature
113    Target temperature in Kelvin.
114    Default: 300.0
115    """
116    kT = us.K_B * temperature 
117
118    ### GET TOTAL CHARGE
119    total_charge = simulation_parameters.get("total_charge", None)
120    """@keyword[fennol_md] total_charge
121    Total system charge for charged systems.
122    Default: None (interpreted as 0)
123    """
124    if total_charge is None:
125        total_charge = 0
126    else:
127        total_charge = int(total_charge)
128        print("# total charge: ", total_charge,"e")
129
130    ### ENERGY UNIT
131    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
132    """@keyword[fennol_md] energy_unit
133    Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'.
134    Default: "kcal/mol"
135    """
136    energy_unit = us.get_multiplier(energy_unit_str)
137
138    ## SYSTEM DATA
139    system_data = {
140        "name": system_name,
141        "nat": nat,
142        "symbols": symbols,
143        "species": species,
144        "mass": mass,
145        "mass_Da": mass_Da,
146        "totmass_Da": totmass_Da,
147        "temperature": temperature,
148        "kT": kT,
149        "total_charge": total_charge,
150        "energy_unit": energy_unit,
151        "energy_unit_str": energy_unit_str,
152    }
153    input_flags = simulation_parameters.get("model_flags", [])
154    """@keyword[fennol_md] model_flags
155    Additional flags to pass to the model.
156    Default: []
157    """
158    flags = {f:None for f in input_flags}
159
160    ### Set boundary conditions
161    cell = simulation_parameters.get("cell", None)
162    """@keyword[fennol_md] cell
163    Unit cell vectors. Required for PBC. It is a sequence of floats:
164    - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz]
165    - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma]
166    - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic)
167    - 1 float: length of cell vectors (cubic cell)
168    Lengths are in Angstroms, angles in degrees.
169    Default: None
170    """
171    if cell is not None:
172        cell = parse_cell(cell).astype(fprec)
173        rotate_cell = not cell_is_triangular(cell)
174        if rotate_cell:
175            print("# Warning: provided cell is not lower triangular. Rotating to canonical cell orientation.")
176            cell, cell_rotation = tril_cell(cell)
177        # cell = np.array(cell, dtype=fprec).reshape(3, 3)
178        reciprocal_cell = np.linalg.inv(cell)
179        volume = np.abs(np.linalg.det(cell))
180        print("# cell matrix:")
181        for l in cell:
182            print("# ", l)
183        # print(cell)
184        dens = (totmass_Da/volume) * (us.MOL/us.CM**3)
185        print("# density: ", dens.item(), " g/cm^3")
186        minimum_image = simulation_parameters.get("minimum_image", True)
187        """@keyword[fennol_md] minimum_image
188        Use minimum image convention for neighbor lists in periodic systems.
189        Default: True
190        """
191        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
192        """@keyword[fennol_md] estimate_pressure
193        Calculate and print pressure during simulation.
194        Default: False
195        """
196        print("# minimum_image: ", minimum_image)
197
198        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
199        """@keyword[fennol_md] xyz_input/crystal
200        Use crystal coordinates.
201        Default: False
202        """
203        if crystal_input:
204            coordinates = coordinates @ cell
205        
206        if rotate_cell:
207            coordinates = coordinates @ cell_rotation
208
209        pbc_data = {
210            "cell": cell,
211            "reciprocal_cell": reciprocal_cell,
212            "volume": volume,
213            "minimum_image": minimum_image,
214            "estimate_pressure": estimate_pressure,
215        }
216        if minimum_image:
217            flags["minimum_image"] = None
218    else:
219        pbc_data = None
220    system_data["pbc"] = pbc_data
221    system_data["initial_coordinates"] = coordinates.copy()
222
223    ### TOPOLOGY
224    topology_key = simulation_parameters.get("topology", None)
225    """@keyword[fennol_md] topology
226    Topology specification for molecular systems. Use "detect" for automatic detection.
227    Default: None
228    """
229    if topology_key is not None:
230        topology_key = str(topology_key).strip()
231        if topology_key.lower() == "detect":
232            topology = detect_topology(species,coordinates,cell=cell)
233            np.savetxt(system_name +".topo", topology+1, fmt="%d")
234            print("# Detected topology saved to", system_name + ".topo")
235        else:
236            assert Path(topology_key).exists(), f"Topology file {topology_key} not found"
237            topology = np.loadtxt(topology_key, dtype=np.int32)-1
238            assert topology.shape[1] == 2, "Topology file must have two columns (source, target)"
239            print("# Topology loaded from", topology_key)
240    else:
241        topology = None
242    
243    system_data["topology"] = topology
244
245    ### PIMD
246    nbeads = simulation_parameters.get("nbeads", None)
247    """@keyword[fennol_md] nbeads
248    Number of beads for Path Integral MD.
249    Default: None
250    """
251    nreplicas = simulation_parameters.get("nreplicas", None)
252    """@keyword[fennol_md] nreplicas
253    Number of replicas for independent replica simulations.
254    Default: None
255    """
256    if nbeads is not None:
257        nbeads = int(nbeads)
258        print("# nbeads: ", nbeads)
259        system_data["nbeads"] = nbeads
260        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
261        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
262        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
263        natoms = np.array([nat] * nbeads, dtype=np.int32)
264
265        eigmat = np.zeros((nbeads, nbeads))
266        for i in range(nbeads - 1):
267            eigmat[i, i] = 2.0
268            eigmat[i, i + 1] = -1.0
269            eigmat[i + 1, i] = -1.0
270        eigmat[-1, -1] = 2.0
271        eigmat[0, -1] = -1.0
272        eigmat[-1, 0] = -1.0
273        omk, eigmat = np.linalg.eigh(eigmat)
274        omk[0] = 0.0
275        omk = (nbeads * kT / us.HBAR) * omk**0.5
276        for i in range(nbeads):
277            if eigmat[i, 0] < 0:
278                eigmat[i] *= -1.0
279        eigmat = jnp.asarray(eigmat, dtype=fprec)
280        system_data["omk"] = omk
281        system_data["eigmat"] = eigmat
282        nreplicas = None
283    elif nreplicas is not None:
284        nreplicas = int(nreplicas)
285        print("# nreplicas: ", nreplicas)
286        system_data["nreplicas"] = nreplicas
287        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
288        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
289            -1
290        )
291        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
292            -1, 3
293        )
294        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
295        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
296        natoms = np.array([nat] * nreplicas, dtype=np.int32)
297    else:
298        system_data["nreplicas"] = 1
299        bead_index = np.array([0] * nat, dtype=np.int32)
300        natoms = np.array([nat], dtype=np.int32)
301
302    conformation = {
303        "species": species,
304        "coordinates": coordinates,
305        "batch_index": bead_index,
306        "natoms": natoms,
307        "total_charge": total_charge,
308    }
309    if cell is not None:
310        cell = cell[None, :, :]
311        reciprocal_cell = reciprocal_cell[None, :, :]
312        if nbeads is not None:
313            cell = np.repeat(cell, nbeads, axis=0)
314            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
315        elif nreplicas is not None:
316            cell = np.repeat(cell, nreplicas, axis=0)
317            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
318        conformation["cells"] = cell
319        conformation["reciprocal_cells"] = reciprocal_cell
320
321    additional_keys = simulation_parameters.get("additional_keys", {})
322    """@keyword[fennol_md] additional_keys
323    Additional custom keys for model input.
324    Default: {}
325    """
326    for key, value in additional_keys.items():
327        conformation[key] = value
328    
329    conformation["flags"] = flags
330
331    return system_data, conformation
332
333
334def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
335    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
336    """@keyword[fennol_md] nblist_verbose
337    Print detailed neighbor list information.
338    Default: False
339    """
340    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
341    """@keyword[fennol_md] nblist_skin
342    Neighbor list skin distance in Angstroms.
343    Default: -1.0 (automatic)
344    """
345
346    ### CONFIGURE PREPROCESSING
347    preproc_state = unfreeze(model.preproc_state)
348    layer_state = []
349    for st in preproc_state["layers_state"]:
350        stnew = unfreeze(st)
351        if nblist_skin > 0:
352            stnew["nblist_skin"] = nblist_skin
353        if "nblist_mult_size" in simulation_parameters:
354            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
355            """@keyword[fennol_md] nblist_mult_size
356            Multiplier for neighbor list size.
357            Default: None
358            """
359        if "nblist_add_neigh" in simulation_parameters:
360            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
361            """@keyword[fennol_md] nblist_add_neigh
362            Additional neighbors to include in lists.
363            Default: None
364            """
365        layer_state.append(freeze(stnew))
366    preproc_state["layers_state"] = layer_state
367    preproc_state = freeze(preproc_state)
368
369    ## initial preprocessing
370    preproc_state = preproc_state.copy({"check_input": True})
371    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
372
373    preproc_state = preproc_state.copy({"check_input": False})
374
375    if nblist_verbose:
376        graphs_keys = list(model._graphs_properties.keys())
377        print("# graphs_keys: ", graphs_keys)
378        print("# nblist state:", preproc_state)
379
380    ### print model
381    if simulation_parameters.get("print_model", False):
382        """@keyword[fennol_md] print_model
383        Print detailed model information at startup.
384        Default: False
385        """
386        print(model.summarize(example_data=conformation))
387
388    return preproc_state, conformation
def load_model(simulation_parameters):
19def load_model(simulation_parameters):
20    model_file = simulation_parameters.get("model_file")
21    """@keyword[fennol_md] model_file
22    Path to the machine learning model file (.fnx format). Required parameter.
23    Type: str, Required
24    """
25    model_file = Path(str(model_file).strip())
26    if not model_file.exists():
27        raise FileNotFoundError(f"model file {model_file} not found")
28    else:
29        graph_config = simulation_parameters.get("graph_config", {})
30        """@keyword[fennol_md] graph_config
31        Advanced graph configuration for model initialization.
32        Default: {}
33        """
34        model = FENNIX.load(model_file, graph_config=graph_config)  # \
35        print(f"# model_file: {model_file}")
36
37    if "energy_terms" in simulation_parameters:
38        energy_terms = simulation_parameters["energy_terms"]
39        if isinstance(energy_terms, str):
40            energy_terms = energy_terms.split()
41        model.set_energy_terms(energy_terms)
42        print("# energy terms:", model.energy_terms)
43
44    return model
def load_system_data(simulation_parameters, fprec):
 47def load_system_data(simulation_parameters, fprec):
 48    ## LOAD SYSTEM CONFORMATION FROM FILES
 49    system_name = str(simulation_parameters.get("system_name", "system")).strip()
 50    """@keyword[fennol_md] system_name
 51    Name prefix for output files. If not specified, uses the xyz filename stem.
 52    Default: "system"
 53    """
 54    indexed = simulation_parameters.get("xyz_input/indexed", False)
 55    """@keyword[fennol_md] xyz_input/indexed
 56    Whether first column contains atom indices (Tinker format).
 57    Default: False
 58    """
 59    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True)
 60    """@keyword[fennol_md] xyz_input/has_comment_line
 61    Whether file contains comment lines.
 62    Default: True
 63    """
 64    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 65    """@keyword[fennol_md] xyz_input/file
 66    Path to xyz/arc coordinate file. Required parameter.
 67    Type: str, Required
 68    """
 69    if not xyzfile.exists():
 70        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 71    system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip()
 72    symbols, coordinates, _ = last_xyz_frame(
 73        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 74    )
 75    coordinates = coordinates.astype(fprec)
 76    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 77    nat = species.shape[0]
 78
 79    ## GET MASS
 80    mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 81    deuterate = simulation_parameters.get("deuterate", False)
 82    """@keyword[fennol_md] deuterate
 83    Replace hydrogen masses with deuterium masses.
 84    Default: False
 85    """
 86    if deuterate:
 87        print("# Replacing all hydrogens with deuteriums")
 88        mass_Da[species == 1] *= 2.0
 89
 90    totmass_Da = mass_Da.sum()
 91
 92    mass = mass_Da.copy()
 93    hmr = simulation_parameters.get("hmr", 0)
 94    """@keyword[fennol_md] hmr
 95    Hydrogen mass repartitioning factor. 0 = no repartitioning.
 96    Default: 0
 97    """
 98    if hmr > 0:
 99        print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.")
100        Hmask = species == 1
101        added_mass = hmr * Hmask.sum()
102        mass[Hmask] += hmr
103        wmass = mass[~Hmask]
104        mass[~Hmask] -= added_mass * wmass/ wmass.sum()
105
106        assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed"
107
108    # convert to internal units
109    mass = mass / us.DA
110
111    ### GET TEMPERATURE
112    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
113    """@keyword[fennol_md] temperature
114    Target temperature in Kelvin.
115    Default: 300.0
116    """
117    kT = us.K_B * temperature 
118
119    ### GET TOTAL CHARGE
120    total_charge = simulation_parameters.get("total_charge", None)
121    """@keyword[fennol_md] total_charge
122    Total system charge for charged systems.
123    Default: None (interpreted as 0)
124    """
125    if total_charge is None:
126        total_charge = 0
127    else:
128        total_charge = int(total_charge)
129        print("# total charge: ", total_charge,"e")
130
131    ### ENERGY UNIT
132    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
133    """@keyword[fennol_md] energy_unit
134    Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'.
135    Default: "kcal/mol"
136    """
137    energy_unit = us.get_multiplier(energy_unit_str)
138
139    ## SYSTEM DATA
140    system_data = {
141        "name": system_name,
142        "nat": nat,
143        "symbols": symbols,
144        "species": species,
145        "mass": mass,
146        "mass_Da": mass_Da,
147        "totmass_Da": totmass_Da,
148        "temperature": temperature,
149        "kT": kT,
150        "total_charge": total_charge,
151        "energy_unit": energy_unit,
152        "energy_unit_str": energy_unit_str,
153    }
154    input_flags = simulation_parameters.get("model_flags", [])
155    """@keyword[fennol_md] model_flags
156    Additional flags to pass to the model.
157    Default: []
158    """
159    flags = {f:None for f in input_flags}
160
161    ### Set boundary conditions
162    cell = simulation_parameters.get("cell", None)
163    """@keyword[fennol_md] cell
164    Unit cell vectors. Required for PBC. It is a sequence of floats:
165    - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz]
166    - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma]
167    - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic)
168    - 1 float: length of cell vectors (cubic cell)
169    Lengths are in Angstroms, angles in degrees.
170    Default: None
171    """
172    if cell is not None:
173        cell = parse_cell(cell).astype(fprec)
174        rotate_cell = not cell_is_triangular(cell)
175        if rotate_cell:
176            print("# Warning: provided cell is not lower triangular. Rotating to canonical cell orientation.")
177            cell, cell_rotation = tril_cell(cell)
178        # cell = np.array(cell, dtype=fprec).reshape(3, 3)
179        reciprocal_cell = np.linalg.inv(cell)
180        volume = np.abs(np.linalg.det(cell))
181        print("# cell matrix:")
182        for l in cell:
183            print("# ", l)
184        # print(cell)
185        dens = (totmass_Da/volume) * (us.MOL/us.CM**3)
186        print("# density: ", dens.item(), " g/cm^3")
187        minimum_image = simulation_parameters.get("minimum_image", True)
188        """@keyword[fennol_md] minimum_image
189        Use minimum image convention for neighbor lists in periodic systems.
190        Default: True
191        """
192        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
193        """@keyword[fennol_md] estimate_pressure
194        Calculate and print pressure during simulation.
195        Default: False
196        """
197        print("# minimum_image: ", minimum_image)
198
199        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
200        """@keyword[fennol_md] xyz_input/crystal
201        Use crystal coordinates.
202        Default: False
203        """
204        if crystal_input:
205            coordinates = coordinates @ cell
206        
207        if rotate_cell:
208            coordinates = coordinates @ cell_rotation
209
210        pbc_data = {
211            "cell": cell,
212            "reciprocal_cell": reciprocal_cell,
213            "volume": volume,
214            "minimum_image": minimum_image,
215            "estimate_pressure": estimate_pressure,
216        }
217        if minimum_image:
218            flags["minimum_image"] = None
219    else:
220        pbc_data = None
221    system_data["pbc"] = pbc_data
222    system_data["initial_coordinates"] = coordinates.copy()
223
224    ### TOPOLOGY
225    topology_key = simulation_parameters.get("topology", None)
226    """@keyword[fennol_md] topology
227    Topology specification for molecular systems. Use "detect" for automatic detection.
228    Default: None
229    """
230    if topology_key is not None:
231        topology_key = str(topology_key).strip()
232        if topology_key.lower() == "detect":
233            topology = detect_topology(species,coordinates,cell=cell)
234            np.savetxt(system_name +".topo", topology+1, fmt="%d")
235            print("# Detected topology saved to", system_name + ".topo")
236        else:
237            assert Path(topology_key).exists(), f"Topology file {topology_key} not found"
238            topology = np.loadtxt(topology_key, dtype=np.int32)-1
239            assert topology.shape[1] == 2, "Topology file must have two columns (source, target)"
240            print("# Topology loaded from", topology_key)
241    else:
242        topology = None
243    
244    system_data["topology"] = topology
245
246    ### PIMD
247    nbeads = simulation_parameters.get("nbeads", None)
248    """@keyword[fennol_md] nbeads
249    Number of beads for Path Integral MD.
250    Default: None
251    """
252    nreplicas = simulation_parameters.get("nreplicas", None)
253    """@keyword[fennol_md] nreplicas
254    Number of replicas for independent replica simulations.
255    Default: None
256    """
257    if nbeads is not None:
258        nbeads = int(nbeads)
259        print("# nbeads: ", nbeads)
260        system_data["nbeads"] = nbeads
261        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
262        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
263        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
264        natoms = np.array([nat] * nbeads, dtype=np.int32)
265
266        eigmat = np.zeros((nbeads, nbeads))
267        for i in range(nbeads - 1):
268            eigmat[i, i] = 2.0
269            eigmat[i, i + 1] = -1.0
270            eigmat[i + 1, i] = -1.0
271        eigmat[-1, -1] = 2.0
272        eigmat[0, -1] = -1.0
273        eigmat[-1, 0] = -1.0
274        omk, eigmat = np.linalg.eigh(eigmat)
275        omk[0] = 0.0
276        omk = (nbeads * kT / us.HBAR) * omk**0.5
277        for i in range(nbeads):
278            if eigmat[i, 0] < 0:
279                eigmat[i] *= -1.0
280        eigmat = jnp.asarray(eigmat, dtype=fprec)
281        system_data["omk"] = omk
282        system_data["eigmat"] = eigmat
283        nreplicas = None
284    elif nreplicas is not None:
285        nreplicas = int(nreplicas)
286        print("# nreplicas: ", nreplicas)
287        system_data["nreplicas"] = nreplicas
288        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
289        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
290            -1
291        )
292        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
293            -1, 3
294        )
295        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
296        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
297        natoms = np.array([nat] * nreplicas, dtype=np.int32)
298    else:
299        system_data["nreplicas"] = 1
300        bead_index = np.array([0] * nat, dtype=np.int32)
301        natoms = np.array([nat], dtype=np.int32)
302
303    conformation = {
304        "species": species,
305        "coordinates": coordinates,
306        "batch_index": bead_index,
307        "natoms": natoms,
308        "total_charge": total_charge,
309    }
310    if cell is not None:
311        cell = cell[None, :, :]
312        reciprocal_cell = reciprocal_cell[None, :, :]
313        if nbeads is not None:
314            cell = np.repeat(cell, nbeads, axis=0)
315            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
316        elif nreplicas is not None:
317            cell = np.repeat(cell, nreplicas, axis=0)
318            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
319        conformation["cells"] = cell
320        conformation["reciprocal_cells"] = reciprocal_cell
321
322    additional_keys = simulation_parameters.get("additional_keys", {})
323    """@keyword[fennol_md] additional_keys
324    Additional custom keys for model input.
325    Default: {}
326    """
327    for key, value in additional_keys.items():
328        conformation[key] = value
329    
330    conformation["flags"] = flags
331
332    return system_data, conformation
def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
335def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
336    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
337    """@keyword[fennol_md] nblist_verbose
338    Print detailed neighbor list information.
339    Default: False
340    """
341    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
342    """@keyword[fennol_md] nblist_skin
343    Neighbor list skin distance in Angstroms.
344    Default: -1.0 (automatic)
345    """
346
347    ### CONFIGURE PREPROCESSING
348    preproc_state = unfreeze(model.preproc_state)
349    layer_state = []
350    for st in preproc_state["layers_state"]:
351        stnew = unfreeze(st)
352        if nblist_skin > 0:
353            stnew["nblist_skin"] = nblist_skin
354        if "nblist_mult_size" in simulation_parameters:
355            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
356            """@keyword[fennol_md] nblist_mult_size
357            Multiplier for neighbor list size.
358            Default: None
359            """
360        if "nblist_add_neigh" in simulation_parameters:
361            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
362            """@keyword[fennol_md] nblist_add_neigh
363            Additional neighbors to include in lists.
364            Default: None
365            """
366        layer_state.append(freeze(stnew))
367    preproc_state["layers_state"] = layer_state
368    preproc_state = freeze(preproc_state)
369
370    ## initial preprocessing
371    preproc_state = preproc_state.copy({"check_input": True})
372    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
373
374    preproc_state = preproc_state.copy({"check_input": False})
375
376    if nblist_verbose:
377        graphs_keys = list(model._graphs_properties.keys())
378        print("# graphs_keys: ", graphs_keys)
379        print("# nblist state:", preproc_state)
380
381    ### print model
382    if simulation_parameters.get("print_model", False):
383        """@keyword[fennol_md] print_model
384        Print detailed model information at startup.
385        Default: False
386        """
387        print(model.summarize(example_data=conformation))
388
389    return preproc_state, conformation