fennol.models.fennix
1from typing import Any, Sequence, Callable, Union, Optional, Tuple, Dict 2from copy import deepcopy 3import dataclasses 4from collections import OrderedDict 5 6import jax 7import jax.numpy as jnp 8import flax.linen as nn 9import numpy as np 10from flax import serialization 11from flax.core.frozen_dict import freeze, unfreeze, FrozenDict 12from ..utils.atomic_units import au 13 14from .preprocessing import ( 15 GraphGenerator, 16 PreprocessingChain, 17 JaxConverter, 18 atom_unpadding, 19 check_input, 20 convert_to_jax, 21 convert_to_numpy, 22) 23from .modules import MODULES, PREPROCESSING, FENNIXModules 24 25 26@dataclasses.dataclass 27class FENNIX: 28 """ 29 Static wrapper for FENNIX models 30 31 The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary 32 which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization. 33 34 Since the model is static and contains variables, it must be initialized right away with either 35 `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data` 36 is provided, the model is initialized with `example_data` and the resulting variables are stored 37 in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting. 38 """ 39 40 cutoff: Union[float, None] 41 modules: FENNIXModules 42 variables: Dict 43 preprocessing: PreprocessingChain 44 _apply: Callable[[Dict, Dict], Dict] 45 _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]] 46 _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]] 47 _input_args: Dict 48 _graphs_properties: Dict 49 preproc_state: Dict 50 energy_terms: Optional[Sequence[str]] = None 51 _initializing: bool = True 52 use_atom_padding: bool = False 53 54 def __init__( 55 self, 56 cutoff: float, 57 modules: OrderedDict, 58 preprocessing: OrderedDict = OrderedDict(), 59 example_data=None, 60 rng_key: Optional[jax.random.PRNGKey] = None, 61 variables: Optional[dict] = None, 62 energy_terms: Optional[Sequence[str]] = None, 63 use_atom_padding: bool = False, 64 graph_config: Dict = {}, 65 energy_unit: str = "Ha", 66 **kwargs, 67 ) -> None: 68 """Initialize the FENNIX model 69 70 Arguments: 71 ---------- 72 cutoff: float 73 The cutoff radius for the model 74 modules: OrderedDict 75 The dictionary defining the sequence of FeNNol modules and their parameters. 76 preprocessing: OrderedDict 77 The dictionary defining the sequence of preprocessing modules and their parameters. 78 example_data: dict 79 Example data to initialize the model. If not provided, a dummy system is generated. 80 rng_key: jax.random.PRNGKey 81 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 82 variables: dict 83 The variables of the model (i.e. weights, biases and all other tunable parameters). 84 If not provided, the variables are initialized (usually at random) 85 energy_terms: Sequence[str] 86 The energy terms in the model output that will be summed to compute the total energy. 87 If None, the total energy is always zero (useful for non-PES models). 88 use_atom_padding: bool 89 If True, the model will use atom padding for the input data. 90 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 91 graph_config: dict 92 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 93 94 """ 95 self._input_args = { 96 "cutoff": cutoff, 97 "modules": OrderedDict(modules), 98 "preprocessing": OrderedDict(preprocessing), 99 "energy_unit": energy_unit, 100 } 101 self.energy_unit = energy_unit 102 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 103 self.cutoff = cutoff 104 self.energy_terms = energy_terms 105 self.use_atom_padding = use_atom_padding 106 107 # add non-differentiable/non-jittable modules 108 preprocessing = deepcopy(preprocessing) 109 if cutoff is None: 110 preprocessing_modules = [] 111 else: 112 prep_keys = list(preprocessing.keys()) 113 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 114 if len(prep_keys) > 0 and prep_keys[0] == "graph": 115 graph_params = { 116 **graph_params, 117 **preprocessing.pop("graph"), 118 } 119 graph_params = {**graph_params, **graph_config} 120 121 preprocessing_modules = [ 122 GraphGenerator(**graph_params), 123 ] 124 125 for name, params in preprocessing.items(): 126 key = str(params.pop("module_name")) if "module_name" in params else name 127 key = str(params.pop("FID")) if "FID" in params else key 128 mod = PREPROCESSING[key.upper()](**freeze(params)) 129 preprocessing_modules.append(mod) 130 131 self.preprocessing = PreprocessingChain( 132 tuple(preprocessing_modules), use_atom_padding 133 ) 134 graphs_properties = self.preprocessing.get_graphs_properties() 135 self._graphs_properties = freeze(graphs_properties) 136 # add preprocessing modules that should be differentiated/jitted 137 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 138 # mods = self.preprocessing.get_processors(return_list=True) 139 140 # build the model 141 modules = deepcopy(modules) 142 modules_names = [] 143 for name, params in modules.items(): 144 key = str(params.pop("module_name")) if "module_name" in params else name 145 key = str(params.pop("FID")) if "FID" in params else key 146 if name in modules_names: 147 raise ValueError(f"Module {name} already exists") 148 modules_names.append(name) 149 params["name"] = name 150 mod = MODULES[key.upper()] 151 fields = [f.name for f in dataclasses.fields(mod)] 152 if "_graphs_properties" in fields: 153 params["_graphs_properties"] = graphs_properties 154 if "_energy_unit" in fields: 155 params["_energy_unit"] = energy_unit 156 mods.append((mod, params)) 157 158 self.modules = FENNIXModules(mods) 159 160 self.__apply = self.modules.apply 161 self._apply = jax.jit(self.modules.apply) 162 163 self.set_energy_terms(energy_terms) 164 165 # initialize the model 166 167 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 168 169 if variables is not None: 170 self.variables = variables 171 elif rng_key is not None: 172 self.variables = self.modules.init(rng_key, inputs) 173 else: 174 raise ValueError( 175 "Either variables or a jax.random.PRNGKey must be provided for initialization" 176 ) 177 178 self._initializing = False 179 180 def set_energy_terms( 181 self, energy_terms: Union[Sequence[str], None], jit: bool = True 182 ) -> None: 183 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 184 object.__setattr__(self, "energy_terms", energy_terms) 185 if isinstance(energy_terms, str): 186 energy_terms = [energy_terms] 187 188 if energy_terms is None or len(energy_terms) == 0: 189 190 def total_energy(variables, data): 191 out = self.__apply(variables, data) 192 coords = out["coordinates"] 193 nsys = out["natoms"].shape[0] 194 nat = coords.shape[0] 195 dtype = coords.dtype 196 e = jnp.zeros(nsys, dtype=dtype) 197 eat = jnp.zeros(nat, dtype=dtype) 198 out["total_energy"] = e 199 out["atomic_energies"] = eat 200 return e, out 201 202 def energy_and_forces(variables, data): 203 e, out = total_energy(variables, data) 204 f = jnp.zeros_like(out["coordinates"]) 205 out["forces"] = f 206 return e, f, out 207 208 def energy_and_forces_and_virial(variables, data): 209 e, f, out = energy_and_forces(variables, data) 210 v = jnp.zeros( 211 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 212 ) 213 out["virial_tensor"] = v 214 return e, f, v, out 215 216 else: 217 # build the energy and force functions 218 def total_energy(variables, data): 219 out = self.__apply(variables, data) 220 atomic_energies = 0.0 221 system_energies = 0.0 222 species = out["species"] 223 nsys = out["natoms"].shape[0] 224 for term in self.energy_terms: 225 e = out[term] 226 if e.ndim > 1 and e.shape[-1] == 1: 227 e = jnp.squeeze(e, axis=-1) 228 if e.shape[0] == nsys and nsys != species.shape[0]: 229 system_energies += e 230 continue 231 assert e.shape == species.shape 232 atomic_energies += e 233 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 234 if isinstance(atomic_energies, jnp.ndarray): 235 if "true_atoms" in out: 236 atomic_energies = jnp.where( 237 out["true_atoms"], atomic_energies, 0.0 238 ) 239 atomic_energies = atomic_energies.flatten() 240 assert atomic_energies.shape == out["species"].shape 241 out["atomic_energies"] = atomic_energies 242 energies = jax.ops.segment_sum( 243 atomic_energies, 244 data["batch_index"], 245 num_segments=len(data["natoms"]), 246 ) 247 else: 248 energies = 0.0 249 250 if isinstance(system_energies, jnp.ndarray): 251 if "true_sys" in out: 252 system_energies = jnp.where( 253 out["true_sys"], system_energies, 0.0 254 ) 255 out["system_energies"] = system_energies 256 257 out["total_energy"] = energies + system_energies 258 return out["total_energy"], out 259 260 def energy_and_forces(variables, data): 261 def _etot(variables, coordinates): 262 energy, out = total_energy( 263 variables, {**data, "coordinates": coordinates} 264 ) 265 return energy.sum(), out 266 267 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 268 variables, data["coordinates"] 269 ) 270 out["forces"] = -de 271 272 return out["total_energy"], out["forces"], out 273 274 # def energy_and_forces_and_virial(variables, data): 275 # x = data["coordinates"] 276 # batch_index = data["batch_index"] 277 # if "cells" in data: 278 # cells = data["cells"] 279 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 280 281 # def _etot(variables, coordinates, cells): 282 # reciprocal_cells = jnp.linalg.inv(cells) 283 # energy, out = total_energy( 284 # variables, 285 # { 286 # **data, 287 # "coordinates": coordinates, 288 # "cells": cells, 289 # "reciprocal_cells": reciprocal_cells, 290 # }, 291 # ) 292 # return energy.sum(), out 293 294 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 295 # variables, x, cells 296 # ) 297 # f= -dedx 298 # out["forces"] = f 299 # else: 300 # _,f,out = energy_and_forces(variables, data) 301 302 # vir = -jax.ops.segment_sum( 303 # f[:, :, None] * x[:, None, :], 304 # batch_index, 305 # num_segments=len(data["natoms"]), 306 # ) 307 308 # if "cells" in data: 309 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 310 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 311 # nsys = data["natoms"].shape[0] 312 # if cells.shape[0]==1 and nsys>1: 313 # dvir = dvir / nsys 314 # vir = vir + dvir 315 316 # out["virial_tensor"] = vir 317 318 # return out["total_energy"], f, vir, out 319 320 def energy_and_forces_and_virial(variables, data): 321 x = data["coordinates"] 322 scaling = jnp.asarray( 323 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 324 ) 325 def _etot(variables, coordinates, scaling): 326 batch_index = data["batch_index"] 327 coordinates = jax.vmap(jnp.matmul)( 328 coordinates, scaling[batch_index] 329 ) 330 inputs = {**data, "coordinates": coordinates} 331 if "cells" in data: 332 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 333 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 334 reciprocal_cells = jnp.linalg.inv(cells) 335 inputs["cells"] = cells 336 inputs["reciprocal_cells"] = reciprocal_cells 337 energy, out = total_energy(variables, inputs) 338 return energy.sum(), out 339 340 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 341 variables, x, scaling 342 ) 343 f = -dedx 344 out["forces"] = f 345 out["virial_tensor"] = vir 346 347 return out["total_energy"], f, vir, out 348 349 object.__setattr__(self, "_total_energy_raw", total_energy) 350 if jit: 351 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 352 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 353 object.__setattr__( 354 self, 355 "_energy_and_forces_and_virial", 356 jax.jit(energy_and_forces_and_virial), 357 ) 358 else: 359 object.__setattr__(self, "_total_energy", total_energy) 360 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 361 object.__setattr__( 362 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 363 ) 364 365 def get_gradient_function( 366 self, 367 *gradient_keys: Sequence[str], 368 jit: bool = True, 369 variables_as_input: bool = False, 370 ): 371 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 372 373 def _energy_gradient(variables, data): 374 def _etot(variables, inputs): 375 if "strain" in inputs: 376 scaling = inputs["strain"] 377 batch_index = data["batch_index"] 378 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 379 coordinates = jax.vmap(jnp.matmul)( 380 coordinates, scaling[batch_index] 381 ) 382 inputs = {**inputs, "coordinates": coordinates} 383 if "cells" in inputs or "cells" in data: 384 cells = inputs["cells"] if "cells" in inputs else data["cells"] 385 cells = jax.vmap(jnp.matmul)(cells, scaling) 386 inputs["cells"] = cells 387 if "cells" in inputs: 388 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 389 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 390 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 391 return energy.sum(), out 392 393 if "strain" in gradient_keys and "strain" not in data: 394 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 395 inputs = {k: data[k] for k in gradient_keys} 396 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 397 398 return ( 399 out["total_energy"], 400 de, 401 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 402 ) 403 404 if variables_as_input: 405 energy_gradient = _energy_gradient 406 else: 407 408 def energy_gradient(data): 409 return _energy_gradient(self.variables, data) 410 411 if jit: 412 return jax.jit(energy_gradient) 413 else: 414 return energy_gradient 415 416 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 417 """apply preprocessing to the input data 418 419 !!! This is not a pure function => do not apply jax transforms !!!""" 420 if self.preproc_state is None: 421 out, _ = self.reinitialize_preprocessing(example_data=inputs) 422 elif use_gpu: 423 do_check_input = self.preproc_state.get("check_input", True) 424 if do_check_input: 425 inputs = check_input(inputs) 426 preproc_state, inputs = self.preprocessing.atom_padding( 427 self.preproc_state, inputs 428 ) 429 inputs = self.preprocessing.process(preproc_state, inputs) 430 preproc_state, state_up, out, overflow = ( 431 self.preprocessing.check_reallocate( 432 preproc_state, inputs 433 ) 434 ) 435 if verbose and overflow: 436 print("GPU preprocessing: nblist overflow => reallocating nblist") 437 print("size updates:", state_up) 438 else: 439 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 440 441 object.__setattr__(self, "preproc_state", preproc_state) 442 return out 443 444 def reinitialize_preprocessing( 445 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 446 ) -> None: 447 ### TODO ### 448 if rng_key is None: 449 rng_key_pre = jax.random.PRNGKey(0) 450 else: 451 rng_key, rng_key_pre = jax.random.split(rng_key) 452 453 if example_data is None: 454 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 455 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 456 457 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 458 object.__setattr__(self, "preproc_state", preproc_state) 459 return inputs, rng_key 460 461 def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]: 462 """Apply the FENNIX model (preprocess + modules) 463 464 !!! This is not a pure function => do not apply jax transforms !!! 465 if you want to apply jax transforms, use self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess 466 """ 467 if variables is None: 468 variables = self.variables 469 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 470 output = self._apply(variables, inputs) 471 if self.use_atom_padding: 472 output = atom_unpadding(output) 473 return output 474 475 def total_energy( 476 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 477 ) -> Tuple[jnp.ndarray, Dict]: 478 """compute the total energy of the system 479 480 !!! This is not a pure function => do not apply jax transforms !!! 481 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 482 """ 483 if variables is None: 484 variables = self.variables 485 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 486 # def print_shape(path,value): 487 # if isinstance(value,jnp.ndarray): 488 # print(path,value.shape) 489 # else: 490 # print(path,value) 491 # jax.tree_util.tree_map_with_path(print_shape,inputs) 492 _, output = self._total_energy(variables, inputs) 493 if self.use_atom_padding: 494 output = atom_unpadding(output) 495 e = output["total_energy"] 496 if unit is not None: 497 model_energy_unit = self.Ha_to_model_energy 498 if isinstance(unit, str): 499 unit = au.get_multiplier(unit) 500 e = e * (unit / model_energy_unit) 501 return e, output 502 503 def energy_and_forces( 504 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 505 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 506 """compute the total energy and forces of the system 507 508 !!! This is not a pure function => do not apply jax transforms !!! 509 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 510 """ 511 if variables is None: 512 variables = self.variables 513 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 514 _, _, output = self._energy_and_forces(variables, inputs) 515 if self.use_atom_padding: 516 output = atom_unpadding(output) 517 e = output["total_energy"] 518 f = output["forces"] 519 if unit is not None: 520 model_energy_unit = self.Ha_to_model_energy 521 if isinstance(unit, str): 522 unit = au.get_multiplier(unit) 523 e = e * (unit / model_energy_unit) 524 f = f * (unit / model_energy_unit) 525 return e, f, output 526 527 def energy_and_forces_and_virial( 528 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 529 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 530 """compute the total energy and forces of the system 531 532 !!! This is not a pure function => do not apply jax transforms !!! 533 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 534 """ 535 if variables is None: 536 variables = self.variables 537 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 538 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 539 if self.use_atom_padding: 540 output = atom_unpadding(output) 541 e = output["total_energy"] 542 f = output["forces"] 543 v = output["virial_tensor"] 544 if unit is not None: 545 model_energy_unit = self.Ha_to_model_energy 546 if isinstance(unit, str): 547 unit = au.get_multiplier(unit) 548 e = e * (unit / model_energy_unit) 549 f = f * (unit / model_energy_unit) 550 v = v * (unit / model_energy_unit) 551 return e, f, v, output 552 553 def remove_atom_padding(self, output): 554 """remove atom padding from the output""" 555 return atom_unpadding(output) 556 557 def get_model(self) -> Tuple[FENNIXModules, Dict]: 558 """return the model and its variables""" 559 return self.modules, self.variables 560 561 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 562 """return the preprocessing chain and its state""" 563 return self.preprocessing, self.preproc_state 564 565 def __setattr__(self, __name: str, __value: Any) -> None: 566 if __name == "variables": 567 if __value is not None: 568 if not ( 569 isinstance(__value, dict) 570 or isinstance(__value, OrderedDict) 571 or isinstance(__value, FrozenDict) 572 ): 573 raise ValueError(f"{__name} must be a dict") 574 object.__setattr__(self, __name, JaxConverter()(__value)) 575 else: 576 raise ValueError(f"{__name} cannot be None") 577 elif __name == "preproc_state": 578 if __value is not None: 579 if not ( 580 isinstance(__value, dict) 581 or isinstance(__value, OrderedDict) 582 or isinstance(__value, FrozenDict) 583 ): 584 raise ValueError(f"{__name} must be a FrozenDict") 585 object.__setattr__(self, __name, freeze(JaxConverter()(__value))) 586 else: 587 raise ValueError(f"{__name} cannot be None") 588 589 elif self._initializing: 590 object.__setattr__(self, __name, __value) 591 else: 592 raise ValueError(f"{__name} attribute of FENNIX model is immutable.") 593 594 def generate_dummy_system( 595 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 596 ) -> Dict[str, Any]: 597 """ 598 Generate dummy system for initialization 599 """ 600 if box_size is None: 601 box_size = 2 * self.cutoff 602 for g in self._graphs_properties.values(): 603 cutoff = g["cutoff"] 604 if cutoff is not None: 605 box_size = min(box_size, 2 * g["cutoff"]) 606 coordinates = np.array( 607 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 608 ) 609 species = np.ones((n_atoms,), dtype=np.int32) 610 batch_index = np.zeros((n_atoms,), dtype=np.int32) 611 natoms = np.array([n_atoms], dtype=np.int32) 612 return { 613 "species": species, 614 "coordinates": coordinates, 615 # "graph": graph, 616 "batch_index": batch_index, 617 "natoms": natoms, 618 } 619 620 def summarize( 621 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 622 ) -> str: 623 """Summarize the model architecture and parameters""" 624 if rng_key is None: 625 head = "Summarizing with example data:\n" 626 rng_key = jax.random.PRNGKey(0) 627 if example_data is None: 628 head = "Summarizing with dummy 10 atoms system:\n" 629 rng_key, rng_key_sys = jax.random.split(rng_key) 630 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 631 rng_key, rng_key_pre = jax.random.split(rng_key) 632 _, inputs = self.preprocessing.init_with_output(example_data) 633 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs) 634 635 def to_dict(self,convert_numpy=False): 636 """return a dictionary representation of the model""" 637 if convert_numpy: 638 variables = convert_to_numpy(self.variables) 639 else: 640 variables = deepcopy(self.variables) 641 return { 642 **self._input_args, 643 "energy_terms": self.energy_terms, 644 "variables": variables, 645 } 646 647 def save(self, filename): 648 """save the model to a file""" 649 filename_str = str(filename) 650 do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle") 651 state_dict = self.to_dict(convert_numpy=do_pickle) 652 state_dict["preprocessing"] = [ 653 [k, v] for k, v in state_dict["preprocessing"].items() 654 ] 655 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 656 if do_pickle: 657 import pickle 658 with open(filename, "wb") as f: 659 pickle.dump(state_dict, f) 660 else: 661 with open(filename, "wb") as f: 662 f.write(serialization.msgpack_serialize(state_dict)) 663 664 @classmethod 665 def load( 666 cls, 667 filename, 668 use_atom_padding=False, 669 graph_config={}, 670 ): 671 """load a model from a file""" 672 filename_str = str(filename) 673 do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle") 674 if do_pickle: 675 import pickle 676 with open(filename, "rb") as f: 677 state_dict = pickle.load(f) 678 state_dict["variables"] = convert_to_jax(state_dict["variables"]) 679 else: 680 with open(filename, "rb") as f: 681 state_dict = serialization.msgpack_restore(f.read()) 682 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 683 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 684 return cls( 685 **state_dict, 686 graph_config=graph_config, 687 use_atom_padding=use_atom_padding, 688 )
27@dataclasses.dataclass 28class FENNIX: 29 """ 30 Static wrapper for FENNIX models 31 32 The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary 33 which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization. 34 35 Since the model is static and contains variables, it must be initialized right away with either 36 `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data` 37 is provided, the model is initialized with `example_data` and the resulting variables are stored 38 in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting. 39 """ 40 41 cutoff: Union[float, None] 42 modules: FENNIXModules 43 variables: Dict 44 preprocessing: PreprocessingChain 45 _apply: Callable[[Dict, Dict], Dict] 46 _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]] 47 _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]] 48 _input_args: Dict 49 _graphs_properties: Dict 50 preproc_state: Dict 51 energy_terms: Optional[Sequence[str]] = None 52 _initializing: bool = True 53 use_atom_padding: bool = False 54 55 def __init__( 56 self, 57 cutoff: float, 58 modules: OrderedDict, 59 preprocessing: OrderedDict = OrderedDict(), 60 example_data=None, 61 rng_key: Optional[jax.random.PRNGKey] = None, 62 variables: Optional[dict] = None, 63 energy_terms: Optional[Sequence[str]] = None, 64 use_atom_padding: bool = False, 65 graph_config: Dict = {}, 66 energy_unit: str = "Ha", 67 **kwargs, 68 ) -> None: 69 """Initialize the FENNIX model 70 71 Arguments: 72 ---------- 73 cutoff: float 74 The cutoff radius for the model 75 modules: OrderedDict 76 The dictionary defining the sequence of FeNNol modules and their parameters. 77 preprocessing: OrderedDict 78 The dictionary defining the sequence of preprocessing modules and their parameters. 79 example_data: dict 80 Example data to initialize the model. If not provided, a dummy system is generated. 81 rng_key: jax.random.PRNGKey 82 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 83 variables: dict 84 The variables of the model (i.e. weights, biases and all other tunable parameters). 85 If not provided, the variables are initialized (usually at random) 86 energy_terms: Sequence[str] 87 The energy terms in the model output that will be summed to compute the total energy. 88 If None, the total energy is always zero (useful for non-PES models). 89 use_atom_padding: bool 90 If True, the model will use atom padding for the input data. 91 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 92 graph_config: dict 93 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 94 95 """ 96 self._input_args = { 97 "cutoff": cutoff, 98 "modules": OrderedDict(modules), 99 "preprocessing": OrderedDict(preprocessing), 100 "energy_unit": energy_unit, 101 } 102 self.energy_unit = energy_unit 103 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 104 self.cutoff = cutoff 105 self.energy_terms = energy_terms 106 self.use_atom_padding = use_atom_padding 107 108 # add non-differentiable/non-jittable modules 109 preprocessing = deepcopy(preprocessing) 110 if cutoff is None: 111 preprocessing_modules = [] 112 else: 113 prep_keys = list(preprocessing.keys()) 114 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 115 if len(prep_keys) > 0 and prep_keys[0] == "graph": 116 graph_params = { 117 **graph_params, 118 **preprocessing.pop("graph"), 119 } 120 graph_params = {**graph_params, **graph_config} 121 122 preprocessing_modules = [ 123 GraphGenerator(**graph_params), 124 ] 125 126 for name, params in preprocessing.items(): 127 key = str(params.pop("module_name")) if "module_name" in params else name 128 key = str(params.pop("FID")) if "FID" in params else key 129 mod = PREPROCESSING[key.upper()](**freeze(params)) 130 preprocessing_modules.append(mod) 131 132 self.preprocessing = PreprocessingChain( 133 tuple(preprocessing_modules), use_atom_padding 134 ) 135 graphs_properties = self.preprocessing.get_graphs_properties() 136 self._graphs_properties = freeze(graphs_properties) 137 # add preprocessing modules that should be differentiated/jitted 138 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 139 # mods = self.preprocessing.get_processors(return_list=True) 140 141 # build the model 142 modules = deepcopy(modules) 143 modules_names = [] 144 for name, params in modules.items(): 145 key = str(params.pop("module_name")) if "module_name" in params else name 146 key = str(params.pop("FID")) if "FID" in params else key 147 if name in modules_names: 148 raise ValueError(f"Module {name} already exists") 149 modules_names.append(name) 150 params["name"] = name 151 mod = MODULES[key.upper()] 152 fields = [f.name for f in dataclasses.fields(mod)] 153 if "_graphs_properties" in fields: 154 params["_graphs_properties"] = graphs_properties 155 if "_energy_unit" in fields: 156 params["_energy_unit"] = energy_unit 157 mods.append((mod, params)) 158 159 self.modules = FENNIXModules(mods) 160 161 self.__apply = self.modules.apply 162 self._apply = jax.jit(self.modules.apply) 163 164 self.set_energy_terms(energy_terms) 165 166 # initialize the model 167 168 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 169 170 if variables is not None: 171 self.variables = variables 172 elif rng_key is not None: 173 self.variables = self.modules.init(rng_key, inputs) 174 else: 175 raise ValueError( 176 "Either variables or a jax.random.PRNGKey must be provided for initialization" 177 ) 178 179 self._initializing = False 180 181 def set_energy_terms( 182 self, energy_terms: Union[Sequence[str], None], jit: bool = True 183 ) -> None: 184 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 185 object.__setattr__(self, "energy_terms", energy_terms) 186 if isinstance(energy_terms, str): 187 energy_terms = [energy_terms] 188 189 if energy_terms is None or len(energy_terms) == 0: 190 191 def total_energy(variables, data): 192 out = self.__apply(variables, data) 193 coords = out["coordinates"] 194 nsys = out["natoms"].shape[0] 195 nat = coords.shape[0] 196 dtype = coords.dtype 197 e = jnp.zeros(nsys, dtype=dtype) 198 eat = jnp.zeros(nat, dtype=dtype) 199 out["total_energy"] = e 200 out["atomic_energies"] = eat 201 return e, out 202 203 def energy_and_forces(variables, data): 204 e, out = total_energy(variables, data) 205 f = jnp.zeros_like(out["coordinates"]) 206 out["forces"] = f 207 return e, f, out 208 209 def energy_and_forces_and_virial(variables, data): 210 e, f, out = energy_and_forces(variables, data) 211 v = jnp.zeros( 212 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 213 ) 214 out["virial_tensor"] = v 215 return e, f, v, out 216 217 else: 218 # build the energy and force functions 219 def total_energy(variables, data): 220 out = self.__apply(variables, data) 221 atomic_energies = 0.0 222 system_energies = 0.0 223 species = out["species"] 224 nsys = out["natoms"].shape[0] 225 for term in self.energy_terms: 226 e = out[term] 227 if e.ndim > 1 and e.shape[-1] == 1: 228 e = jnp.squeeze(e, axis=-1) 229 if e.shape[0] == nsys and nsys != species.shape[0]: 230 system_energies += e 231 continue 232 assert e.shape == species.shape 233 atomic_energies += e 234 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 235 if isinstance(atomic_energies, jnp.ndarray): 236 if "true_atoms" in out: 237 atomic_energies = jnp.where( 238 out["true_atoms"], atomic_energies, 0.0 239 ) 240 atomic_energies = atomic_energies.flatten() 241 assert atomic_energies.shape == out["species"].shape 242 out["atomic_energies"] = atomic_energies 243 energies = jax.ops.segment_sum( 244 atomic_energies, 245 data["batch_index"], 246 num_segments=len(data["natoms"]), 247 ) 248 else: 249 energies = 0.0 250 251 if isinstance(system_energies, jnp.ndarray): 252 if "true_sys" in out: 253 system_energies = jnp.where( 254 out["true_sys"], system_energies, 0.0 255 ) 256 out["system_energies"] = system_energies 257 258 out["total_energy"] = energies + system_energies 259 return out["total_energy"], out 260 261 def energy_and_forces(variables, data): 262 def _etot(variables, coordinates): 263 energy, out = total_energy( 264 variables, {**data, "coordinates": coordinates} 265 ) 266 return energy.sum(), out 267 268 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 269 variables, data["coordinates"] 270 ) 271 out["forces"] = -de 272 273 return out["total_energy"], out["forces"], out 274 275 # def energy_and_forces_and_virial(variables, data): 276 # x = data["coordinates"] 277 # batch_index = data["batch_index"] 278 # if "cells" in data: 279 # cells = data["cells"] 280 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 281 282 # def _etot(variables, coordinates, cells): 283 # reciprocal_cells = jnp.linalg.inv(cells) 284 # energy, out = total_energy( 285 # variables, 286 # { 287 # **data, 288 # "coordinates": coordinates, 289 # "cells": cells, 290 # "reciprocal_cells": reciprocal_cells, 291 # }, 292 # ) 293 # return energy.sum(), out 294 295 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 296 # variables, x, cells 297 # ) 298 # f= -dedx 299 # out["forces"] = f 300 # else: 301 # _,f,out = energy_and_forces(variables, data) 302 303 # vir = -jax.ops.segment_sum( 304 # f[:, :, None] * x[:, None, :], 305 # batch_index, 306 # num_segments=len(data["natoms"]), 307 # ) 308 309 # if "cells" in data: 310 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 311 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 312 # nsys = data["natoms"].shape[0] 313 # if cells.shape[0]==1 and nsys>1: 314 # dvir = dvir / nsys 315 # vir = vir + dvir 316 317 # out["virial_tensor"] = vir 318 319 # return out["total_energy"], f, vir, out 320 321 def energy_and_forces_and_virial(variables, data): 322 x = data["coordinates"] 323 scaling = jnp.asarray( 324 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 325 ) 326 def _etot(variables, coordinates, scaling): 327 batch_index = data["batch_index"] 328 coordinates = jax.vmap(jnp.matmul)( 329 coordinates, scaling[batch_index] 330 ) 331 inputs = {**data, "coordinates": coordinates} 332 if "cells" in data: 333 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 334 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 335 reciprocal_cells = jnp.linalg.inv(cells) 336 inputs["cells"] = cells 337 inputs["reciprocal_cells"] = reciprocal_cells 338 energy, out = total_energy(variables, inputs) 339 return energy.sum(), out 340 341 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 342 variables, x, scaling 343 ) 344 f = -dedx 345 out["forces"] = f 346 out["virial_tensor"] = vir 347 348 return out["total_energy"], f, vir, out 349 350 object.__setattr__(self, "_total_energy_raw", total_energy) 351 if jit: 352 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 353 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 354 object.__setattr__( 355 self, 356 "_energy_and_forces_and_virial", 357 jax.jit(energy_and_forces_and_virial), 358 ) 359 else: 360 object.__setattr__(self, "_total_energy", total_energy) 361 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 362 object.__setattr__( 363 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 364 ) 365 366 def get_gradient_function( 367 self, 368 *gradient_keys: Sequence[str], 369 jit: bool = True, 370 variables_as_input: bool = False, 371 ): 372 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 373 374 def _energy_gradient(variables, data): 375 def _etot(variables, inputs): 376 if "strain" in inputs: 377 scaling = inputs["strain"] 378 batch_index = data["batch_index"] 379 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 380 coordinates = jax.vmap(jnp.matmul)( 381 coordinates, scaling[batch_index] 382 ) 383 inputs = {**inputs, "coordinates": coordinates} 384 if "cells" in inputs or "cells" in data: 385 cells = inputs["cells"] if "cells" in inputs else data["cells"] 386 cells = jax.vmap(jnp.matmul)(cells, scaling) 387 inputs["cells"] = cells 388 if "cells" in inputs: 389 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 390 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 391 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 392 return energy.sum(), out 393 394 if "strain" in gradient_keys and "strain" not in data: 395 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 396 inputs = {k: data[k] for k in gradient_keys} 397 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 398 399 return ( 400 out["total_energy"], 401 de, 402 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 403 ) 404 405 if variables_as_input: 406 energy_gradient = _energy_gradient 407 else: 408 409 def energy_gradient(data): 410 return _energy_gradient(self.variables, data) 411 412 if jit: 413 return jax.jit(energy_gradient) 414 else: 415 return energy_gradient 416 417 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 418 """apply preprocessing to the input data 419 420 !!! This is not a pure function => do not apply jax transforms !!!""" 421 if self.preproc_state is None: 422 out, _ = self.reinitialize_preprocessing(example_data=inputs) 423 elif use_gpu: 424 do_check_input = self.preproc_state.get("check_input", True) 425 if do_check_input: 426 inputs = check_input(inputs) 427 preproc_state, inputs = self.preprocessing.atom_padding( 428 self.preproc_state, inputs 429 ) 430 inputs = self.preprocessing.process(preproc_state, inputs) 431 preproc_state, state_up, out, overflow = ( 432 self.preprocessing.check_reallocate( 433 preproc_state, inputs 434 ) 435 ) 436 if verbose and overflow: 437 print("GPU preprocessing: nblist overflow => reallocating nblist") 438 print("size updates:", state_up) 439 else: 440 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 441 442 object.__setattr__(self, "preproc_state", preproc_state) 443 return out 444 445 def reinitialize_preprocessing( 446 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 447 ) -> None: 448 ### TODO ### 449 if rng_key is None: 450 rng_key_pre = jax.random.PRNGKey(0) 451 else: 452 rng_key, rng_key_pre = jax.random.split(rng_key) 453 454 if example_data is None: 455 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 456 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 457 458 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 459 object.__setattr__(self, "preproc_state", preproc_state) 460 return inputs, rng_key 461 462 def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]: 463 """Apply the FENNIX model (preprocess + modules) 464 465 !!! This is not a pure function => do not apply jax transforms !!! 466 if you want to apply jax transforms, use self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess 467 """ 468 if variables is None: 469 variables = self.variables 470 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 471 output = self._apply(variables, inputs) 472 if self.use_atom_padding: 473 output = atom_unpadding(output) 474 return output 475 476 def total_energy( 477 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 478 ) -> Tuple[jnp.ndarray, Dict]: 479 """compute the total energy of the system 480 481 !!! This is not a pure function => do not apply jax transforms !!! 482 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 483 """ 484 if variables is None: 485 variables = self.variables 486 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 487 # def print_shape(path,value): 488 # if isinstance(value,jnp.ndarray): 489 # print(path,value.shape) 490 # else: 491 # print(path,value) 492 # jax.tree_util.tree_map_with_path(print_shape,inputs) 493 _, output = self._total_energy(variables, inputs) 494 if self.use_atom_padding: 495 output = atom_unpadding(output) 496 e = output["total_energy"] 497 if unit is not None: 498 model_energy_unit = self.Ha_to_model_energy 499 if isinstance(unit, str): 500 unit = au.get_multiplier(unit) 501 e = e * (unit / model_energy_unit) 502 return e, output 503 504 def energy_and_forces( 505 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 506 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 507 """compute the total energy and forces of the system 508 509 !!! This is not a pure function => do not apply jax transforms !!! 510 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 511 """ 512 if variables is None: 513 variables = self.variables 514 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 515 _, _, output = self._energy_and_forces(variables, inputs) 516 if self.use_atom_padding: 517 output = atom_unpadding(output) 518 e = output["total_energy"] 519 f = output["forces"] 520 if unit is not None: 521 model_energy_unit = self.Ha_to_model_energy 522 if isinstance(unit, str): 523 unit = au.get_multiplier(unit) 524 e = e * (unit / model_energy_unit) 525 f = f * (unit / model_energy_unit) 526 return e, f, output 527 528 def energy_and_forces_and_virial( 529 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 530 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 531 """compute the total energy and forces of the system 532 533 !!! This is not a pure function => do not apply jax transforms !!! 534 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 535 """ 536 if variables is None: 537 variables = self.variables 538 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 539 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 540 if self.use_atom_padding: 541 output = atom_unpadding(output) 542 e = output["total_energy"] 543 f = output["forces"] 544 v = output["virial_tensor"] 545 if unit is not None: 546 model_energy_unit = self.Ha_to_model_energy 547 if isinstance(unit, str): 548 unit = au.get_multiplier(unit) 549 e = e * (unit / model_energy_unit) 550 f = f * (unit / model_energy_unit) 551 v = v * (unit / model_energy_unit) 552 return e, f, v, output 553 554 def remove_atom_padding(self, output): 555 """remove atom padding from the output""" 556 return atom_unpadding(output) 557 558 def get_model(self) -> Tuple[FENNIXModules, Dict]: 559 """return the model and its variables""" 560 return self.modules, self.variables 561 562 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 563 """return the preprocessing chain and its state""" 564 return self.preprocessing, self.preproc_state 565 566 def __setattr__(self, __name: str, __value: Any) -> None: 567 if __name == "variables": 568 if __value is not None: 569 if not ( 570 isinstance(__value, dict) 571 or isinstance(__value, OrderedDict) 572 or isinstance(__value, FrozenDict) 573 ): 574 raise ValueError(f"{__name} must be a dict") 575 object.__setattr__(self, __name, JaxConverter()(__value)) 576 else: 577 raise ValueError(f"{__name} cannot be None") 578 elif __name == "preproc_state": 579 if __value is not None: 580 if not ( 581 isinstance(__value, dict) 582 or isinstance(__value, OrderedDict) 583 or isinstance(__value, FrozenDict) 584 ): 585 raise ValueError(f"{__name} must be a FrozenDict") 586 object.__setattr__(self, __name, freeze(JaxConverter()(__value))) 587 else: 588 raise ValueError(f"{__name} cannot be None") 589 590 elif self._initializing: 591 object.__setattr__(self, __name, __value) 592 else: 593 raise ValueError(f"{__name} attribute of FENNIX model is immutable.") 594 595 def generate_dummy_system( 596 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 597 ) -> Dict[str, Any]: 598 """ 599 Generate dummy system for initialization 600 """ 601 if box_size is None: 602 box_size = 2 * self.cutoff 603 for g in self._graphs_properties.values(): 604 cutoff = g["cutoff"] 605 if cutoff is not None: 606 box_size = min(box_size, 2 * g["cutoff"]) 607 coordinates = np.array( 608 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 609 ) 610 species = np.ones((n_atoms,), dtype=np.int32) 611 batch_index = np.zeros((n_atoms,), dtype=np.int32) 612 natoms = np.array([n_atoms], dtype=np.int32) 613 return { 614 "species": species, 615 "coordinates": coordinates, 616 # "graph": graph, 617 "batch_index": batch_index, 618 "natoms": natoms, 619 } 620 621 def summarize( 622 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 623 ) -> str: 624 """Summarize the model architecture and parameters""" 625 if rng_key is None: 626 head = "Summarizing with example data:\n" 627 rng_key = jax.random.PRNGKey(0) 628 if example_data is None: 629 head = "Summarizing with dummy 10 atoms system:\n" 630 rng_key, rng_key_sys = jax.random.split(rng_key) 631 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 632 rng_key, rng_key_pre = jax.random.split(rng_key) 633 _, inputs = self.preprocessing.init_with_output(example_data) 634 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs) 635 636 def to_dict(self,convert_numpy=False): 637 """return a dictionary representation of the model""" 638 if convert_numpy: 639 variables = convert_to_numpy(self.variables) 640 else: 641 variables = deepcopy(self.variables) 642 return { 643 **self._input_args, 644 "energy_terms": self.energy_terms, 645 "variables": variables, 646 } 647 648 def save(self, filename): 649 """save the model to a file""" 650 filename_str = str(filename) 651 do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle") 652 state_dict = self.to_dict(convert_numpy=do_pickle) 653 state_dict["preprocessing"] = [ 654 [k, v] for k, v in state_dict["preprocessing"].items() 655 ] 656 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 657 if do_pickle: 658 import pickle 659 with open(filename, "wb") as f: 660 pickle.dump(state_dict, f) 661 else: 662 with open(filename, "wb") as f: 663 f.write(serialization.msgpack_serialize(state_dict)) 664 665 @classmethod 666 def load( 667 cls, 668 filename, 669 use_atom_padding=False, 670 graph_config={}, 671 ): 672 """load a model from a file""" 673 filename_str = str(filename) 674 do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle") 675 if do_pickle: 676 import pickle 677 with open(filename, "rb") as f: 678 state_dict = pickle.load(f) 679 state_dict["variables"] = convert_to_jax(state_dict["variables"]) 680 else: 681 with open(filename, "rb") as f: 682 state_dict = serialization.msgpack_restore(f.read()) 683 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 684 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 685 return cls( 686 **state_dict, 687 graph_config=graph_config, 688 use_atom_padding=use_atom_padding, 689 )
Static wrapper for FENNIX models
The underlying model is a fennol.models.modules.FENNIXModules built from the modules dictionary
which references registered modules in fennol.models.modules.MODULES and provides the parameters for initialization.
Since the model is static and contains variables, it must be initialized right away with either
example_data, variables or rng_key. If variables is provided, it is used directly. If example_data
is provided, the model is initialized with example_data and the resulting variables are stored
in the wrapper. If only rng_key is provided, the model is initialized with a dummy system and the resulting.
55 def __init__( 56 self, 57 cutoff: float, 58 modules: OrderedDict, 59 preprocessing: OrderedDict = OrderedDict(), 60 example_data=None, 61 rng_key: Optional[jax.random.PRNGKey] = None, 62 variables: Optional[dict] = None, 63 energy_terms: Optional[Sequence[str]] = None, 64 use_atom_padding: bool = False, 65 graph_config: Dict = {}, 66 energy_unit: str = "Ha", 67 **kwargs, 68 ) -> None: 69 """Initialize the FENNIX model 70 71 Arguments: 72 ---------- 73 cutoff: float 74 The cutoff radius for the model 75 modules: OrderedDict 76 The dictionary defining the sequence of FeNNol modules and their parameters. 77 preprocessing: OrderedDict 78 The dictionary defining the sequence of preprocessing modules and their parameters. 79 example_data: dict 80 Example data to initialize the model. If not provided, a dummy system is generated. 81 rng_key: jax.random.PRNGKey 82 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 83 variables: dict 84 The variables of the model (i.e. weights, biases and all other tunable parameters). 85 If not provided, the variables are initialized (usually at random) 86 energy_terms: Sequence[str] 87 The energy terms in the model output that will be summed to compute the total energy. 88 If None, the total energy is always zero (useful for non-PES models). 89 use_atom_padding: bool 90 If True, the model will use atom padding for the input data. 91 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 92 graph_config: dict 93 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 94 95 """ 96 self._input_args = { 97 "cutoff": cutoff, 98 "modules": OrderedDict(modules), 99 "preprocessing": OrderedDict(preprocessing), 100 "energy_unit": energy_unit, 101 } 102 self.energy_unit = energy_unit 103 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 104 self.cutoff = cutoff 105 self.energy_terms = energy_terms 106 self.use_atom_padding = use_atom_padding 107 108 # add non-differentiable/non-jittable modules 109 preprocessing = deepcopy(preprocessing) 110 if cutoff is None: 111 preprocessing_modules = [] 112 else: 113 prep_keys = list(preprocessing.keys()) 114 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 115 if len(prep_keys) > 0 and prep_keys[0] == "graph": 116 graph_params = { 117 **graph_params, 118 **preprocessing.pop("graph"), 119 } 120 graph_params = {**graph_params, **graph_config} 121 122 preprocessing_modules = [ 123 GraphGenerator(**graph_params), 124 ] 125 126 for name, params in preprocessing.items(): 127 key = str(params.pop("module_name")) if "module_name" in params else name 128 key = str(params.pop("FID")) if "FID" in params else key 129 mod = PREPROCESSING[key.upper()](**freeze(params)) 130 preprocessing_modules.append(mod) 131 132 self.preprocessing = PreprocessingChain( 133 tuple(preprocessing_modules), use_atom_padding 134 ) 135 graphs_properties = self.preprocessing.get_graphs_properties() 136 self._graphs_properties = freeze(graphs_properties) 137 # add preprocessing modules that should be differentiated/jitted 138 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 139 # mods = self.preprocessing.get_processors(return_list=True) 140 141 # build the model 142 modules = deepcopy(modules) 143 modules_names = [] 144 for name, params in modules.items(): 145 key = str(params.pop("module_name")) if "module_name" in params else name 146 key = str(params.pop("FID")) if "FID" in params else key 147 if name in modules_names: 148 raise ValueError(f"Module {name} already exists") 149 modules_names.append(name) 150 params["name"] = name 151 mod = MODULES[key.upper()] 152 fields = [f.name for f in dataclasses.fields(mod)] 153 if "_graphs_properties" in fields: 154 params["_graphs_properties"] = graphs_properties 155 if "_energy_unit" in fields: 156 params["_energy_unit"] = energy_unit 157 mods.append((mod, params)) 158 159 self.modules = FENNIXModules(mods) 160 161 self.__apply = self.modules.apply 162 self._apply = jax.jit(self.modules.apply) 163 164 self.set_energy_terms(energy_terms) 165 166 # initialize the model 167 168 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 169 170 if variables is not None: 171 self.variables = variables 172 elif rng_key is not None: 173 self.variables = self.modules.init(rng_key, inputs) 174 else: 175 raise ValueError( 176 "Either variables or a jax.random.PRNGKey must be provided for initialization" 177 ) 178 179 self._initializing = False
Initialize the FENNIX model
Arguments:
cutoff: float The cutoff radius for the model modules: OrderedDict The dictionary defining the sequence of FeNNol modules and their parameters. preprocessing: OrderedDict The dictionary defining the sequence of preprocessing modules and their parameters. example_data: dict Example data to initialize the model. If not provided, a dummy system is generated. rng_key: jax.random.PRNGKey The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). variables: dict The variables of the model (i.e. weights, biases and all other tunable parameters). If not provided, the variables are initialized (usually at random) energy_terms: Sequence[str] The energy terms in the model output that will be summed to compute the total energy. If None, the total energy is always zero (useful for non-PES models). use_atom_padding: bool If True, the model will use atom padding for the input data. This is useful when one plans to frequently change the number of atoms in the system (for example during training). graph_config: dict Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
181 def set_energy_terms( 182 self, energy_terms: Union[Sequence[str], None], jit: bool = True 183 ) -> None: 184 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 185 object.__setattr__(self, "energy_terms", energy_terms) 186 if isinstance(energy_terms, str): 187 energy_terms = [energy_terms] 188 189 if energy_terms is None or len(energy_terms) == 0: 190 191 def total_energy(variables, data): 192 out = self.__apply(variables, data) 193 coords = out["coordinates"] 194 nsys = out["natoms"].shape[0] 195 nat = coords.shape[0] 196 dtype = coords.dtype 197 e = jnp.zeros(nsys, dtype=dtype) 198 eat = jnp.zeros(nat, dtype=dtype) 199 out["total_energy"] = e 200 out["atomic_energies"] = eat 201 return e, out 202 203 def energy_and_forces(variables, data): 204 e, out = total_energy(variables, data) 205 f = jnp.zeros_like(out["coordinates"]) 206 out["forces"] = f 207 return e, f, out 208 209 def energy_and_forces_and_virial(variables, data): 210 e, f, out = energy_and_forces(variables, data) 211 v = jnp.zeros( 212 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 213 ) 214 out["virial_tensor"] = v 215 return e, f, v, out 216 217 else: 218 # build the energy and force functions 219 def total_energy(variables, data): 220 out = self.__apply(variables, data) 221 atomic_energies = 0.0 222 system_energies = 0.0 223 species = out["species"] 224 nsys = out["natoms"].shape[0] 225 for term in self.energy_terms: 226 e = out[term] 227 if e.ndim > 1 and e.shape[-1] == 1: 228 e = jnp.squeeze(e, axis=-1) 229 if e.shape[0] == nsys and nsys != species.shape[0]: 230 system_energies += e 231 continue 232 assert e.shape == species.shape 233 atomic_energies += e 234 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 235 if isinstance(atomic_energies, jnp.ndarray): 236 if "true_atoms" in out: 237 atomic_energies = jnp.where( 238 out["true_atoms"], atomic_energies, 0.0 239 ) 240 atomic_energies = atomic_energies.flatten() 241 assert atomic_energies.shape == out["species"].shape 242 out["atomic_energies"] = atomic_energies 243 energies = jax.ops.segment_sum( 244 atomic_energies, 245 data["batch_index"], 246 num_segments=len(data["natoms"]), 247 ) 248 else: 249 energies = 0.0 250 251 if isinstance(system_energies, jnp.ndarray): 252 if "true_sys" in out: 253 system_energies = jnp.where( 254 out["true_sys"], system_energies, 0.0 255 ) 256 out["system_energies"] = system_energies 257 258 out["total_energy"] = energies + system_energies 259 return out["total_energy"], out 260 261 def energy_and_forces(variables, data): 262 def _etot(variables, coordinates): 263 energy, out = total_energy( 264 variables, {**data, "coordinates": coordinates} 265 ) 266 return energy.sum(), out 267 268 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 269 variables, data["coordinates"] 270 ) 271 out["forces"] = -de 272 273 return out["total_energy"], out["forces"], out 274 275 # def energy_and_forces_and_virial(variables, data): 276 # x = data["coordinates"] 277 # batch_index = data["batch_index"] 278 # if "cells" in data: 279 # cells = data["cells"] 280 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 281 282 # def _etot(variables, coordinates, cells): 283 # reciprocal_cells = jnp.linalg.inv(cells) 284 # energy, out = total_energy( 285 # variables, 286 # { 287 # **data, 288 # "coordinates": coordinates, 289 # "cells": cells, 290 # "reciprocal_cells": reciprocal_cells, 291 # }, 292 # ) 293 # return energy.sum(), out 294 295 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 296 # variables, x, cells 297 # ) 298 # f= -dedx 299 # out["forces"] = f 300 # else: 301 # _,f,out = energy_and_forces(variables, data) 302 303 # vir = -jax.ops.segment_sum( 304 # f[:, :, None] * x[:, None, :], 305 # batch_index, 306 # num_segments=len(data["natoms"]), 307 # ) 308 309 # if "cells" in data: 310 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 311 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 312 # nsys = data["natoms"].shape[0] 313 # if cells.shape[0]==1 and nsys>1: 314 # dvir = dvir / nsys 315 # vir = vir + dvir 316 317 # out["virial_tensor"] = vir 318 319 # return out["total_energy"], f, vir, out 320 321 def energy_and_forces_and_virial(variables, data): 322 x = data["coordinates"] 323 scaling = jnp.asarray( 324 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 325 ) 326 def _etot(variables, coordinates, scaling): 327 batch_index = data["batch_index"] 328 coordinates = jax.vmap(jnp.matmul)( 329 coordinates, scaling[batch_index] 330 ) 331 inputs = {**data, "coordinates": coordinates} 332 if "cells" in data: 333 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 334 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 335 reciprocal_cells = jnp.linalg.inv(cells) 336 inputs["cells"] = cells 337 inputs["reciprocal_cells"] = reciprocal_cells 338 energy, out = total_energy(variables, inputs) 339 return energy.sum(), out 340 341 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 342 variables, x, scaling 343 ) 344 f = -dedx 345 out["forces"] = f 346 out["virial_tensor"] = vir 347 348 return out["total_energy"], f, vir, out 349 350 object.__setattr__(self, "_total_energy_raw", total_energy) 351 if jit: 352 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 353 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 354 object.__setattr__( 355 self, 356 "_energy_and_forces_and_virial", 357 jax.jit(energy_and_forces_and_virial), 358 ) 359 else: 360 object.__setattr__(self, "_total_energy", total_energy) 361 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 362 object.__setattr__( 363 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 364 )
Set the energy terms to be computed by the model and prepare the energy and force functions.
366 def get_gradient_function( 367 self, 368 *gradient_keys: Sequence[str], 369 jit: bool = True, 370 variables_as_input: bool = False, 371 ): 372 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 373 374 def _energy_gradient(variables, data): 375 def _etot(variables, inputs): 376 if "strain" in inputs: 377 scaling = inputs["strain"] 378 batch_index = data["batch_index"] 379 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 380 coordinates = jax.vmap(jnp.matmul)( 381 coordinates, scaling[batch_index] 382 ) 383 inputs = {**inputs, "coordinates": coordinates} 384 if "cells" in inputs or "cells" in data: 385 cells = inputs["cells"] if "cells" in inputs else data["cells"] 386 cells = jax.vmap(jnp.matmul)(cells, scaling) 387 inputs["cells"] = cells 388 if "cells" in inputs: 389 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 390 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 391 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 392 return energy.sum(), out 393 394 if "strain" in gradient_keys and "strain" not in data: 395 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 396 inputs = {k: data[k] for k in gradient_keys} 397 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 398 399 return ( 400 out["total_energy"], 401 de, 402 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 403 ) 404 405 if variables_as_input: 406 energy_gradient = _energy_gradient 407 else: 408 409 def energy_gradient(data): 410 return _energy_gradient(self.variables, data) 411 412 if jit: 413 return jax.jit(energy_gradient) 414 else: 415 return energy_gradient
Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys
417 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 418 """apply preprocessing to the input data 419 420 !!! This is not a pure function => do not apply jax transforms !!!""" 421 if self.preproc_state is None: 422 out, _ = self.reinitialize_preprocessing(example_data=inputs) 423 elif use_gpu: 424 do_check_input = self.preproc_state.get("check_input", True) 425 if do_check_input: 426 inputs = check_input(inputs) 427 preproc_state, inputs = self.preprocessing.atom_padding( 428 self.preproc_state, inputs 429 ) 430 inputs = self.preprocessing.process(preproc_state, inputs) 431 preproc_state, state_up, out, overflow = ( 432 self.preprocessing.check_reallocate( 433 preproc_state, inputs 434 ) 435 ) 436 if verbose and overflow: 437 print("GPU preprocessing: nblist overflow => reallocating nblist") 438 print("size updates:", state_up) 439 else: 440 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 441 442 object.__setattr__(self, "preproc_state", preproc_state) 443 return out
apply preprocessing to the input data
!!! This is not a pure function => do not apply jax transforms !!!
445 def reinitialize_preprocessing( 446 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 447 ) -> None: 448 ### TODO ### 449 if rng_key is None: 450 rng_key_pre = jax.random.PRNGKey(0) 451 else: 452 rng_key, rng_key_pre = jax.random.split(rng_key) 453 454 if example_data is None: 455 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 456 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 457 458 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 459 object.__setattr__(self, "preproc_state", preproc_state) 460 return inputs, rng_key
476 def total_energy( 477 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 478 ) -> Tuple[jnp.ndarray, Dict]: 479 """compute the total energy of the system 480 481 !!! This is not a pure function => do not apply jax transforms !!! 482 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 483 """ 484 if variables is None: 485 variables = self.variables 486 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 487 # def print_shape(path,value): 488 # if isinstance(value,jnp.ndarray): 489 # print(path,value.shape) 490 # else: 491 # print(path,value) 492 # jax.tree_util.tree_map_with_path(print_shape,inputs) 493 _, output = self._total_energy(variables, inputs) 494 if self.use_atom_padding: 495 output = atom_unpadding(output) 496 e = output["total_energy"] 497 if unit is not None: 498 model_energy_unit = self.Ha_to_model_energy 499 if isinstance(unit, str): 500 unit = au.get_multiplier(unit) 501 e = e * (unit / model_energy_unit) 502 return e, output
compute the total energy of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
504 def energy_and_forces( 505 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 506 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 507 """compute the total energy and forces of the system 508 509 !!! This is not a pure function => do not apply jax transforms !!! 510 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 511 """ 512 if variables is None: 513 variables = self.variables 514 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 515 _, _, output = self._energy_and_forces(variables, inputs) 516 if self.use_atom_padding: 517 output = atom_unpadding(output) 518 e = output["total_energy"] 519 f = output["forces"] 520 if unit is not None: 521 model_energy_unit = self.Ha_to_model_energy 522 if isinstance(unit, str): 523 unit = au.get_multiplier(unit) 524 e = e * (unit / model_energy_unit) 525 f = f * (unit / model_energy_unit) 526 return e, f, output
compute the total energy and forces of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
528 def energy_and_forces_and_virial( 529 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 530 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 531 """compute the total energy and forces of the system 532 533 !!! This is not a pure function => do not apply jax transforms !!! 534 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 535 """ 536 if variables is None: 537 variables = self.variables 538 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 539 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 540 if self.use_atom_padding: 541 output = atom_unpadding(output) 542 e = output["total_energy"] 543 f = output["forces"] 544 v = output["virial_tensor"] 545 if unit is not None: 546 model_energy_unit = self.Ha_to_model_energy 547 if isinstance(unit, str): 548 unit = au.get_multiplier(unit) 549 e = e * (unit / model_energy_unit) 550 f = f * (unit / model_energy_unit) 551 v = v * (unit / model_energy_unit) 552 return e, f, v, output
compute the total energy and forces of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
554 def remove_atom_padding(self, output): 555 """remove atom padding from the output""" 556 return atom_unpadding(output)
remove atom padding from the output
558 def get_model(self) -> Tuple[FENNIXModules, Dict]: 559 """return the model and its variables""" 560 return self.modules, self.variables
return the model and its variables
562 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 563 """return the preprocessing chain and its state""" 564 return self.preprocessing, self.preproc_state
return the preprocessing chain and its state
595 def generate_dummy_system( 596 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 597 ) -> Dict[str, Any]: 598 """ 599 Generate dummy system for initialization 600 """ 601 if box_size is None: 602 box_size = 2 * self.cutoff 603 for g in self._graphs_properties.values(): 604 cutoff = g["cutoff"] 605 if cutoff is not None: 606 box_size = min(box_size, 2 * g["cutoff"]) 607 coordinates = np.array( 608 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 609 ) 610 species = np.ones((n_atoms,), dtype=np.int32) 611 batch_index = np.zeros((n_atoms,), dtype=np.int32) 612 natoms = np.array([n_atoms], dtype=np.int32) 613 return { 614 "species": species, 615 "coordinates": coordinates, 616 # "graph": graph, 617 "batch_index": batch_index, 618 "natoms": natoms, 619 }
Generate dummy system for initialization
621 def summarize( 622 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 623 ) -> str: 624 """Summarize the model architecture and parameters""" 625 if rng_key is None: 626 head = "Summarizing with example data:\n" 627 rng_key = jax.random.PRNGKey(0) 628 if example_data is None: 629 head = "Summarizing with dummy 10 atoms system:\n" 630 rng_key, rng_key_sys = jax.random.split(rng_key) 631 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 632 rng_key, rng_key_pre = jax.random.split(rng_key) 633 _, inputs = self.preprocessing.init_with_output(example_data) 634 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
Summarize the model architecture and parameters
636 def to_dict(self,convert_numpy=False): 637 """return a dictionary representation of the model""" 638 if convert_numpy: 639 variables = convert_to_numpy(self.variables) 640 else: 641 variables = deepcopy(self.variables) 642 return { 643 **self._input_args, 644 "energy_terms": self.energy_terms, 645 "variables": variables, 646 }
return a dictionary representation of the model
648 def save(self, filename): 649 """save the model to a file""" 650 filename_str = str(filename) 651 do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle") 652 state_dict = self.to_dict(convert_numpy=do_pickle) 653 state_dict["preprocessing"] = [ 654 [k, v] for k, v in state_dict["preprocessing"].items() 655 ] 656 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 657 if do_pickle: 658 import pickle 659 with open(filename, "wb") as f: 660 pickle.dump(state_dict, f) 661 else: 662 with open(filename, "wb") as f: 663 f.write(serialization.msgpack_serialize(state_dict))
save the model to a file
665 @classmethod 666 def load( 667 cls, 668 filename, 669 use_atom_padding=False, 670 graph_config={}, 671 ): 672 """load a model from a file""" 673 filename_str = str(filename) 674 do_pickle = filename_str.endswith(".pkl") or filename_str.endswith(".pickle") 675 if do_pickle: 676 import pickle 677 with open(filename, "rb") as f: 678 state_dict = pickle.load(f) 679 state_dict["variables"] = convert_to_jax(state_dict["variables"]) 680 else: 681 with open(filename, "rb") as f: 682 state_dict = serialization.msgpack_restore(f.read()) 683 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 684 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 685 return cls( 686 **state_dict, 687 graph_config=graph_config, 688 use_atom_padding=use_atom_padding, 689 )
load a model from a file