fennol.training.io
1import os, io, sys 2import numpy as np 3from scipy.spatial.transform import Rotation 4from collections import defaultdict 5import pickle 6import glob 7from flax import traverse_util 8from typing import Dict, List, Tuple, Union, Optional, Callable, Sequence 9from .databases import DBDataset, H5Dataset 10from ..models.preprocessing import AtomPadding 11import re 12import json 13import yaml 14 15try: 16 import tomlkit 17except ImportError: 18 tomlkit = None 19 20try: 21 from torch.utils.data import DataLoader 22except ImportError: 23 raise ImportError( 24 "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/" 25 ) 26 27from ..models import FENNIX 28 29 30def load_configuration(config_file: str) -> Dict[str, any]: 31 if config_file.endswith(".json"): 32 parameters = json.load(open(config_file)) 33 elif config_file.endswith(".yaml") or config_file.endswith(".yml"): 34 parameters = yaml.load(open(config_file), Loader=yaml.FullLoader) 35 elif tomlkit is not None and config_file.endswith(".toml"): 36 parameters = tomlkit.loads(open(config_file).read()) 37 else: 38 supported_formats = [".json", ".yaml", ".yml"] 39 if tomlkit is not None: 40 supported_formats.append(".toml") 41 raise ValueError( 42 f"Unknown config file format. Supported formats: {supported_formats}" 43 ) 44 return parameters 45 46 47def load_dataset( 48 dspath: str, 49 batch_size: int, 50 batch_size_val: Optional[int] = None, 51 rename_refs: dict = {}, 52 infinite_iterator: bool = False, 53 atom_padding: bool = False, 54 ref_keys: Optional[Sequence[str]] = None, 55 split_data_inputs: bool = False, 56 np_rng: Optional[np.random.Generator] = None, 57 train_val_split: bool = True, 58 training_parameters: dict = {}, 59 add_flags: Sequence[str] = ["training"], 60 fprec: str = "float32", 61): 62 """ 63 Load a dataset from a pickle file and return two iterators for training and validation batches. 64 65 Args: 66 training_parameters (dict): A dictionary with the following keys: 67 - 'dspath': str. Path to the pickle file containing the dataset. 68 - 'batch_size': int. Number of samples per batch. 69 rename_refs (list, optional): A list of strings with the names of the reference properties to rename. 70 Default is an empty list. 71 72 Returns: 73 tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. 74 For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: 75 - 'natoms': np.ndarray. Array with the number of atoms in each system. 76 - 'batch_index': np.ndarray. Array with the index of the system to which each atom 77 if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name. 78 """ 79 80 assert isinstance( 81 training_parameters, dict 82 ), "training_parameters must be a dictionary." 83 assert isinstance( 84 rename_refs, dict 85 ), "rename_refs must be a dictionary with the keys to rename." 86 87 pbc_training = training_parameters.get("pbc_training", False) 88 minimum_image = training_parameters.get("minimum_image", False) 89 90 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 91 92 input_keys = [ 93 "species", 94 "coordinates", 95 "natoms", 96 "batch_index", 97 "total_charge", 98 "flags", 99 ] 100 if pbc_training: 101 input_keys += ["cells", "reciprocal_cells"] 102 if atom_padding: 103 input_keys += ["true_atoms", "true_sys"] 104 if coordinates_ref_key is not None: 105 input_keys += ["system_index", "system_sign"] 106 107 flags = {f: None for f in add_flags} 108 if minimum_image and pbc_training: 109 flags["minimum_image"] = None 110 111 additional_input_keys = set(training_parameters.get("additional_input_keys", [])) 112 additional_input_keys_ = set() 113 for key in additional_input_keys: 114 if key not in input_keys: 115 additional_input_keys_.add(key) 116 additional_input_keys = additional_input_keys_ 117 118 all_inputs = set(input_keys + list(additional_input_keys)) 119 120 extract_all_keys = ref_keys is None 121 if ref_keys is not None: 122 ref_keys = set(ref_keys) 123 ref_keys_ = set() 124 for key in ref_keys: 125 if key not in all_inputs: 126 ref_keys_.add(key) 127 128 random_rotation = training_parameters.get("random_rotation", False) 129 if random_rotation: 130 assert np_rng is not None, "np_rng must be provided for adding noise." 131 132 apply_rotation = { 133 1: lambda x, r: x @ r, 134 -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r), 135 2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r), 136 } 137 def rotate_2f(x,r): 138 assert x.shape[-1]==6 139 # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor 140 indices = np.array([0,3,4,3,1,5,4,5,2]) 141 x=x[...,indices].reshape(*x.shape[:-1],3,3) 142 x=np.einsum("li,...lk,kj->...ij", r, x, r) 143 # select back the 6 components 144 indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]]) 145 x=x[...,indices[:,0],indices[:,1]] 146 return x 147 apply_rotation[-2]=rotate_2f 148 149 valid_rotations = tuple(apply_rotation.keys()) 150 rotated_keys = { 151 "coordinates": 1, 152 "forces": 1, 153 "virial_tensor": 2, 154 "stress_tensor": 2, 155 "virial": 2, 156 "stress": 2, 157 } 158 if pbc_training: 159 rotated_keys["cells"] = 1 160 user_rotated_keys = dict(training_parameters.get("rotated_keys", {})) 161 for k, v in user_rotated_keys.items(): 162 assert ( 163 v in valid_rotations 164 ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}" 165 rotated_keys[k] = v 166 167 # rotated_vector_keys = set( 168 # ["coordinates", "forces"] 169 # + list(training_parameters.get("rotated_vector_keys", [])) 170 # ) 171 # if pbc_training: 172 # rotated_vector_keys.add("cells") 173 174 # rotated_tensor_keys = set( 175 # ["virial_tensor", "stress_tensor", "virial", "stress"] 176 # + list(training_parameters.get("rotated_tensor_keys", [])) 177 # ) 178 # assert rotated_vector_keys.isdisjoint( 179 # rotated_tensor_keys 180 # ), "Rotated vector keys and rotated tensor keys must be disjoint." 181 # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys) 182 183 print( 184 "Applying random rotations to the following keys if present:", 185 list(rotated_keys.keys()), 186 ) 187 188 def apply_random_rotations(output, nbatch): 189 euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3)) 190 r = [ 191 Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T 192 for i in range(nbatch) 193 ] 194 for k, l in rotated_keys.items(): 195 if k in output: 196 for i in range(nbatch): 197 output[k][i] = apply_rotation[l](output[k][i], r[i]) 198 199 else: 200 201 def apply_random_rotations(output, nbatch): 202 pass 203 204 flow_matching = training_parameters.get("flow_matching", False) 205 if flow_matching: 206 if ref_keys is not None: 207 ref_keys.add("flow_matching_target") 208 if "flow_matching_target" in ref_keys_: 209 ref_keys_.remove("flow_matching_target") 210 all_inputs.add("flow_matching_time") 211 212 def add_flow_matching(output, nbatch): 213 ts = np_rng.uniform(0.0, 1.0, (nbatch,)) 214 targets = [] 215 for i in range(nbatch): 216 x1 = output["coordinates"][i] 217 com = x1.mean(axis=0, keepdims=True) 218 x1 = x1 - com 219 x0 = np_rng.normal(0.0, 1.0, x1.shape) 220 xt = (1 - ts[i]) * x0 + ts[i] * x1 221 output["coordinates"][i] = xt 222 targets.append(x1 - x0) 223 output["flow_matching_target"] = targets 224 output["flow_matching_time"] = [np.array(t) for t in ts] 225 226 else: 227 228 def add_flow_matching(output, nbatch): 229 pass 230 231 if pbc_training: 232 print("Periodic boundary conditions are active.") 233 length_nopbc = training_parameters.get("length_nopbc", 1000.0) 234 235 def add_cell(d, output): 236 if "cell" not in d: 237 cell = np.asarray( 238 [ 239 [length_nopbc, 0.0, 0.0], 240 [0.0, length_nopbc, 0.0], 241 [0.0, 0.0, length_nopbc], 242 ], 243 dtype=fprec, 244 ) 245 else: 246 cell = np.asarray(d["cell"], dtype=fprec) 247 output["cells"].append(cell.reshape(1, 3, 3)) 248 249 else: 250 251 def add_cell(d, output): 252 if "cell" in d: 253 print( 254 "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions." 255 ) 256 257 if extract_all_keys: 258 259 def add_other_keys(d, output, atom_shift): 260 for k, v in d.items(): 261 if k in ("cell", "total_charge"): 262 continue 263 v_array = np.array(v) 264 # Shift atom number if necessary 265 if k.endswith("_atidx"): 266 mask = (v_array > 0).astype(int) 267 v_array = v_array + atom_shift*mask 268 output[k].append(v_array) 269 270 else: 271 272 def add_other_keys(d, output, atom_shift): 273 output["species"].append(np.asarray(d["species"])) 274 output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec)) 275 for k in additional_input_keys: 276 v_array = np.array(d[k]) 277 # Shift atom number if necessary 278 if k.endswith("_atidx"): 279 mask = (v_array > 0).astype(int) 280 v_array =v_array + atom_shift*mask 281 output[k].append(v_array) 282 for k in ref_keys_: 283 v_array = np.array(d[k]) 284 # Shift atom number if necessary 285 if k.endswith("_atidx"): 286 mask = (v_array > 0).astype(int) 287 v_array =v_array + atom_shift*mask 288 output[k].append(v_array) 289 if k + "_mask" in d: 290 output[k + "_mask"].append(np.asarray(d[k + "_mask"])) 291 292 def add_keys(d, output, atom_shift, batch_index): 293 nat = d["species"].shape[0] 294 295 output["natoms"].append(np.asarray([nat])) 296 output["batch_index"].append(np.asarray([batch_index] * nat)) 297 if "total_charge" not in d: 298 total_charge = np.asarray(0.0, dtype=fprec) 299 else: 300 total_charge = np.asarray(d["total_charge"], dtype=fprec) 301 output["total_charge"].append(total_charge) 302 303 add_cell(d, output) 304 add_other_keys(d, output, atom_shift) 305 306 return atom_shift + nat 307 308 def collate_fn_(batch): 309 output = defaultdict(list) 310 atom_shift = 0 311 batch_index = 0 312 313 for d in batch: 314 atom_shift = add_keys(d, output, atom_shift, batch_index) 315 batch_index += 1 316 317 if coordinates_ref_key is not None: 318 output["system_index"].append(np.asarray([batch_index - 1])) 319 output["system_sign"].append(np.asarray([1])) 320 if coordinates_ref_key in d: 321 dref = {**d, "coordinates": d[coordinates_ref_key]} 322 atom_shift = add_keys(dref, output, atom_shift, batch_index) 323 output["system_index"].append(np.asarray([batch_index - 1])) 324 output["system_sign"].append(np.asarray([-1])) 325 batch_index += 1 326 327 nbatch_ = len(output["natoms"]) 328 apply_random_rotations(output,nbatch_) 329 add_flow_matching(output,nbatch_) 330 331 # Stack and concatenate the arrays 332 for k, v in output.items(): 333 if v[0].ndim == 0: 334 v = np.stack(v) 335 else: 336 v = np.concatenate(v, axis=0) 337 if np.issubdtype(v.dtype, np.floating): 338 v = v.astype(fprec) 339 output[k] = v 340 341 if "cells" in output and pbc_training: 342 output["reciprocal_cells"] = np.linalg.inv(output["cells"]) 343 344 # Rename necessary keys 345 # for key in rename_refs: 346 # if key in output: 347 # output["true_" + key] = output.pop(key) 348 for kold, knew in rename_refs.items(): 349 assert ( 350 knew not in output 351 ), f"Cannot rename key {kold} to {knew}. Key {knew} already present." 352 if kold in output: 353 output[knew] = output.pop(kold) 354 355 output["flags"] = flags 356 return output 357 358 collate_layers_train = [collate_fn_] 359 collate_layers_valid = [collate_fn_] 360 361 ### collate preprocessing 362 # add noise to the training data 363 noise_sigma = training_parameters.get("noise_sigma", None) 364 if noise_sigma is not None: 365 assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary" 366 367 for sigma in noise_sigma.values(): 368 assert sigma >= 0, "Noise sigma should be a positive number" 369 370 print("Adding noise to the training data:") 371 for key, sigma in noise_sigma.items(): 372 print(f" - {key} with sigma = {sigma}") 373 374 assert np_rng is not None, "np_rng must be provided for adding noise." 375 376 def collate_with_noise(batch): 377 for key, sigma in noise_sigma.items(): 378 if key in batch and sigma > 0: 379 batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype( 380 batch[key].dtype 381 ) 382 return batch 383 384 collate_layers_train.append(collate_with_noise) 385 386 if atom_padding: 387 padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0)) 388 padder_state = padder.init() 389 390 def collate_with_padding(batch): 391 padder_state_up, output = padder(padder_state, batch) 392 padder_state.update(padder_state_up) 393 return output 394 395 collate_layers_train.append(collate_with_padding) 396 collate_layers_valid.append(collate_with_padding) 397 398 if split_data_inputs: 399 400 # input_keys += additional_input_keys 401 # input_keys = set(input_keys) 402 print("Input keys:", all_inputs) 403 print("Ref keys:", ref_keys) 404 405 def collate_split(batch): 406 inputs = {} 407 refs = {} 408 for k, v in batch.items(): 409 if k in all_inputs: 410 inputs[k] = v 411 if k in ref_keys: 412 refs[k] = v 413 if k.endswith("_mask") and k[:-5] in ref_keys: 414 refs[k] = v 415 return inputs, refs 416 417 collate_layers_train.append(collate_split) 418 collate_layers_valid.append(collate_split) 419 420 ### apply all collate preprocessing 421 if len(collate_layers_train) == 1: 422 collate_fn_train = collate_layers_train[0] 423 else: 424 425 def collate_fn_train(batch): 426 for layer in collate_layers_train: 427 batch = layer(batch) 428 return batch 429 430 if len(collate_layers_valid) == 1: 431 collate_fn_valid = collate_layers_valid[0] 432 else: 433 434 def collate_fn_valid(batch): 435 for layer in collate_layers_valid: 436 batch = layer(batch) 437 return batch 438 439 if not os.path.exists(dspath): 440 raise ValueError(f"Dataset file '{dspath}' not found.") 441 # dspath = training_parameters.get("dspath", None) 442 print(f"Loading dataset from {dspath}...", end="") 443 # print(f" the following keys will be renamed if present : {rename_refs}") 444 sharded_training = False 445 if dspath.endswith(".db"): 446 dataset = {} 447 if train_val_split: 448 dataset["training"] = DBDataset(dspath, table="training") 449 dataset["validation"] = DBDataset(dspath, table="validation") 450 else: 451 dataset = DBDataset(dspath) 452 elif dspath.endswith(".h5") or dspath.endswith(".hdf5"): 453 dataset = {} 454 if train_val_split: 455 dataset["training"] = H5Dataset(dspath, table="training") 456 dataset["validation"] = H5Dataset(dspath, table="validation") 457 else: 458 dataset = H5Dataset(dspath) 459 elif dspath.endswith(".pkl") or dspath.endswith(".pickle"): 460 with open(dspath, "rb") as f: 461 dataset = pickle.load(f) 462 if not train_val_split and isinstance(dataset, dict): 463 dataset = dataset["training"] 464 elif os.path.isdir(dspath): 465 if train_val_split: 466 dataset = {} 467 with open(dspath + "/validation.pkl", "rb") as f: 468 dataset["validation"] = pickle.load(f) 469 else: 470 dataset = None 471 472 shard_files = sorted(glob.glob(dspath + "/training_*.pkl")) 473 nshards = len(shard_files) 474 if nshards == 0: 475 raise ValueError("No dataset shards found.") 476 elif nshards == 1: 477 with open(shard_files[0], "rb") as f: 478 if train_val_split: 479 dataset["training"] = pickle.load(f) 480 else: 481 dataset = pickle.load(f) 482 else: 483 print(f"Found {nshards} dataset shards.") 484 sharded_training = True 485 486 else: 487 raise ValueError( 488 f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'" 489 ) 490 print(" done.") 491 492 ### BUILD DATALOADERS 493 # batch_size = training_parameters.get("batch_size", 16) 494 shuffle = training_parameters.get("shuffle_dataset", True) 495 if train_val_split: 496 if batch_size_val is None: 497 batch_size_val = batch_size 498 dataloader_validation = DataLoader( 499 dataset["validation"], 500 batch_size=batch_size_val, 501 shuffle=shuffle, 502 collate_fn=collate_fn_valid, 503 ) 504 505 if sharded_training: 506 507 def iterate_sharded_dataset(): 508 indices = np.arange(nshards) 509 if shuffle: 510 assert np_rng is not None, "np_rng must be provided for shuffling." 511 np_rng.shuffle(indices) 512 for i in indices: 513 filename = shard_files[i] 514 print(f"# Loading dataset shard from {filename}...", end="") 515 with open(filename, "rb") as f: 516 dataset = pickle.load(f) 517 print(" done.") 518 dataloader = DataLoader( 519 dataset, 520 batch_size=batch_size, 521 shuffle=shuffle, 522 collate_fn=collate_fn_train, 523 ) 524 for batch in dataloader: 525 yield batch 526 527 class DataLoaderSharded: 528 def __iter__(self): 529 return iterate_sharded_dataset() 530 531 dataloader_training = DataLoaderSharded() 532 else: 533 dataloader_training = DataLoader( 534 dataset["training"] if train_val_split else dataset, 535 batch_size=batch_size, 536 shuffle=shuffle, 537 collate_fn=collate_fn_train, 538 ) 539 540 if not infinite_iterator: 541 if train_val_split: 542 return dataloader_training, dataloader_validation 543 return dataloader_training 544 545 def next_batch_factory(dataloader): 546 while True: 547 for batch in dataloader: 548 yield batch 549 550 training_iterator = next_batch_factory(dataloader_training) 551 if train_val_split: 552 validation_iterator = next_batch_factory(dataloader_validation) 553 return training_iterator, validation_iterator 554 return training_iterator 555 556 557def load_model( 558 parameters: Dict[str, any], 559 model_file: Optional[str] = None, 560 rng_key: Optional[str] = None, 561) -> FENNIX: 562 """ 563 Load a FENNIX model from a file or create a new one. 564 565 Args: 566 parameters (dict): A dictionary of parameters for the model. 567 model_file (str, optional): The path to a saved model file to load. 568 569 Returns: 570 FENNIX: A FENNIX model object. 571 """ 572 print_model = parameters["training"].get("print_model", False) 573 if model_file is None: 574 model_file = parameters.get("model_file", None) 575 if model_file is not None and os.path.exists(model_file): 576 model = FENNIX.load(model_file, use_atom_padding=False) 577 if print_model: 578 print(model.summarize()) 579 print(f"Restored model from '{model_file}'.") 580 else: 581 assert ( 582 rng_key is not None 583 ), "rng_key must be specified if model_file is not provided." 584 model_params = parameters["model"] 585 if isinstance(model_params, str): 586 assert os.path.exists( 587 model_params 588 ), f"Model file '{model_params}' not found." 589 model = FENNIX.load(model_params, use_atom_padding=False) 590 print(f"Restored model from '{model_params}'.") 591 else: 592 model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False) 593 if print_model: 594 print(model.summarize()) 595 return model 596 597 598def copy_parameters(variables, variables_ref, params=[".*"]): 599 def merge_params(full_path_, v, v_ref): 600 full_path = "/".join(full_path_[1:]).lower() 601 # status = (False, "") 602 for path in params: 603 if re.match(path.lower(), full_path): 604 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 605 return v_ref 606 return v 607 # return v_ref if status[0] else v 608 609 flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False) 610 flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False) 611 return traverse_util.unflatten_dict( 612 { 613 k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v 614 for k, v in flat.items() 615 } 616 ) 617 618 619class TeeLogger(object): 620 def __init__(self, name): 621 self.file = io.TextIOWrapper(open(name, "wb"), write_through=True) 622 self.stdout = None 623 624 def __del__(self): 625 self.close() 626 627 def write(self, data): 628 self.file.write(data) 629 self.stdout.write(data) 630 self.flush() 631 632 def close(self): 633 self.file.close() 634 635 def flush(self): 636 self.file.flush() 637 638 def bind_stdout(self): 639 if isinstance(sys.stdout, TeeLogger): 640 raise ValueError("stdout already bound to a Tee instance.") 641 if self.stdout is not None: 642 raise ValueError("stdout already bound.") 643 self.stdout = sys.stdout 644 sys.stdout = self 645 646 def unbind_stdout(self): 647 if self.stdout is None: 648 raise ValueError("stdout is not bound.") 649 sys.stdout = self.stdout
31def load_configuration(config_file: str) -> Dict[str, any]: 32 if config_file.endswith(".json"): 33 parameters = json.load(open(config_file)) 34 elif config_file.endswith(".yaml") or config_file.endswith(".yml"): 35 parameters = yaml.load(open(config_file), Loader=yaml.FullLoader) 36 elif tomlkit is not None and config_file.endswith(".toml"): 37 parameters = tomlkit.loads(open(config_file).read()) 38 else: 39 supported_formats = [".json", ".yaml", ".yml"] 40 if tomlkit is not None: 41 supported_formats.append(".toml") 42 raise ValueError( 43 f"Unknown config file format. Supported formats: {supported_formats}" 44 ) 45 return parameters
48def load_dataset( 49 dspath: str, 50 batch_size: int, 51 batch_size_val: Optional[int] = None, 52 rename_refs: dict = {}, 53 infinite_iterator: bool = False, 54 atom_padding: bool = False, 55 ref_keys: Optional[Sequence[str]] = None, 56 split_data_inputs: bool = False, 57 np_rng: Optional[np.random.Generator] = None, 58 train_val_split: bool = True, 59 training_parameters: dict = {}, 60 add_flags: Sequence[str] = ["training"], 61 fprec: str = "float32", 62): 63 """ 64 Load a dataset from a pickle file and return two iterators for training and validation batches. 65 66 Args: 67 training_parameters (dict): A dictionary with the following keys: 68 - 'dspath': str. Path to the pickle file containing the dataset. 69 - 'batch_size': int. Number of samples per batch. 70 rename_refs (list, optional): A list of strings with the names of the reference properties to rename. 71 Default is an empty list. 72 73 Returns: 74 tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. 75 For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: 76 - 'natoms': np.ndarray. Array with the number of atoms in each system. 77 - 'batch_index': np.ndarray. Array with the index of the system to which each atom 78 if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name. 79 """ 80 81 assert isinstance( 82 training_parameters, dict 83 ), "training_parameters must be a dictionary." 84 assert isinstance( 85 rename_refs, dict 86 ), "rename_refs must be a dictionary with the keys to rename." 87 88 pbc_training = training_parameters.get("pbc_training", False) 89 minimum_image = training_parameters.get("minimum_image", False) 90 91 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 92 93 input_keys = [ 94 "species", 95 "coordinates", 96 "natoms", 97 "batch_index", 98 "total_charge", 99 "flags", 100 ] 101 if pbc_training: 102 input_keys += ["cells", "reciprocal_cells"] 103 if atom_padding: 104 input_keys += ["true_atoms", "true_sys"] 105 if coordinates_ref_key is not None: 106 input_keys += ["system_index", "system_sign"] 107 108 flags = {f: None for f in add_flags} 109 if minimum_image and pbc_training: 110 flags["minimum_image"] = None 111 112 additional_input_keys = set(training_parameters.get("additional_input_keys", [])) 113 additional_input_keys_ = set() 114 for key in additional_input_keys: 115 if key not in input_keys: 116 additional_input_keys_.add(key) 117 additional_input_keys = additional_input_keys_ 118 119 all_inputs = set(input_keys + list(additional_input_keys)) 120 121 extract_all_keys = ref_keys is None 122 if ref_keys is not None: 123 ref_keys = set(ref_keys) 124 ref_keys_ = set() 125 for key in ref_keys: 126 if key not in all_inputs: 127 ref_keys_.add(key) 128 129 random_rotation = training_parameters.get("random_rotation", False) 130 if random_rotation: 131 assert np_rng is not None, "np_rng must be provided for adding noise." 132 133 apply_rotation = { 134 1: lambda x, r: x @ r, 135 -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r), 136 2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r), 137 } 138 def rotate_2f(x,r): 139 assert x.shape[-1]==6 140 # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor 141 indices = np.array([0,3,4,3,1,5,4,5,2]) 142 x=x[...,indices].reshape(*x.shape[:-1],3,3) 143 x=np.einsum("li,...lk,kj->...ij", r, x, r) 144 # select back the 6 components 145 indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]]) 146 x=x[...,indices[:,0],indices[:,1]] 147 return x 148 apply_rotation[-2]=rotate_2f 149 150 valid_rotations = tuple(apply_rotation.keys()) 151 rotated_keys = { 152 "coordinates": 1, 153 "forces": 1, 154 "virial_tensor": 2, 155 "stress_tensor": 2, 156 "virial": 2, 157 "stress": 2, 158 } 159 if pbc_training: 160 rotated_keys["cells"] = 1 161 user_rotated_keys = dict(training_parameters.get("rotated_keys", {})) 162 for k, v in user_rotated_keys.items(): 163 assert ( 164 v in valid_rotations 165 ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}" 166 rotated_keys[k] = v 167 168 # rotated_vector_keys = set( 169 # ["coordinates", "forces"] 170 # + list(training_parameters.get("rotated_vector_keys", [])) 171 # ) 172 # if pbc_training: 173 # rotated_vector_keys.add("cells") 174 175 # rotated_tensor_keys = set( 176 # ["virial_tensor", "stress_tensor", "virial", "stress"] 177 # + list(training_parameters.get("rotated_tensor_keys", [])) 178 # ) 179 # assert rotated_vector_keys.isdisjoint( 180 # rotated_tensor_keys 181 # ), "Rotated vector keys and rotated tensor keys must be disjoint." 182 # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys) 183 184 print( 185 "Applying random rotations to the following keys if present:", 186 list(rotated_keys.keys()), 187 ) 188 189 def apply_random_rotations(output, nbatch): 190 euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3)) 191 r = [ 192 Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T 193 for i in range(nbatch) 194 ] 195 for k, l in rotated_keys.items(): 196 if k in output: 197 for i in range(nbatch): 198 output[k][i] = apply_rotation[l](output[k][i], r[i]) 199 200 else: 201 202 def apply_random_rotations(output, nbatch): 203 pass 204 205 flow_matching = training_parameters.get("flow_matching", False) 206 if flow_matching: 207 if ref_keys is not None: 208 ref_keys.add("flow_matching_target") 209 if "flow_matching_target" in ref_keys_: 210 ref_keys_.remove("flow_matching_target") 211 all_inputs.add("flow_matching_time") 212 213 def add_flow_matching(output, nbatch): 214 ts = np_rng.uniform(0.0, 1.0, (nbatch,)) 215 targets = [] 216 for i in range(nbatch): 217 x1 = output["coordinates"][i] 218 com = x1.mean(axis=0, keepdims=True) 219 x1 = x1 - com 220 x0 = np_rng.normal(0.0, 1.0, x1.shape) 221 xt = (1 - ts[i]) * x0 + ts[i] * x1 222 output["coordinates"][i] = xt 223 targets.append(x1 - x0) 224 output["flow_matching_target"] = targets 225 output["flow_matching_time"] = [np.array(t) for t in ts] 226 227 else: 228 229 def add_flow_matching(output, nbatch): 230 pass 231 232 if pbc_training: 233 print("Periodic boundary conditions are active.") 234 length_nopbc = training_parameters.get("length_nopbc", 1000.0) 235 236 def add_cell(d, output): 237 if "cell" not in d: 238 cell = np.asarray( 239 [ 240 [length_nopbc, 0.0, 0.0], 241 [0.0, length_nopbc, 0.0], 242 [0.0, 0.0, length_nopbc], 243 ], 244 dtype=fprec, 245 ) 246 else: 247 cell = np.asarray(d["cell"], dtype=fprec) 248 output["cells"].append(cell.reshape(1, 3, 3)) 249 250 else: 251 252 def add_cell(d, output): 253 if "cell" in d: 254 print( 255 "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions." 256 ) 257 258 if extract_all_keys: 259 260 def add_other_keys(d, output, atom_shift): 261 for k, v in d.items(): 262 if k in ("cell", "total_charge"): 263 continue 264 v_array = np.array(v) 265 # Shift atom number if necessary 266 if k.endswith("_atidx"): 267 mask = (v_array > 0).astype(int) 268 v_array = v_array + atom_shift*mask 269 output[k].append(v_array) 270 271 else: 272 273 def add_other_keys(d, output, atom_shift): 274 output["species"].append(np.asarray(d["species"])) 275 output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec)) 276 for k in additional_input_keys: 277 v_array = np.array(d[k]) 278 # Shift atom number if necessary 279 if k.endswith("_atidx"): 280 mask = (v_array > 0).astype(int) 281 v_array =v_array + atom_shift*mask 282 output[k].append(v_array) 283 for k in ref_keys_: 284 v_array = np.array(d[k]) 285 # Shift atom number if necessary 286 if k.endswith("_atidx"): 287 mask = (v_array > 0).astype(int) 288 v_array =v_array + atom_shift*mask 289 output[k].append(v_array) 290 if k + "_mask" in d: 291 output[k + "_mask"].append(np.asarray(d[k + "_mask"])) 292 293 def add_keys(d, output, atom_shift, batch_index): 294 nat = d["species"].shape[0] 295 296 output["natoms"].append(np.asarray([nat])) 297 output["batch_index"].append(np.asarray([batch_index] * nat)) 298 if "total_charge" not in d: 299 total_charge = np.asarray(0.0, dtype=fprec) 300 else: 301 total_charge = np.asarray(d["total_charge"], dtype=fprec) 302 output["total_charge"].append(total_charge) 303 304 add_cell(d, output) 305 add_other_keys(d, output, atom_shift) 306 307 return atom_shift + nat 308 309 def collate_fn_(batch): 310 output = defaultdict(list) 311 atom_shift = 0 312 batch_index = 0 313 314 for d in batch: 315 atom_shift = add_keys(d, output, atom_shift, batch_index) 316 batch_index += 1 317 318 if coordinates_ref_key is not None: 319 output["system_index"].append(np.asarray([batch_index - 1])) 320 output["system_sign"].append(np.asarray([1])) 321 if coordinates_ref_key in d: 322 dref = {**d, "coordinates": d[coordinates_ref_key]} 323 atom_shift = add_keys(dref, output, atom_shift, batch_index) 324 output["system_index"].append(np.asarray([batch_index - 1])) 325 output["system_sign"].append(np.asarray([-1])) 326 batch_index += 1 327 328 nbatch_ = len(output["natoms"]) 329 apply_random_rotations(output,nbatch_) 330 add_flow_matching(output,nbatch_) 331 332 # Stack and concatenate the arrays 333 for k, v in output.items(): 334 if v[0].ndim == 0: 335 v = np.stack(v) 336 else: 337 v = np.concatenate(v, axis=0) 338 if np.issubdtype(v.dtype, np.floating): 339 v = v.astype(fprec) 340 output[k] = v 341 342 if "cells" in output and pbc_training: 343 output["reciprocal_cells"] = np.linalg.inv(output["cells"]) 344 345 # Rename necessary keys 346 # for key in rename_refs: 347 # if key in output: 348 # output["true_" + key] = output.pop(key) 349 for kold, knew in rename_refs.items(): 350 assert ( 351 knew not in output 352 ), f"Cannot rename key {kold} to {knew}. Key {knew} already present." 353 if kold in output: 354 output[knew] = output.pop(kold) 355 356 output["flags"] = flags 357 return output 358 359 collate_layers_train = [collate_fn_] 360 collate_layers_valid = [collate_fn_] 361 362 ### collate preprocessing 363 # add noise to the training data 364 noise_sigma = training_parameters.get("noise_sigma", None) 365 if noise_sigma is not None: 366 assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary" 367 368 for sigma in noise_sigma.values(): 369 assert sigma >= 0, "Noise sigma should be a positive number" 370 371 print("Adding noise to the training data:") 372 for key, sigma in noise_sigma.items(): 373 print(f" - {key} with sigma = {sigma}") 374 375 assert np_rng is not None, "np_rng must be provided for adding noise." 376 377 def collate_with_noise(batch): 378 for key, sigma in noise_sigma.items(): 379 if key in batch and sigma > 0: 380 batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype( 381 batch[key].dtype 382 ) 383 return batch 384 385 collate_layers_train.append(collate_with_noise) 386 387 if atom_padding: 388 padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0)) 389 padder_state = padder.init() 390 391 def collate_with_padding(batch): 392 padder_state_up, output = padder(padder_state, batch) 393 padder_state.update(padder_state_up) 394 return output 395 396 collate_layers_train.append(collate_with_padding) 397 collate_layers_valid.append(collate_with_padding) 398 399 if split_data_inputs: 400 401 # input_keys += additional_input_keys 402 # input_keys = set(input_keys) 403 print("Input keys:", all_inputs) 404 print("Ref keys:", ref_keys) 405 406 def collate_split(batch): 407 inputs = {} 408 refs = {} 409 for k, v in batch.items(): 410 if k in all_inputs: 411 inputs[k] = v 412 if k in ref_keys: 413 refs[k] = v 414 if k.endswith("_mask") and k[:-5] in ref_keys: 415 refs[k] = v 416 return inputs, refs 417 418 collate_layers_train.append(collate_split) 419 collate_layers_valid.append(collate_split) 420 421 ### apply all collate preprocessing 422 if len(collate_layers_train) == 1: 423 collate_fn_train = collate_layers_train[0] 424 else: 425 426 def collate_fn_train(batch): 427 for layer in collate_layers_train: 428 batch = layer(batch) 429 return batch 430 431 if len(collate_layers_valid) == 1: 432 collate_fn_valid = collate_layers_valid[0] 433 else: 434 435 def collate_fn_valid(batch): 436 for layer in collate_layers_valid: 437 batch = layer(batch) 438 return batch 439 440 if not os.path.exists(dspath): 441 raise ValueError(f"Dataset file '{dspath}' not found.") 442 # dspath = training_parameters.get("dspath", None) 443 print(f"Loading dataset from {dspath}...", end="") 444 # print(f" the following keys will be renamed if present : {rename_refs}") 445 sharded_training = False 446 if dspath.endswith(".db"): 447 dataset = {} 448 if train_val_split: 449 dataset["training"] = DBDataset(dspath, table="training") 450 dataset["validation"] = DBDataset(dspath, table="validation") 451 else: 452 dataset = DBDataset(dspath) 453 elif dspath.endswith(".h5") or dspath.endswith(".hdf5"): 454 dataset = {} 455 if train_val_split: 456 dataset["training"] = H5Dataset(dspath, table="training") 457 dataset["validation"] = H5Dataset(dspath, table="validation") 458 else: 459 dataset = H5Dataset(dspath) 460 elif dspath.endswith(".pkl") or dspath.endswith(".pickle"): 461 with open(dspath, "rb") as f: 462 dataset = pickle.load(f) 463 if not train_val_split and isinstance(dataset, dict): 464 dataset = dataset["training"] 465 elif os.path.isdir(dspath): 466 if train_val_split: 467 dataset = {} 468 with open(dspath + "/validation.pkl", "rb") as f: 469 dataset["validation"] = pickle.load(f) 470 else: 471 dataset = None 472 473 shard_files = sorted(glob.glob(dspath + "/training_*.pkl")) 474 nshards = len(shard_files) 475 if nshards == 0: 476 raise ValueError("No dataset shards found.") 477 elif nshards == 1: 478 with open(shard_files[0], "rb") as f: 479 if train_val_split: 480 dataset["training"] = pickle.load(f) 481 else: 482 dataset = pickle.load(f) 483 else: 484 print(f"Found {nshards} dataset shards.") 485 sharded_training = True 486 487 else: 488 raise ValueError( 489 f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'" 490 ) 491 print(" done.") 492 493 ### BUILD DATALOADERS 494 # batch_size = training_parameters.get("batch_size", 16) 495 shuffle = training_parameters.get("shuffle_dataset", True) 496 if train_val_split: 497 if batch_size_val is None: 498 batch_size_val = batch_size 499 dataloader_validation = DataLoader( 500 dataset["validation"], 501 batch_size=batch_size_val, 502 shuffle=shuffle, 503 collate_fn=collate_fn_valid, 504 ) 505 506 if sharded_training: 507 508 def iterate_sharded_dataset(): 509 indices = np.arange(nshards) 510 if shuffle: 511 assert np_rng is not None, "np_rng must be provided for shuffling." 512 np_rng.shuffle(indices) 513 for i in indices: 514 filename = shard_files[i] 515 print(f"# Loading dataset shard from {filename}...", end="") 516 with open(filename, "rb") as f: 517 dataset = pickle.load(f) 518 print(" done.") 519 dataloader = DataLoader( 520 dataset, 521 batch_size=batch_size, 522 shuffle=shuffle, 523 collate_fn=collate_fn_train, 524 ) 525 for batch in dataloader: 526 yield batch 527 528 class DataLoaderSharded: 529 def __iter__(self): 530 return iterate_sharded_dataset() 531 532 dataloader_training = DataLoaderSharded() 533 else: 534 dataloader_training = DataLoader( 535 dataset["training"] if train_val_split else dataset, 536 batch_size=batch_size, 537 shuffle=shuffle, 538 collate_fn=collate_fn_train, 539 ) 540 541 if not infinite_iterator: 542 if train_val_split: 543 return dataloader_training, dataloader_validation 544 return dataloader_training 545 546 def next_batch_factory(dataloader): 547 while True: 548 for batch in dataloader: 549 yield batch 550 551 training_iterator = next_batch_factory(dataloader_training) 552 if train_val_split: 553 validation_iterator = next_batch_factory(dataloader_validation) 554 return training_iterator, validation_iterator 555 return training_iterator
Load a dataset from a pickle file and return two iterators for training and validation batches.
Args: training_parameters (dict): A dictionary with the following keys: - 'dspath': str. Path to the pickle file containing the dataset. - 'batch_size': int. Number of samples per batch. rename_refs (list, optional): A list of strings with the names of the reference properties to rename. Default is an empty list.
Returns: tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: - 'natoms': np.ndarray. Array with the number of atoms in each system. - 'batch_index': np.ndarray. Array with the index of the system to which each atom if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
558def load_model( 559 parameters: Dict[str, any], 560 model_file: Optional[str] = None, 561 rng_key: Optional[str] = None, 562) -> FENNIX: 563 """ 564 Load a FENNIX model from a file or create a new one. 565 566 Args: 567 parameters (dict): A dictionary of parameters for the model. 568 model_file (str, optional): The path to a saved model file to load. 569 570 Returns: 571 FENNIX: A FENNIX model object. 572 """ 573 print_model = parameters["training"].get("print_model", False) 574 if model_file is None: 575 model_file = parameters.get("model_file", None) 576 if model_file is not None and os.path.exists(model_file): 577 model = FENNIX.load(model_file, use_atom_padding=False) 578 if print_model: 579 print(model.summarize()) 580 print(f"Restored model from '{model_file}'.") 581 else: 582 assert ( 583 rng_key is not None 584 ), "rng_key must be specified if model_file is not provided." 585 model_params = parameters["model"] 586 if isinstance(model_params, str): 587 assert os.path.exists( 588 model_params 589 ), f"Model file '{model_params}' not found." 590 model = FENNIX.load(model_params, use_atom_padding=False) 591 print(f"Restored model from '{model_params}'.") 592 else: 593 model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False) 594 if print_model: 595 print(model.summarize()) 596 return model
Load a FENNIX model from a file or create a new one.
Args: parameters (dict): A dictionary of parameters for the model. model_file (str, optional): The path to a saved model file to load.
Returns: FENNIX: A FENNIX model object.
599def copy_parameters(variables, variables_ref, params=[".*"]): 600 def merge_params(full_path_, v, v_ref): 601 full_path = "/".join(full_path_[1:]).lower() 602 # status = (False, "") 603 for path in params: 604 if re.match(path.lower(), full_path): 605 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 606 return v_ref 607 return v 608 # return v_ref if status[0] else v 609 610 flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False) 611 flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False) 612 return traverse_util.unflatten_dict( 613 { 614 k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v 615 for k, v in flat.items() 616 } 617 )
620class TeeLogger(object): 621 def __init__(self, name): 622 self.file = io.TextIOWrapper(open(name, "wb"), write_through=True) 623 self.stdout = None 624 625 def __del__(self): 626 self.close() 627 628 def write(self, data): 629 self.file.write(data) 630 self.stdout.write(data) 631 self.flush() 632 633 def close(self): 634 self.file.close() 635 636 def flush(self): 637 self.file.flush() 638 639 def bind_stdout(self): 640 if isinstance(sys.stdout, TeeLogger): 641 raise ValueError("stdout already bound to a Tee instance.") 642 if self.stdout is not None: 643 raise ValueError("stdout already bound.") 644 self.stdout = sys.stdout 645 sys.stdout = self 646 647 def unbind_stdout(self): 648 if self.stdout is None: 649 raise ValueError("stdout is not bound.") 650 sys.stdout = self.stdout