fennol.models.fennix

  1from typing import Any, Sequence, Callable, Union, Optional, Tuple, Dict
  2from copy import deepcopy
  3import dataclasses
  4from collections import OrderedDict
  5
  6import jax
  7import jax.numpy as jnp
  8import flax.linen as nn
  9import numpy as np
 10from flax import serialization
 11from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
 12from ..utils.atomic_units import au
 13
 14from .preprocessing import (
 15    GraphGenerator,
 16    PreprocessingChain,
 17    JaxConverter,
 18    atom_unpadding,
 19    check_input,
 20    convert_to_jax,
 21    convert_to_numpy,
 22)
 23from .modules import MODULES, PREPROCESSING, FENNIXModules
 24
 25
 26@dataclasses.dataclass
 27class FENNIX:
 28    """
 29    Static wrapper for FENNIX models
 30
 31    The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary
 32    which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization.
 33
 34    Since the model is static and contains variables, it must be initialized right away with either
 35    `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data`
 36    is provided, the model is initialized with `example_data` and the resulting variables are stored
 37    in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting.
 38    """
 39
 40    cutoff: Union[float, None]
 41    modules: FENNIXModules
 42    variables: Dict
 43    preprocessing: PreprocessingChain
 44    _apply: Callable[[Dict, Dict], Dict]
 45    _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]]
 46    _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]]
 47    _input_args: Dict
 48    _graphs_properties: Dict
 49    preproc_state: Dict
 50    energy_terms: Optional[Sequence[str]] = None
 51    _initializing: bool = True
 52    use_atom_padding: bool = False
 53
 54    def __init__(
 55        self,
 56        cutoff: float,
 57        modules: OrderedDict,
 58        preprocessing: OrderedDict = OrderedDict(),
 59        example_data=None,
 60        rng_key: Optional[jax.random.PRNGKey] = None,
 61        variables: Optional[dict] = None,
 62        energy_terms: Optional[Sequence[str]] = None,
 63        use_atom_padding: bool = False,
 64        graph_config: Dict = {},
 65        energy_unit: str = "Ha",
 66        **kwargs,
 67    ) -> None:
 68        """Initialize the FENNIX model
 69
 70        Arguments:
 71        ----------
 72        cutoff: float
 73            The cutoff radius for the model
 74        modules: OrderedDict
 75            The dictionary defining the sequence of FeNNol modules and their parameters.
 76        preprocessing: OrderedDict
 77            The dictionary defining the sequence of preprocessing modules and their parameters.
 78        example_data: dict
 79            Example data to initialize the model. If not provided, a dummy system is generated.
 80        rng_key: jax.random.PRNGKey
 81            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 82        variables: dict
 83            The variables of the model (i.e. weights, biases and all other tunable parameters).
 84            If not provided, the variables are initialized (usually at random)
 85        energy_terms: Sequence[str]
 86            The energy terms in the model output that will be summed to compute the total energy.
 87            If None, the total energy is always zero (useful for non-PES models).
 88        use_atom_padding: bool
 89            If True, the model will use atom padding for the input data.
 90            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 91        graph_config: dict
 92            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 93
 94        """
 95        self._input_args = {
 96            "cutoff": cutoff,
 97            "modules": OrderedDict(modules),
 98            "preprocessing": OrderedDict(preprocessing),
 99            "energy_unit": energy_unit,
100        }
101        self.energy_unit = energy_unit
102        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
103        self.cutoff = cutoff
104        self.energy_terms = energy_terms
105        self.use_atom_padding = use_atom_padding
106
107        # add non-differentiable/non-jittable modules
108        preprocessing = deepcopy(preprocessing)
109        if cutoff is None:
110            preprocessing_modules = []
111        else:
112            prep_keys = list(preprocessing.keys())
113            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
114            if len(prep_keys) > 0 and prep_keys[0] == "graph":
115                graph_params = {
116                    **graph_params,
117                    **preprocessing.pop("graph"),
118                }
119            graph_params = {**graph_params, **graph_config}
120
121            preprocessing_modules = [
122                GraphGenerator(**graph_params),
123            ]
124
125        for name, params in preprocessing.items():
126            key = str(params.pop("module_name")) if "module_name" in params else name
127            key = str(params.pop("FID")) if "FID" in params else key
128            mod = PREPROCESSING[key.upper()](**freeze(params))
129            preprocessing_modules.append(mod)
130
131        self.preprocessing = PreprocessingChain(
132            tuple(preprocessing_modules), use_atom_padding
133        )
134        graphs_properties = self.preprocessing.get_graphs_properties()
135        self._graphs_properties = freeze(graphs_properties)
136        # add preprocessing modules that should be differentiated/jitted
137        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
138        # mods = self.preprocessing.get_processors(return_list=True)
139
140        # build the model
141        modules = deepcopy(modules)
142        modules_names = []
143        for name, params in modules.items():
144            key = str(params.pop("module_name")) if "module_name" in params else name
145            key = str(params.pop("FID")) if "FID" in params else key
146            if name in modules_names:
147                raise ValueError(f"Module {name} already exists")
148            modules_names.append(name)
149            params["name"] = name
150            mod = MODULES[key.upper()]
151            fields = [f.name for f in dataclasses.fields(mod)]
152            if "_graphs_properties" in fields:
153                params["_graphs_properties"] = graphs_properties
154            if "_energy_unit" in fields:
155                params["_energy_unit"] = energy_unit
156            mods.append((mod, params))
157
158        self.modules = FENNIXModules(mods)
159
160        self.__apply = self.modules.apply
161        self._apply = jax.jit(self.modules.apply)
162
163        self.set_energy_terms(energy_terms)
164
165        # initialize the model
166
167        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
168
169        if variables is not None:
170            self.variables = variables
171        elif rng_key is not None:
172            self.variables = self.modules.init(rng_key, inputs)
173        else:
174            raise ValueError(
175                "Either variables or a jax.random.PRNGKey must be provided for initialization"
176            )
177
178        self._initializing = False
179
180    def set_energy_terms(
181        self, energy_terms: Union[Sequence[str], None], jit: bool = True
182    ) -> None:
183        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
184        object.__setattr__(self, "energy_terms", energy_terms)
185        if isinstance(energy_terms, str):
186            energy_terms = [energy_terms]
187        
188        if energy_terms is None or len(energy_terms) == 0:
189
190            def total_energy(variables, data):
191                out = self.__apply(variables, data)
192                coords = out["coordinates"]
193                nsys = out["natoms"].shape[0]
194                nat = coords.shape[0]
195                dtype = coords.dtype
196                e = jnp.zeros(nsys, dtype=dtype)
197                eat = jnp.zeros(nat, dtype=dtype)
198                out["total_energy"] = e
199                out["atomic_energies"] = eat
200                return e, out
201
202            def energy_and_forces(variables, data):
203                e, out = total_energy(variables, data)
204                f = jnp.zeros_like(out["coordinates"])
205                out["forces"] = f
206                return e, f, out
207
208            def energy_and_forces_and_virial(variables, data):
209                e, f, out = energy_and_forces(variables, data)
210                v = jnp.zeros(
211                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
212                )
213                out["virial_tensor"] = v
214                return e, f, v, out
215
216        else:
217            # build the energy and force functions
218            def total_energy(variables, data):
219                out = self.__apply(variables, data)
220                atomic_energies = 0.0
221                system_energies = 0.0
222                species = out["species"]
223                nsys = out["natoms"].shape[0]
224                for term in self.energy_terms:
225                    e = out[term]
226                    if e.ndim > 1 and e.shape[-1] == 1:
227                        e = jnp.squeeze(e, axis=-1)
228                    if e.shape[0] == nsys and nsys != species.shape[0]:
229                        system_energies += e
230                        continue
231                    assert e.shape == species.shape
232                    atomic_energies += e
233                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
234                if isinstance(atomic_energies, jnp.ndarray):
235                    if "true_atoms" in out:
236                        atomic_energies = jnp.where(
237                            out["true_atoms"], atomic_energies, 0.0
238                        )
239                    atomic_energies = atomic_energies.flatten()
240                    assert atomic_energies.shape == out["species"].shape
241                    out["atomic_energies"] = atomic_energies
242                    energies = jax.ops.segment_sum(
243                        atomic_energies,
244                        data["batch_index"],
245                        num_segments=len(data["natoms"]),
246                    )
247                else:
248                    energies = 0.0
249
250                if isinstance(system_energies, jnp.ndarray):
251                    if "true_sys" in out:
252                        system_energies = jnp.where(
253                            out["true_sys"], system_energies, 0.0
254                        )
255                    out["system_energies"] = system_energies
256
257                out["total_energy"] = energies + system_energies
258                return out["total_energy"], out
259
260            def energy_and_forces(variables, data):
261                def _etot(variables, coordinates):
262                    energy, out = total_energy(
263                        variables, {**data, "coordinates": coordinates}
264                    )
265                    return energy.sum(), out
266
267                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
268                    variables, data["coordinates"]
269                )
270                out["forces"] = -de
271
272                return out["total_energy"], out["forces"], out
273
274            # def energy_and_forces_and_virial(variables, data):
275            #     x = data["coordinates"]
276            #     batch_index = data["batch_index"]
277            #     if "cells" in data:
278            #         cells = data["cells"]
279            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
280
281            #         def _etot(variables, coordinates, cells):
282            #             reciprocal_cells = jnp.linalg.inv(cells)
283            #             energy, out = total_energy(
284            #                 variables,
285            #                 {
286            #                     **data,
287            #                     "coordinates": coordinates,
288            #                     "cells": cells,
289            #                     "reciprocal_cells": reciprocal_cells,
290            #                 },
291            #             )
292            #             return energy.sum(), out
293
294            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
295            #             variables, x, cells
296            #         )
297            #         f= -dedx
298            #         out["forces"] = f
299            #     else:
300            #         _,f,out = energy_and_forces(variables, data)
301
302            #     vir = -jax.ops.segment_sum(
303            #         f[:, :, None] * x[:, None, :],
304            #         batch_index,
305            #         num_segments=len(data["natoms"]),
306            #     )
307
308            #     if "cells" in data:
309            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
310            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
311            #         nsys = data["natoms"].shape[0]
312            #         if cells.shape[0]==1 and nsys>1:
313            #             dvir = dvir / nsys
314            #         vir = vir + dvir
315
316            #     out["virial_tensor"] = vir
317
318            #     return out["total_energy"], f, vir, out
319
320            def energy_and_forces_and_virial(variables, data):
321                x = data["coordinates"]
322                scaling = jnp.asarray(
323                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
324                )
325                def _etot(variables, coordinates, scaling):
326                    batch_index = data["batch_index"]
327                    coordinates = jax.vmap(jnp.matmul)(
328                        coordinates, scaling[batch_index]
329                    )
330                    inputs = {**data, "coordinates": coordinates}
331                    if "cells" in data:
332                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
333                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
334                        reciprocal_cells = jnp.linalg.inv(cells)
335                        inputs["cells"] = cells
336                        inputs["reciprocal_cells"] = reciprocal_cells
337                    energy, out = total_energy(variables, inputs)
338                    return energy.sum(), out
339
340                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
341                    variables, x, scaling
342                )
343                f = -dedx
344                out["forces"] = f
345                out["virial_tensor"] = vir
346
347                return out["total_energy"], f, vir, out
348
349        object.__setattr__(self, "_total_energy_raw", total_energy)
350        if jit:
351            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
352            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
353            object.__setattr__(
354                self,
355                "_energy_and_forces_and_virial",
356                jax.jit(energy_and_forces_and_virial),
357            )
358        else:
359            object.__setattr__(self, "_total_energy", total_energy)
360            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
361            object.__setattr__(
362                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
363            )
364
365    def get_gradient_function(
366        self,
367        *gradient_keys: Sequence[str],
368        jit: bool = True,
369        variables_as_input: bool = False,
370    ):
371        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
372
373        def _energy_gradient(variables, data):
374            def _etot(variables, inputs):
375                if "strain" in inputs:
376                    scaling = inputs["strain"]
377                    batch_index = data["batch_index"]
378                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
379                    coordinates = jax.vmap(jnp.matmul)(
380                        coordinates, scaling[batch_index]
381                    )
382                    inputs = {**inputs, "coordinates": coordinates}
383                    if "cells" in inputs or "cells" in data:
384                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
385                        cells = jax.vmap(jnp.matmul)(cells, scaling)
386                        inputs["cells"] = cells
387                if "cells" in inputs:
388                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
389                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
390                energy, out = self._total_energy_raw(variables, {**data, **inputs})
391                return energy.sum(), out
392
393            if "strain" in gradient_keys and "strain" not in data:
394                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
395            inputs = {k: data[k] for k in gradient_keys}
396            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
397
398            return (
399                out["total_energy"],
400                de,
401                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
402            )
403
404        if variables_as_input:
405            energy_gradient = _energy_gradient
406        else:
407
408            def energy_gradient(data):
409                return _energy_gradient(self.variables, data)
410
411        if jit:
412            return jax.jit(energy_gradient)
413        else:
414            return energy_gradient
415
416    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
417        """apply preprocessing to the input data
418
419        !!! This is not a pure function => do not apply jax transforms !!!"""
420        if self.preproc_state is None:
421            out, _ = self.reinitialize_preprocessing(example_data=inputs)
422        elif use_gpu:
423            do_check_input = self.preproc_state.get("check_input", True)
424            if do_check_input:
425                inputs = check_input(inputs)
426            preproc_state, inputs = self.preprocessing.atom_padding(
427                self.preproc_state, inputs
428            )
429            inputs = self.preprocessing.process(preproc_state, inputs)
430            preproc_state, state_up, out, overflow = (
431                self.preprocessing.check_reallocate(
432                    preproc_state, inputs
433                )
434            )
435            if verbose and overflow:
436                print("GPU preprocessing: nblist overflow => reallocating nblist")
437                print("size updates:", state_up)
438        else:
439            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
440
441        object.__setattr__(self, "preproc_state", preproc_state)
442        return out
443
444    def reinitialize_preprocessing(
445        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
446    ) -> None:
447        ### TODO ###
448        if rng_key is None:
449            rng_key_pre = jax.random.PRNGKey(0)
450        else:
451            rng_key, rng_key_pre = jax.random.split(rng_key)
452
453        if example_data is None:
454            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
455            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
456
457        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
458        object.__setattr__(self, "preproc_state", preproc_state)
459        return inputs, rng_key
460
461    def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]:
462        """Apply the FENNIX model (preprocess + modules)
463
464        !!! This is not a pure function => do not apply jax transforms !!!
465        if you want to apply jax transforms, use  self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess
466        """
467        if variables is None:
468            variables = self.variables
469        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
470        output = self._apply(variables, inputs)
471        if self.use_atom_padding:
472            output = atom_unpadding(output)
473        return output
474
475    def total_energy(
476        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
477    ) -> Tuple[jnp.ndarray, Dict]:
478        """compute the total energy of the system
479
480        !!! This is not a pure function => do not apply jax transforms !!!
481        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
482        """
483        if variables is None:
484            variables = self.variables
485        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
486        # def print_shape(path,value):
487        #     if isinstance(value,jnp.ndarray):
488        #         print(path,value.shape)
489        #     else:
490        #         print(path,value)
491        # jax.tree_util.tree_map_with_path(print_shape,inputs)
492        _, output = self._total_energy(variables, inputs)
493        if self.use_atom_padding:
494            output = atom_unpadding(output)
495        e = output["total_energy"]
496        if unit is not None:
497            model_energy_unit = self.Ha_to_model_energy
498            if isinstance(unit, str):
499                unit = au.get_multiplier(unit)
500            e = e * (unit / model_energy_unit)
501        return e, output
502
503    def energy_and_forces(
504        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
505    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
506        """compute the total energy and forces of the system
507
508        !!! This is not a pure function => do not apply jax transforms !!!
509        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
510        """
511        if variables is None:
512            variables = self.variables
513        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
514        _, _, output = self._energy_and_forces(variables, inputs)
515        if self.use_atom_padding:
516            output = atom_unpadding(output)
517        e = output["total_energy"]
518        f = output["forces"]
519        if unit is not None:
520            model_energy_unit = self.Ha_to_model_energy
521            if isinstance(unit, str):
522                unit = au.get_multiplier(unit)
523            e = e * (unit / model_energy_unit)
524            f = f * (unit / model_energy_unit)
525        return e, f, output
526
527    def energy_and_forces_and_virial(
528        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
529    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
530        """compute the total energy and forces of the system
531
532        !!! This is not a pure function => do not apply jax transforms !!!
533        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
534        """
535        if variables is None:
536            variables = self.variables
537        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
538        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
539        if self.use_atom_padding:
540            output = atom_unpadding(output)
541        e = output["total_energy"]
542        f = output["forces"]
543        v = output["virial_tensor"]
544        if unit is not None:
545            model_energy_unit = self.Ha_to_model_energy
546            if isinstance(unit, str):
547                unit = au.get_multiplier(unit)
548            e = e * (unit / model_energy_unit)
549            f = f * (unit / model_energy_unit)
550            v = v * (unit / model_energy_unit)
551        return e, f, v, output
552
553    def remove_atom_padding(self, output):
554        """remove atom padding from the output"""
555        return atom_unpadding(output)
556
557    def get_model(self) -> Tuple[FENNIXModules, Dict]:
558        """return the model and its variables"""
559        return self.modules, self.variables
560
561    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
562        """return the preprocessing chain and its state"""
563        return self.preprocessing, self.preproc_state
564
565    def __setattr__(self, __name: str, __value: Any) -> None:
566        if __name == "variables":
567            if __value is not None:
568                if not (
569                    isinstance(__value, dict)
570                    or isinstance(__value, OrderedDict)
571                    or isinstance(__value, FrozenDict)
572                ):
573                    raise ValueError(f"{__name} must be a dict")
574                object.__setattr__(self, __name, JaxConverter()(__value))
575            else:
576                raise ValueError(f"{__name} cannot be None")
577        elif __name == "preproc_state":
578            if __value is not None:
579                if not (
580                    isinstance(__value, dict)
581                    or isinstance(__value, OrderedDict)
582                    or isinstance(__value, FrozenDict)
583                ):
584                    raise ValueError(f"{__name} must be a FrozenDict")
585                object.__setattr__(self, __name, freeze(JaxConverter()(__value)))
586            else:
587                raise ValueError(f"{__name} cannot be None")
588
589        elif self._initializing:
590            object.__setattr__(self, __name, __value)
591        else:
592            raise ValueError(f"{__name} attribute of FENNIX model is immutable.")
593
594    def generate_dummy_system(
595        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
596    ) -> Dict[str, Any]:
597        """
598        Generate dummy system for initialization
599        """
600        if box_size is None:
601            box_size = 2 * self.cutoff
602            for g in self._graphs_properties.values():
603                cutoff = g["cutoff"]
604                if cutoff is not None:
605                    box_size = min(box_size, 2 * g["cutoff"])
606        coordinates = np.array(
607            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
608        )
609        species = np.ones((n_atoms,), dtype=np.int32)
610        batch_index = np.zeros((n_atoms,), dtype=np.int32)
611        natoms = np.array([n_atoms], dtype=np.int32)
612        return {
613            "species": species,
614            "coordinates": coordinates,
615            # "graph": graph,
616            "batch_index": batch_index,
617            "natoms": natoms,
618        }
619
620    def summarize(
621        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
622    ) -> str:
623        """Summarize the model architecture and parameters"""
624        if rng_key is None:
625            head = "Summarizing with example data:\n"
626            rng_key = jax.random.PRNGKey(0)
627        if example_data is None:
628            head = "Summarizing with dummy 10 atoms system:\n"
629            rng_key, rng_key_sys = jax.random.split(rng_key)
630            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
631        rng_key, rng_key_pre = jax.random.split(rng_key)
632        _, inputs = self.preprocessing.init_with_output(example_data)
633        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
634
635    def to_dict(self,convert_numpy=False):
636        """return a dictionary representation of the model"""
637        if convert_numpy:
638            variables = convert_to_numpy(self.variables)
639        else:
640            variables = deepcopy(self.variables)
641        return {
642            **self._input_args,
643            "energy_terms": self.energy_terms,
644            "variables": variables,
645        }
646
647    def save(self, filename):
648        """save the model to a file"""
649        filename_str = str(filename)
650        do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle")
651        state_dict = self.to_dict(convert_numpy=do_pickle)
652        state_dict["preprocessing"] = [
653            [k, v] for k, v in state_dict["preprocessing"].items()
654        ]
655        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
656        if do_pickle:
657            import pickle
658            with open(filename, "wb") as f:
659                pickle.dump(state_dict, f)
660        else:
661            with open(filename, "wb") as f:
662                f.write(serialization.msgpack_serialize(state_dict))
663
664    @classmethod
665    def load(
666        cls,
667        filename,
668        use_atom_padding=False,
669        graph_config={},
670    ):
671        """load a model from a file"""
672        filename_str = str(filename)
673        do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle")
674        if do_pickle:
675            import pickle
676            with open(filename, "rb") as f:
677                state_dict = pickle.load(f)
678                state_dict["variables"] = convert_to_jax(state_dict["variables"])
679        else:
680            with open(filename, "rb") as f:
681                state_dict = serialization.msgpack_restore(f.read())
682        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
683        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
684        return cls(
685            **state_dict,
686            graph_config=graph_config,
687            use_atom_padding=use_atom_padding,
688        )
@dataclasses.dataclass
class FENNIX:
 27@dataclasses.dataclass
 28class FENNIX:
 29    """
 30    Static wrapper for FENNIX models
 31
 32    The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary
 33    which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization.
 34
 35    Since the model is static and contains variables, it must be initialized right away with either
 36    `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data`
 37    is provided, the model is initialized with `example_data` and the resulting variables are stored
 38    in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting.
 39    """
 40
 41    cutoff: Union[float, None]
 42    modules: FENNIXModules
 43    variables: Dict
 44    preprocessing: PreprocessingChain
 45    _apply: Callable[[Dict, Dict], Dict]
 46    _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]]
 47    _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]]
 48    _input_args: Dict
 49    _graphs_properties: Dict
 50    preproc_state: Dict
 51    energy_terms: Optional[Sequence[str]] = None
 52    _initializing: bool = True
 53    use_atom_padding: bool = False
 54
 55    def __init__(
 56        self,
 57        cutoff: float,
 58        modules: OrderedDict,
 59        preprocessing: OrderedDict = OrderedDict(),
 60        example_data=None,
 61        rng_key: Optional[jax.random.PRNGKey] = None,
 62        variables: Optional[dict] = None,
 63        energy_terms: Optional[Sequence[str]] = None,
 64        use_atom_padding: bool = False,
 65        graph_config: Dict = {},
 66        energy_unit: str = "Ha",
 67        **kwargs,
 68    ) -> None:
 69        """Initialize the FENNIX model
 70
 71        Arguments:
 72        ----------
 73        cutoff: float
 74            The cutoff radius for the model
 75        modules: OrderedDict
 76            The dictionary defining the sequence of FeNNol modules and their parameters.
 77        preprocessing: OrderedDict
 78            The dictionary defining the sequence of preprocessing modules and their parameters.
 79        example_data: dict
 80            Example data to initialize the model. If not provided, a dummy system is generated.
 81        rng_key: jax.random.PRNGKey
 82            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 83        variables: dict
 84            The variables of the model (i.e. weights, biases and all other tunable parameters).
 85            If not provided, the variables are initialized (usually at random)
 86        energy_terms: Sequence[str]
 87            The energy terms in the model output that will be summed to compute the total energy.
 88            If None, the total energy is always zero (useful for non-PES models).
 89        use_atom_padding: bool
 90            If True, the model will use atom padding for the input data.
 91            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 92        graph_config: dict
 93            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 94
 95        """
 96        self._input_args = {
 97            "cutoff": cutoff,
 98            "modules": OrderedDict(modules),
 99            "preprocessing": OrderedDict(preprocessing),
100            "energy_unit": energy_unit,
101        }
102        self.energy_unit = energy_unit
103        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
104        self.cutoff = cutoff
105        self.energy_terms = energy_terms
106        self.use_atom_padding = use_atom_padding
107
108        # add non-differentiable/non-jittable modules
109        preprocessing = deepcopy(preprocessing)
110        if cutoff is None:
111            preprocessing_modules = []
112        else:
113            prep_keys = list(preprocessing.keys())
114            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
115            if len(prep_keys) > 0 and prep_keys[0] == "graph":
116                graph_params = {
117                    **graph_params,
118                    **preprocessing.pop("graph"),
119                }
120            graph_params = {**graph_params, **graph_config}
121
122            preprocessing_modules = [
123                GraphGenerator(**graph_params),
124            ]
125
126        for name, params in preprocessing.items():
127            key = str(params.pop("module_name")) if "module_name" in params else name
128            key = str(params.pop("FID")) if "FID" in params else key
129            mod = PREPROCESSING[key.upper()](**freeze(params))
130            preprocessing_modules.append(mod)
131
132        self.preprocessing = PreprocessingChain(
133            tuple(preprocessing_modules), use_atom_padding
134        )
135        graphs_properties = self.preprocessing.get_graphs_properties()
136        self._graphs_properties = freeze(graphs_properties)
137        # add preprocessing modules that should be differentiated/jitted
138        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
139        # mods = self.preprocessing.get_processors(return_list=True)
140
141        # build the model
142        modules = deepcopy(modules)
143        modules_names = []
144        for name, params in modules.items():
145            key = str(params.pop("module_name")) if "module_name" in params else name
146            key = str(params.pop("FID")) if "FID" in params else key
147            if name in modules_names:
148                raise ValueError(f"Module {name} already exists")
149            modules_names.append(name)
150            params["name"] = name
151            mod = MODULES[key.upper()]
152            fields = [f.name for f in dataclasses.fields(mod)]
153            if "_graphs_properties" in fields:
154                params["_graphs_properties"] = graphs_properties
155            if "_energy_unit" in fields:
156                params["_energy_unit"] = energy_unit
157            mods.append((mod, params))
158
159        self.modules = FENNIXModules(mods)
160
161        self.__apply = self.modules.apply
162        self._apply = jax.jit(self.modules.apply)
163
164        self.set_energy_terms(energy_terms)
165
166        # initialize the model
167
168        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
169
170        if variables is not None:
171            self.variables = variables
172        elif rng_key is not None:
173            self.variables = self.modules.init(rng_key, inputs)
174        else:
175            raise ValueError(
176                "Either variables or a jax.random.PRNGKey must be provided for initialization"
177            )
178
179        self._initializing = False
180
181    def set_energy_terms(
182        self, energy_terms: Union[Sequence[str], None], jit: bool = True
183    ) -> None:
184        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
185        object.__setattr__(self, "energy_terms", energy_terms)
186        if isinstance(energy_terms, str):
187            energy_terms = [energy_terms]
188        
189        if energy_terms is None or len(energy_terms) == 0:
190
191            def total_energy(variables, data):
192                out = self.__apply(variables, data)
193                coords = out["coordinates"]
194                nsys = out["natoms"].shape[0]
195                nat = coords.shape[0]
196                dtype = coords.dtype
197                e = jnp.zeros(nsys, dtype=dtype)
198                eat = jnp.zeros(nat, dtype=dtype)
199                out["total_energy"] = e
200                out["atomic_energies"] = eat
201                return e, out
202
203            def energy_and_forces(variables, data):
204                e, out = total_energy(variables, data)
205                f = jnp.zeros_like(out["coordinates"])
206                out["forces"] = f
207                return e, f, out
208
209            def energy_and_forces_and_virial(variables, data):
210                e, f, out = energy_and_forces(variables, data)
211                v = jnp.zeros(
212                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
213                )
214                out["virial_tensor"] = v
215                return e, f, v, out
216
217        else:
218            # build the energy and force functions
219            def total_energy(variables, data):
220                out = self.__apply(variables, data)
221                atomic_energies = 0.0
222                system_energies = 0.0
223                species = out["species"]
224                nsys = out["natoms"].shape[0]
225                for term in self.energy_terms:
226                    e = out[term]
227                    if e.ndim > 1 and e.shape[-1] == 1:
228                        e = jnp.squeeze(e, axis=-1)
229                    if e.shape[0] == nsys and nsys != species.shape[0]:
230                        system_energies += e
231                        continue
232                    assert e.shape == species.shape
233                    atomic_energies += e
234                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
235                if isinstance(atomic_energies, jnp.ndarray):
236                    if "true_atoms" in out:
237                        atomic_energies = jnp.where(
238                            out["true_atoms"], atomic_energies, 0.0
239                        )
240                    atomic_energies = atomic_energies.flatten()
241                    assert atomic_energies.shape == out["species"].shape
242                    out["atomic_energies"] = atomic_energies
243                    energies = jax.ops.segment_sum(
244                        atomic_energies,
245                        data["batch_index"],
246                        num_segments=len(data["natoms"]),
247                    )
248                else:
249                    energies = 0.0
250
251                if isinstance(system_energies, jnp.ndarray):
252                    if "true_sys" in out:
253                        system_energies = jnp.where(
254                            out["true_sys"], system_energies, 0.0
255                        )
256                    out["system_energies"] = system_energies
257
258                out["total_energy"] = energies + system_energies
259                return out["total_energy"], out
260
261            def energy_and_forces(variables, data):
262                def _etot(variables, coordinates):
263                    energy, out = total_energy(
264                        variables, {**data, "coordinates": coordinates}
265                    )
266                    return energy.sum(), out
267
268                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
269                    variables, data["coordinates"]
270                )
271                out["forces"] = -de
272
273                return out["total_energy"], out["forces"], out
274
275            # def energy_and_forces_and_virial(variables, data):
276            #     x = data["coordinates"]
277            #     batch_index = data["batch_index"]
278            #     if "cells" in data:
279            #         cells = data["cells"]
280            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
281
282            #         def _etot(variables, coordinates, cells):
283            #             reciprocal_cells = jnp.linalg.inv(cells)
284            #             energy, out = total_energy(
285            #                 variables,
286            #                 {
287            #                     **data,
288            #                     "coordinates": coordinates,
289            #                     "cells": cells,
290            #                     "reciprocal_cells": reciprocal_cells,
291            #                 },
292            #             )
293            #             return energy.sum(), out
294
295            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
296            #             variables, x, cells
297            #         )
298            #         f= -dedx
299            #         out["forces"] = f
300            #     else:
301            #         _,f,out = energy_and_forces(variables, data)
302
303            #     vir = -jax.ops.segment_sum(
304            #         f[:, :, None] * x[:, None, :],
305            #         batch_index,
306            #         num_segments=len(data["natoms"]),
307            #     )
308
309            #     if "cells" in data:
310            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
311            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
312            #         nsys = data["natoms"].shape[0]
313            #         if cells.shape[0]==1 and nsys>1:
314            #             dvir = dvir / nsys
315            #         vir = vir + dvir
316
317            #     out["virial_tensor"] = vir
318
319            #     return out["total_energy"], f, vir, out
320
321            def energy_and_forces_and_virial(variables, data):
322                x = data["coordinates"]
323                scaling = jnp.asarray(
324                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
325                )
326                def _etot(variables, coordinates, scaling):
327                    batch_index = data["batch_index"]
328                    coordinates = jax.vmap(jnp.matmul)(
329                        coordinates, scaling[batch_index]
330                    )
331                    inputs = {**data, "coordinates": coordinates}
332                    if "cells" in data:
333                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
334                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
335                        reciprocal_cells = jnp.linalg.inv(cells)
336                        inputs["cells"] = cells
337                        inputs["reciprocal_cells"] = reciprocal_cells
338                    energy, out = total_energy(variables, inputs)
339                    return energy.sum(), out
340
341                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
342                    variables, x, scaling
343                )
344                f = -dedx
345                out["forces"] = f
346                out["virial_tensor"] = vir
347
348                return out["total_energy"], f, vir, out
349
350        object.__setattr__(self, "_total_energy_raw", total_energy)
351        if jit:
352            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
353            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
354            object.__setattr__(
355                self,
356                "_energy_and_forces_and_virial",
357                jax.jit(energy_and_forces_and_virial),
358            )
359        else:
360            object.__setattr__(self, "_total_energy", total_energy)
361            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
362            object.__setattr__(
363                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
364            )
365
366    def get_gradient_function(
367        self,
368        *gradient_keys: Sequence[str],
369        jit: bool = True,
370        variables_as_input: bool = False,
371    ):
372        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
373
374        def _energy_gradient(variables, data):
375            def _etot(variables, inputs):
376                if "strain" in inputs:
377                    scaling = inputs["strain"]
378                    batch_index = data["batch_index"]
379                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
380                    coordinates = jax.vmap(jnp.matmul)(
381                        coordinates, scaling[batch_index]
382                    )
383                    inputs = {**inputs, "coordinates": coordinates}
384                    if "cells" in inputs or "cells" in data:
385                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
386                        cells = jax.vmap(jnp.matmul)(cells, scaling)
387                        inputs["cells"] = cells
388                if "cells" in inputs:
389                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
390                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
391                energy, out = self._total_energy_raw(variables, {**data, **inputs})
392                return energy.sum(), out
393
394            if "strain" in gradient_keys and "strain" not in data:
395                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
396            inputs = {k: data[k] for k in gradient_keys}
397            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
398
399            return (
400                out["total_energy"],
401                de,
402                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
403            )
404
405        if variables_as_input:
406            energy_gradient = _energy_gradient
407        else:
408
409            def energy_gradient(data):
410                return _energy_gradient(self.variables, data)
411
412        if jit:
413            return jax.jit(energy_gradient)
414        else:
415            return energy_gradient
416
417    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
418        """apply preprocessing to the input data
419
420        !!! This is not a pure function => do not apply jax transforms !!!"""
421        if self.preproc_state is None:
422            out, _ = self.reinitialize_preprocessing(example_data=inputs)
423        elif use_gpu:
424            do_check_input = self.preproc_state.get("check_input", True)
425            if do_check_input:
426                inputs = check_input(inputs)
427            preproc_state, inputs = self.preprocessing.atom_padding(
428                self.preproc_state, inputs
429            )
430            inputs = self.preprocessing.process(preproc_state, inputs)
431            preproc_state, state_up, out, overflow = (
432                self.preprocessing.check_reallocate(
433                    preproc_state, inputs
434                )
435            )
436            if verbose and overflow:
437                print("GPU preprocessing: nblist overflow => reallocating nblist")
438                print("size updates:", state_up)
439        else:
440            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
441
442        object.__setattr__(self, "preproc_state", preproc_state)
443        return out
444
445    def reinitialize_preprocessing(
446        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
447    ) -> None:
448        ### TODO ###
449        if rng_key is None:
450            rng_key_pre = jax.random.PRNGKey(0)
451        else:
452            rng_key, rng_key_pre = jax.random.split(rng_key)
453
454        if example_data is None:
455            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
456            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
457
458        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
459        object.__setattr__(self, "preproc_state", preproc_state)
460        return inputs, rng_key
461
462    def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]:
463        """Apply the FENNIX model (preprocess + modules)
464
465        !!! This is not a pure function => do not apply jax transforms !!!
466        if you want to apply jax transforms, use  self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess
467        """
468        if variables is None:
469            variables = self.variables
470        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
471        output = self._apply(variables, inputs)
472        if self.use_atom_padding:
473            output = atom_unpadding(output)
474        return output
475
476    def total_energy(
477        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
478    ) -> Tuple[jnp.ndarray, Dict]:
479        """compute the total energy of the system
480
481        !!! This is not a pure function => do not apply jax transforms !!!
482        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
483        """
484        if variables is None:
485            variables = self.variables
486        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
487        # def print_shape(path,value):
488        #     if isinstance(value,jnp.ndarray):
489        #         print(path,value.shape)
490        #     else:
491        #         print(path,value)
492        # jax.tree_util.tree_map_with_path(print_shape,inputs)
493        _, output = self._total_energy(variables, inputs)
494        if self.use_atom_padding:
495            output = atom_unpadding(output)
496        e = output["total_energy"]
497        if unit is not None:
498            model_energy_unit = self.Ha_to_model_energy
499            if isinstance(unit, str):
500                unit = au.get_multiplier(unit)
501            e = e * (unit / model_energy_unit)
502        return e, output
503
504    def energy_and_forces(
505        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
506    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
507        """compute the total energy and forces of the system
508
509        !!! This is not a pure function => do not apply jax transforms !!!
510        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
511        """
512        if variables is None:
513            variables = self.variables
514        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
515        _, _, output = self._energy_and_forces(variables, inputs)
516        if self.use_atom_padding:
517            output = atom_unpadding(output)
518        e = output["total_energy"]
519        f = output["forces"]
520        if unit is not None:
521            model_energy_unit = self.Ha_to_model_energy
522            if isinstance(unit, str):
523                unit = au.get_multiplier(unit)
524            e = e * (unit / model_energy_unit)
525            f = f * (unit / model_energy_unit)
526        return e, f, output
527
528    def energy_and_forces_and_virial(
529        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
530    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
531        """compute the total energy and forces of the system
532
533        !!! This is not a pure function => do not apply jax transforms !!!
534        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
535        """
536        if variables is None:
537            variables = self.variables
538        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
539        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
540        if self.use_atom_padding:
541            output = atom_unpadding(output)
542        e = output["total_energy"]
543        f = output["forces"]
544        v = output["virial_tensor"]
545        if unit is not None:
546            model_energy_unit = self.Ha_to_model_energy
547            if isinstance(unit, str):
548                unit = au.get_multiplier(unit)
549            e = e * (unit / model_energy_unit)
550            f = f * (unit / model_energy_unit)
551            v = v * (unit / model_energy_unit)
552        return e, f, v, output
553
554    def remove_atom_padding(self, output):
555        """remove atom padding from the output"""
556        return atom_unpadding(output)
557
558    def get_model(self) -> Tuple[FENNIXModules, Dict]:
559        """return the model and its variables"""
560        return self.modules, self.variables
561
562    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
563        """return the preprocessing chain and its state"""
564        return self.preprocessing, self.preproc_state
565
566    def __setattr__(self, __name: str, __value: Any) -> None:
567        if __name == "variables":
568            if __value is not None:
569                if not (
570                    isinstance(__value, dict)
571                    or isinstance(__value, OrderedDict)
572                    or isinstance(__value, FrozenDict)
573                ):
574                    raise ValueError(f"{__name} must be a dict")
575                object.__setattr__(self, __name, JaxConverter()(__value))
576            else:
577                raise ValueError(f"{__name} cannot be None")
578        elif __name == "preproc_state":
579            if __value is not None:
580                if not (
581                    isinstance(__value, dict)
582                    or isinstance(__value, OrderedDict)
583                    or isinstance(__value, FrozenDict)
584                ):
585                    raise ValueError(f"{__name} must be a FrozenDict")
586                object.__setattr__(self, __name, freeze(JaxConverter()(__value)))
587            else:
588                raise ValueError(f"{__name} cannot be None")
589
590        elif self._initializing:
591            object.__setattr__(self, __name, __value)
592        else:
593            raise ValueError(f"{__name} attribute of FENNIX model is immutable.")
594
595    def generate_dummy_system(
596        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
597    ) -> Dict[str, Any]:
598        """
599        Generate dummy system for initialization
600        """
601        if box_size is None:
602            box_size = 2 * self.cutoff
603            for g in self._graphs_properties.values():
604                cutoff = g["cutoff"]
605                if cutoff is not None:
606                    box_size = min(box_size, 2 * g["cutoff"])
607        coordinates = np.array(
608            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
609        )
610        species = np.ones((n_atoms,), dtype=np.int32)
611        batch_index = np.zeros((n_atoms,), dtype=np.int32)
612        natoms = np.array([n_atoms], dtype=np.int32)
613        return {
614            "species": species,
615            "coordinates": coordinates,
616            # "graph": graph,
617            "batch_index": batch_index,
618            "natoms": natoms,
619        }
620
621    def summarize(
622        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
623    ) -> str:
624        """Summarize the model architecture and parameters"""
625        if rng_key is None:
626            head = "Summarizing with example data:\n"
627            rng_key = jax.random.PRNGKey(0)
628        if example_data is None:
629            head = "Summarizing with dummy 10 atoms system:\n"
630            rng_key, rng_key_sys = jax.random.split(rng_key)
631            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
632        rng_key, rng_key_pre = jax.random.split(rng_key)
633        _, inputs = self.preprocessing.init_with_output(example_data)
634        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
635
636    def to_dict(self,convert_numpy=False):
637        """return a dictionary representation of the model"""
638        if convert_numpy:
639            variables = convert_to_numpy(self.variables)
640        else:
641            variables = deepcopy(self.variables)
642        return {
643            **self._input_args,
644            "energy_terms": self.energy_terms,
645            "variables": variables,
646        }
647
648    def save(self, filename):
649        """save the model to a file"""
650        filename_str = str(filename)
651        do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle")
652        state_dict = self.to_dict(convert_numpy=do_pickle)
653        state_dict["preprocessing"] = [
654            [k, v] for k, v in state_dict["preprocessing"].items()
655        ]
656        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
657        if do_pickle:
658            import pickle
659            with open(filename, "wb") as f:
660                pickle.dump(state_dict, f)
661        else:
662            with open(filename, "wb") as f:
663                f.write(serialization.msgpack_serialize(state_dict))
664
665    @classmethod
666    def load(
667        cls,
668        filename,
669        use_atom_padding=False,
670        graph_config={},
671    ):
672        """load a model from a file"""
673        filename_str = str(filename)
674        do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle")
675        if do_pickle:
676            import pickle
677            with open(filename, "rb") as f:
678                state_dict = pickle.load(f)
679                state_dict["variables"] = convert_to_jax(state_dict["variables"])
680        else:
681            with open(filename, "rb") as f:
682                state_dict = serialization.msgpack_restore(f.read())
683        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
684        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
685        return cls(
686            **state_dict,
687            graph_config=graph_config,
688            use_atom_padding=use_atom_padding,
689        )

Static wrapper for FENNIX models

The underlying model is a fennol.models.modules.FENNIXModules built from the modules dictionary which references registered modules in fennol.models.modules.MODULES and provides the parameters for initialization.

Since the model is static and contains variables, it must be initialized right away with either example_data, variables or rng_key. If variables is provided, it is used directly. If example_data is provided, the model is initialized with example_data and the resulting variables are stored in the wrapper. If only rng_key is provided, the model is initialized with a dummy system and the resulting.

FENNIX( cutoff: float, modules: collections.OrderedDict, preprocessing: collections.OrderedDict = OrderedDict(), example_data=None, rng_key: Optional[PRNGKey] = None, variables: Optional[dict] = None, energy_terms: Optional[Sequence[str]] = None, use_atom_padding: bool = False, graph_config: Dict = {}, energy_unit: str = 'Ha', **kwargs)
 55    def __init__(
 56        self,
 57        cutoff: float,
 58        modules: OrderedDict,
 59        preprocessing: OrderedDict = OrderedDict(),
 60        example_data=None,
 61        rng_key: Optional[jax.random.PRNGKey] = None,
 62        variables: Optional[dict] = None,
 63        energy_terms: Optional[Sequence[str]] = None,
 64        use_atom_padding: bool = False,
 65        graph_config: Dict = {},
 66        energy_unit: str = "Ha",
 67        **kwargs,
 68    ) -> None:
 69        """Initialize the FENNIX model
 70
 71        Arguments:
 72        ----------
 73        cutoff: float
 74            The cutoff radius for the model
 75        modules: OrderedDict
 76            The dictionary defining the sequence of FeNNol modules and their parameters.
 77        preprocessing: OrderedDict
 78            The dictionary defining the sequence of preprocessing modules and their parameters.
 79        example_data: dict
 80            Example data to initialize the model. If not provided, a dummy system is generated.
 81        rng_key: jax.random.PRNGKey
 82            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 83        variables: dict
 84            The variables of the model (i.e. weights, biases and all other tunable parameters).
 85            If not provided, the variables are initialized (usually at random)
 86        energy_terms: Sequence[str]
 87            The energy terms in the model output that will be summed to compute the total energy.
 88            If None, the total energy is always zero (useful for non-PES models).
 89        use_atom_padding: bool
 90            If True, the model will use atom padding for the input data.
 91            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 92        graph_config: dict
 93            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 94
 95        """
 96        self._input_args = {
 97            "cutoff": cutoff,
 98            "modules": OrderedDict(modules),
 99            "preprocessing": OrderedDict(preprocessing),
100            "energy_unit": energy_unit,
101        }
102        self.energy_unit = energy_unit
103        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
104        self.cutoff = cutoff
105        self.energy_terms = energy_terms
106        self.use_atom_padding = use_atom_padding
107
108        # add non-differentiable/non-jittable modules
109        preprocessing = deepcopy(preprocessing)
110        if cutoff is None:
111            preprocessing_modules = []
112        else:
113            prep_keys = list(preprocessing.keys())
114            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
115            if len(prep_keys) > 0 and prep_keys[0] == "graph":
116                graph_params = {
117                    **graph_params,
118                    **preprocessing.pop("graph"),
119                }
120            graph_params = {**graph_params, **graph_config}
121
122            preprocessing_modules = [
123                GraphGenerator(**graph_params),
124            ]
125
126        for name, params in preprocessing.items():
127            key = str(params.pop("module_name")) if "module_name" in params else name
128            key = str(params.pop("FID")) if "FID" in params else key
129            mod = PREPROCESSING[key.upper()](**freeze(params))
130            preprocessing_modules.append(mod)
131
132        self.preprocessing = PreprocessingChain(
133            tuple(preprocessing_modules), use_atom_padding
134        )
135        graphs_properties = self.preprocessing.get_graphs_properties()
136        self._graphs_properties = freeze(graphs_properties)
137        # add preprocessing modules that should be differentiated/jitted
138        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
139        # mods = self.preprocessing.get_processors(return_list=True)
140
141        # build the model
142        modules = deepcopy(modules)
143        modules_names = []
144        for name, params in modules.items():
145            key = str(params.pop("module_name")) if "module_name" in params else name
146            key = str(params.pop("FID")) if "FID" in params else key
147            if name in modules_names:
148                raise ValueError(f"Module {name} already exists")
149            modules_names.append(name)
150            params["name"] = name
151            mod = MODULES[key.upper()]
152            fields = [f.name for f in dataclasses.fields(mod)]
153            if "_graphs_properties" in fields:
154                params["_graphs_properties"] = graphs_properties
155            if "_energy_unit" in fields:
156                params["_energy_unit"] = energy_unit
157            mods.append((mod, params))
158
159        self.modules = FENNIXModules(mods)
160
161        self.__apply = self.modules.apply
162        self._apply = jax.jit(self.modules.apply)
163
164        self.set_energy_terms(energy_terms)
165
166        # initialize the model
167
168        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
169
170        if variables is not None:
171            self.variables = variables
172        elif rng_key is not None:
173            self.variables = self.modules.init(rng_key, inputs)
174        else:
175            raise ValueError(
176                "Either variables or a jax.random.PRNGKey must be provided for initialization"
177            )
178
179        self._initializing = False

Initialize the FENNIX model

Arguments:

cutoff: float The cutoff radius for the model modules: OrderedDict The dictionary defining the sequence of FeNNol modules and their parameters. preprocessing: OrderedDict The dictionary defining the sequence of preprocessing modules and their parameters. example_data: dict Example data to initialize the model. If not provided, a dummy system is generated. rng_key: jax.random.PRNGKey The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). variables: dict The variables of the model (i.e. weights, biases and all other tunable parameters). If not provided, the variables are initialized (usually at random) energy_terms: Sequence[str] The energy terms in the model output that will be summed to compute the total energy. If None, the total energy is always zero (useful for non-PES models). use_atom_padding: bool If True, the model will use atom padding for the input data. This is useful when one plans to frequently change the number of atoms in the system (for example during training). graph_config: dict Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.

cutoff: Optional[float]
variables: Dict
preproc_state: Dict
energy_terms: Optional[Sequence[str]] = None
use_atom_padding: bool = False
energy_unit
Ha_to_model_energy
def set_energy_terms(self, energy_terms: Optional[Sequence[str]], jit: bool = True) -> None:
181    def set_energy_terms(
182        self, energy_terms: Union[Sequence[str], None], jit: bool = True
183    ) -> None:
184        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
185        object.__setattr__(self, "energy_terms", energy_terms)
186        if isinstance(energy_terms, str):
187            energy_terms = [energy_terms]
188        
189        if energy_terms is None or len(energy_terms) == 0:
190
191            def total_energy(variables, data):
192                out = self.__apply(variables, data)
193                coords = out["coordinates"]
194                nsys = out["natoms"].shape[0]
195                nat = coords.shape[0]
196                dtype = coords.dtype
197                e = jnp.zeros(nsys, dtype=dtype)
198                eat = jnp.zeros(nat, dtype=dtype)
199                out["total_energy"] = e
200                out["atomic_energies"] = eat
201                return e, out
202
203            def energy_and_forces(variables, data):
204                e, out = total_energy(variables, data)
205                f = jnp.zeros_like(out["coordinates"])
206                out["forces"] = f
207                return e, f, out
208
209            def energy_and_forces_and_virial(variables, data):
210                e, f, out = energy_and_forces(variables, data)
211                v = jnp.zeros(
212                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
213                )
214                out["virial_tensor"] = v
215                return e, f, v, out
216
217        else:
218            # build the energy and force functions
219            def total_energy(variables, data):
220                out = self.__apply(variables, data)
221                atomic_energies = 0.0
222                system_energies = 0.0
223                species = out["species"]
224                nsys = out["natoms"].shape[0]
225                for term in self.energy_terms:
226                    e = out[term]
227                    if e.ndim > 1 and e.shape[-1] == 1:
228                        e = jnp.squeeze(e, axis=-1)
229                    if e.shape[0] == nsys and nsys != species.shape[0]:
230                        system_energies += e
231                        continue
232                    assert e.shape == species.shape
233                    atomic_energies += e
234                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
235                if isinstance(atomic_energies, jnp.ndarray):
236                    if "true_atoms" in out:
237                        atomic_energies = jnp.where(
238                            out["true_atoms"], atomic_energies, 0.0
239                        )
240                    atomic_energies = atomic_energies.flatten()
241                    assert atomic_energies.shape == out["species"].shape
242                    out["atomic_energies"] = atomic_energies
243                    energies = jax.ops.segment_sum(
244                        atomic_energies,
245                        data["batch_index"],
246                        num_segments=len(data["natoms"]),
247                    )
248                else:
249                    energies = 0.0
250
251                if isinstance(system_energies, jnp.ndarray):
252                    if "true_sys" in out:
253                        system_energies = jnp.where(
254                            out["true_sys"], system_energies, 0.0
255                        )
256                    out["system_energies"] = system_energies
257
258                out["total_energy"] = energies + system_energies
259                return out["total_energy"], out
260
261            def energy_and_forces(variables, data):
262                def _etot(variables, coordinates):
263                    energy, out = total_energy(
264                        variables, {**data, "coordinates": coordinates}
265                    )
266                    return energy.sum(), out
267
268                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
269                    variables, data["coordinates"]
270                )
271                out["forces"] = -de
272
273                return out["total_energy"], out["forces"], out
274
275            # def energy_and_forces_and_virial(variables, data):
276            #     x = data["coordinates"]
277            #     batch_index = data["batch_index"]
278            #     if "cells" in data:
279            #         cells = data["cells"]
280            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
281
282            #         def _etot(variables, coordinates, cells):
283            #             reciprocal_cells = jnp.linalg.inv(cells)
284            #             energy, out = total_energy(
285            #                 variables,
286            #                 {
287            #                     **data,
288            #                     "coordinates": coordinates,
289            #                     "cells": cells,
290            #                     "reciprocal_cells": reciprocal_cells,
291            #                 },
292            #             )
293            #             return energy.sum(), out
294
295            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
296            #             variables, x, cells
297            #         )
298            #         f= -dedx
299            #         out["forces"] = f
300            #     else:
301            #         _,f,out = energy_and_forces(variables, data)
302
303            #     vir = -jax.ops.segment_sum(
304            #         f[:, :, None] * x[:, None, :],
305            #         batch_index,
306            #         num_segments=len(data["natoms"]),
307            #     )
308
309            #     if "cells" in data:
310            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
311            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
312            #         nsys = data["natoms"].shape[0]
313            #         if cells.shape[0]==1 and nsys>1:
314            #             dvir = dvir / nsys
315            #         vir = vir + dvir
316
317            #     out["virial_tensor"] = vir
318
319            #     return out["total_energy"], f, vir, out
320
321            def energy_and_forces_and_virial(variables, data):
322                x = data["coordinates"]
323                scaling = jnp.asarray(
324                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
325                )
326                def _etot(variables, coordinates, scaling):
327                    batch_index = data["batch_index"]
328                    coordinates = jax.vmap(jnp.matmul)(
329                        coordinates, scaling[batch_index]
330                    )
331                    inputs = {**data, "coordinates": coordinates}
332                    if "cells" in data:
333                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
334                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
335                        reciprocal_cells = jnp.linalg.inv(cells)
336                        inputs["cells"] = cells
337                        inputs["reciprocal_cells"] = reciprocal_cells
338                    energy, out = total_energy(variables, inputs)
339                    return energy.sum(), out
340
341                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
342                    variables, x, scaling
343                )
344                f = -dedx
345                out["forces"] = f
346                out["virial_tensor"] = vir
347
348                return out["total_energy"], f, vir, out
349
350        object.__setattr__(self, "_total_energy_raw", total_energy)
351        if jit:
352            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
353            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
354            object.__setattr__(
355                self,
356                "_energy_and_forces_and_virial",
357                jax.jit(energy_and_forces_and_virial),
358            )
359        else:
360            object.__setattr__(self, "_total_energy", total_energy)
361            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
362            object.__setattr__(
363                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
364            )

Set the energy terms to be computed by the model and prepare the energy and force functions.

def get_gradient_function( self, *gradient_keys: Sequence[str], jit: bool = True, variables_as_input: bool = False):
366    def get_gradient_function(
367        self,
368        *gradient_keys: Sequence[str],
369        jit: bool = True,
370        variables_as_input: bool = False,
371    ):
372        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
373
374        def _energy_gradient(variables, data):
375            def _etot(variables, inputs):
376                if "strain" in inputs:
377                    scaling = inputs["strain"]
378                    batch_index = data["batch_index"]
379                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
380                    coordinates = jax.vmap(jnp.matmul)(
381                        coordinates, scaling[batch_index]
382                    )
383                    inputs = {**inputs, "coordinates": coordinates}
384                    if "cells" in inputs or "cells" in data:
385                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
386                        cells = jax.vmap(jnp.matmul)(cells, scaling)
387                        inputs["cells"] = cells
388                if "cells" in inputs:
389                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
390                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
391                energy, out = self._total_energy_raw(variables, {**data, **inputs})
392                return energy.sum(), out
393
394            if "strain" in gradient_keys and "strain" not in data:
395                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
396            inputs = {k: data[k] for k in gradient_keys}
397            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
398
399            return (
400                out["total_energy"],
401                de,
402                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
403            )
404
405        if variables_as_input:
406            energy_gradient = _energy_gradient
407        else:
408
409            def energy_gradient(data):
410                return _energy_gradient(self.variables, data)
411
412        if jit:
413            return jax.jit(energy_gradient)
414        else:
415            return energy_gradient

Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys

def preprocess(self, use_gpu=False, verbose=False, **inputs) -> Dict[str, Any]:
417    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
418        """apply preprocessing to the input data
419
420        !!! This is not a pure function => do not apply jax transforms !!!"""
421        if self.preproc_state is None:
422            out, _ = self.reinitialize_preprocessing(example_data=inputs)
423        elif use_gpu:
424            do_check_input = self.preproc_state.get("check_input", True)
425            if do_check_input:
426                inputs = check_input(inputs)
427            preproc_state, inputs = self.preprocessing.atom_padding(
428                self.preproc_state, inputs
429            )
430            inputs = self.preprocessing.process(preproc_state, inputs)
431            preproc_state, state_up, out, overflow = (
432                self.preprocessing.check_reallocate(
433                    preproc_state, inputs
434                )
435            )
436            if verbose and overflow:
437                print("GPU preprocessing: nblist overflow => reallocating nblist")
438                print("size updates:", state_up)
439        else:
440            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
441
442        object.__setattr__(self, "preproc_state", preproc_state)
443        return out

apply preprocessing to the input data

!!! This is not a pure function => do not apply jax transforms !!!

def reinitialize_preprocessing(self, rng_key: Optional[PRNGKey] = None, example_data=None) -> None:
445    def reinitialize_preprocessing(
446        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
447    ) -> None:
448        ### TODO ###
449        if rng_key is None:
450            rng_key_pre = jax.random.PRNGKey(0)
451        else:
452            rng_key, rng_key_pre = jax.random.split(rng_key)
453
454        if example_data is None:
455            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
456            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
457
458        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
459        object.__setattr__(self, "preproc_state", preproc_state)
460        return inputs, rng_key
def total_energy( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, Dict]:
476    def total_energy(
477        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
478    ) -> Tuple[jnp.ndarray, Dict]:
479        """compute the total energy of the system
480
481        !!! This is not a pure function => do not apply jax transforms !!!
482        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
483        """
484        if variables is None:
485            variables = self.variables
486        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
487        # def print_shape(path,value):
488        #     if isinstance(value,jnp.ndarray):
489        #         print(path,value.shape)
490        #     else:
491        #         print(path,value)
492        # jax.tree_util.tree_map_with_path(print_shape,inputs)
493        _, output = self._total_energy(variables, inputs)
494        if self.use_atom_padding:
495            output = atom_unpadding(output)
496        e = output["total_energy"]
497        if unit is not None:
498            model_energy_unit = self.Ha_to_model_energy
499            if isinstance(unit, str):
500                unit = au.get_multiplier(unit)
501            e = e * (unit / model_energy_unit)
502        return e, output

compute the total energy of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess

def energy_and_forces( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, jax.Array, Dict]:
504    def energy_and_forces(
505        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
506    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
507        """compute the total energy and forces of the system
508
509        !!! This is not a pure function => do not apply jax transforms !!!
510        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
511        """
512        if variables is None:
513            variables = self.variables
514        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
515        _, _, output = self._energy_and_forces(variables, inputs)
516        if self.use_atom_padding:
517            output = atom_unpadding(output)
518        e = output["total_energy"]
519        f = output["forces"]
520        if unit is not None:
521            model_energy_unit = self.Ha_to_model_energy
522            if isinstance(unit, str):
523                unit = au.get_multiplier(unit)
524            e = e * (unit / model_energy_unit)
525            f = f * (unit / model_energy_unit)
526        return e, f, output

compute the total energy and forces of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess

def energy_and_forces_and_virial( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, jax.Array, Dict]:
528    def energy_and_forces_and_virial(
529        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
530    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
531        """compute the total energy and forces of the system
532
533        !!! This is not a pure function => do not apply jax transforms !!!
534        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
535        """
536        if variables is None:
537            variables = self.variables
538        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
539        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
540        if self.use_atom_padding:
541            output = atom_unpadding(output)
542        e = output["total_energy"]
543        f = output["forces"]
544        v = output["virial_tensor"]
545        if unit is not None:
546            model_energy_unit = self.Ha_to_model_energy
547            if isinstance(unit, str):
548                unit = au.get_multiplier(unit)
549            e = e * (unit / model_energy_unit)
550            f = f * (unit / model_energy_unit)
551            v = v * (unit / model_energy_unit)
552        return e, f, v, output

compute the total energy and forces of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess

def remove_atom_padding(self, output):
554    def remove_atom_padding(self, output):
555        """remove atom padding from the output"""
556        return atom_unpadding(output)

remove atom padding from the output

def get_model(self) -> Tuple[fennol.models.modules.FENNIXModules, Dict]:
558    def get_model(self) -> Tuple[FENNIXModules, Dict]:
559        """return the model and its variables"""
560        return self.modules, self.variables

return the model and its variables

def get_preprocessing(self) -> Tuple[fennol.models.preprocessing.PreprocessingChain, Dict]:
562    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
563        """return the preprocessing chain and its state"""
564        return self.preprocessing, self.preproc_state

return the preprocessing chain and its state

def generate_dummy_system( self, rng_key: <function PRNGKey>, box_size=None, n_atoms: int = 10) -> Dict[str, Any]:
595    def generate_dummy_system(
596        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
597    ) -> Dict[str, Any]:
598        """
599        Generate dummy system for initialization
600        """
601        if box_size is None:
602            box_size = 2 * self.cutoff
603            for g in self._graphs_properties.values():
604                cutoff = g["cutoff"]
605                if cutoff is not None:
606                    box_size = min(box_size, 2 * g["cutoff"])
607        coordinates = np.array(
608            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
609        )
610        species = np.ones((n_atoms,), dtype=np.int32)
611        batch_index = np.zeros((n_atoms,), dtype=np.int32)
612        natoms = np.array([n_atoms], dtype=np.int32)
613        return {
614            "species": species,
615            "coordinates": coordinates,
616            # "graph": graph,
617            "batch_index": batch_index,
618            "natoms": natoms,
619        }

Generate dummy system for initialization

def summarize( self, rng_key: <function PRNGKey> = None, example_data=None, **kwargs) -> str:
621    def summarize(
622        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
623    ) -> str:
624        """Summarize the model architecture and parameters"""
625        if rng_key is None:
626            head = "Summarizing with example data:\n"
627            rng_key = jax.random.PRNGKey(0)
628        if example_data is None:
629            head = "Summarizing with dummy 10 atoms system:\n"
630            rng_key, rng_key_sys = jax.random.split(rng_key)
631            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
632        rng_key, rng_key_pre = jax.random.split(rng_key)
633        _, inputs = self.preprocessing.init_with_output(example_data)
634        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)

Summarize the model architecture and parameters

def to_dict(self, convert_numpy=False):
636    def to_dict(self,convert_numpy=False):
637        """return a dictionary representation of the model"""
638        if convert_numpy:
639            variables = convert_to_numpy(self.variables)
640        else:
641            variables = deepcopy(self.variables)
642        return {
643            **self._input_args,
644            "energy_terms": self.energy_terms,
645            "variables": variables,
646        }

return a dictionary representation of the model

def save(self, filename):
648    def save(self, filename):
649        """save the model to a file"""
650        filename_str = str(filename)
651        do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle")
652        state_dict = self.to_dict(convert_numpy=do_pickle)
653        state_dict["preprocessing"] = [
654            [k, v] for k, v in state_dict["preprocessing"].items()
655        ]
656        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
657        if do_pickle:
658            import pickle
659            with open(filename, "wb") as f:
660                pickle.dump(state_dict, f)
661        else:
662            with open(filename, "wb") as f:
663                f.write(serialization.msgpack_serialize(state_dict))

save the model to a file

@classmethod
def load(cls, filename, use_atom_padding=False, graph_config={}):
665    @classmethod
666    def load(
667        cls,
668        filename,
669        use_atom_padding=False,
670        graph_config={},
671    ):
672        """load a model from a file"""
673        filename_str = str(filename)
674        do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle")
675        if do_pickle:
676            import pickle
677            with open(filename, "rb") as f:
678                state_dict = pickle.load(f)
679                state_dict["variables"] = convert_to_jax(state_dict["variables"])
680        else:
681            with open(filename, "rb") as f:
682                state_dict = serialization.msgpack_restore(f.read())
683        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
684        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
685        return cls(
686            **state_dict,
687            graph_config=graph_config,
688            use_atom_padding=use_atom_padding,
689        )

load a model from a file