fennol.training.io

  1import os, io, sys
  2import numpy as np
  3from scipy.spatial.transform import Rotation
  4from collections import defaultdict
  5import pickle
  6import glob
  7from flax import traverse_util
  8from typing import Dict, List, Tuple, Union, Optional, Callable, Sequence
  9from .databases import DBDataset, H5Dataset
 10from ..models.preprocessing import AtomPadding
 11import re
 12import json
 13import yaml
 14
 15try:
 16    import tomlkit
 17except ImportError:
 18    tomlkit = None
 19
 20try:
 21    from torch.utils.data import DataLoader
 22except ImportError:
 23    raise ImportError(
 24        "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/"
 25    )
 26
 27from ..models import FENNIX
 28
 29
 30def load_configuration(config_file: str) -> Dict[str, any]:
 31    if config_file.endswith(".json"):
 32        parameters = json.load(open(config_file))
 33    elif config_file.endswith(".yaml") or config_file.endswith(".yml"):
 34        parameters = yaml.load(open(config_file), Loader=yaml.FullLoader)
 35    elif tomlkit is not None and config_file.endswith(".toml"):
 36        parameters = tomlkit.loads(open(config_file).read())
 37    else:
 38        supported_formats = [".json", ".yaml", ".yml"]
 39        if tomlkit is not None:
 40            supported_formats.append(".toml")
 41        raise ValueError(
 42            f"Unknown config file format. Supported formats: {supported_formats}"
 43        )
 44    return parameters
 45
 46
 47def load_dataset(
 48    dspath: str,
 49    batch_size: int,
 50    batch_size_val: Optional[int] = None,
 51    rename_refs: dict = {},
 52    infinite_iterator: bool = False,
 53    atom_padding: bool = False,
 54    ref_keys: Optional[Sequence[str]] = None,
 55    split_data_inputs: bool = False,
 56    np_rng: Optional[np.random.Generator] = None,
 57    train_val_split: bool = True,
 58    training_parameters: dict = {},
 59    add_flags: Sequence[str] = ["training"],
 60    fprec: str = "float32",
 61):
 62    """
 63    Load a dataset from a pickle file and return two iterators for training and validation batches.
 64
 65    Args:
 66        training_parameters (dict): A dictionary with the following keys:
 67            - 'dspath': str. Path to the pickle file containing the dataset.
 68            - 'batch_size': int. Number of samples per batch.
 69        rename_refs (list, optional): A list of strings with the names of the reference properties to rename.
 70            Default is an empty list.
 71
 72    Returns:
 73        tuple: A tuple of two infinite iterators, one for training batches and one for validation batches.
 74            For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems:
 75            - 'natoms': np.ndarray. Array with the number of atoms in each system.
 76            - 'batch_index': np.ndarray. Array with the index of the system to which each atom
 77            if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
 78    """
 79
 80    assert isinstance(
 81        training_parameters, dict
 82    ), "training_parameters must be a dictionary."
 83    assert isinstance(
 84        rename_refs, dict
 85    ), "rename_refs must be a dictionary with the keys to rename."
 86
 87    pbc_training = training_parameters.get("pbc_training", False)
 88    minimum_image = training_parameters.get("minimum_image", False)
 89
 90    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
 91
 92    input_keys = [
 93        "species",
 94        "coordinates",
 95        "natoms",
 96        "batch_index",
 97        "total_charge",
 98        "flags",
 99    ]
100    if pbc_training:
101        input_keys += ["cells", "reciprocal_cells"]
102    if atom_padding:
103        input_keys += ["true_atoms", "true_sys"]
104    if coordinates_ref_key is not None:
105        input_keys += ["system_index", "system_sign"]
106
107    flags = {f: None for f in add_flags}
108    if minimum_image and pbc_training:
109        flags["minimum_image"] = None
110
111    additional_input_keys = set(training_parameters.get("additional_input_keys", []))
112    additional_input_keys_ = set()
113    for key in additional_input_keys:
114        if key not in input_keys:
115            additional_input_keys_.add(key)
116    additional_input_keys = additional_input_keys_
117
118    all_inputs = set(input_keys + list(additional_input_keys))
119
120    extract_all_keys = ref_keys is None
121    if ref_keys is not None:
122        ref_keys = set(ref_keys)
123        ref_keys_ = set()
124        for key in ref_keys:
125            if key not in all_inputs:
126                ref_keys_.add(key)
127
128    random_rotation = training_parameters.get("random_rotation", False)
129    if random_rotation:
130        assert np_rng is not None, "np_rng must be provided for adding noise."
131
132        apply_rotation = {
133            1: lambda x, r: x @ r,
134            -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r),
135            2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r),
136        }
137        def rotate_2f(x,r):
138            assert x.shape[-1]==6
139            # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor
140            indices = np.array([0,3,4,3,1,5,4,5,2])
141            x=x[...,indices].reshape(*x.shape[:-1],3,3)
142            x=np.einsum("li,...lk,kj->...ij", r, x, r)
143            # select back the 6 components
144            indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]])
145            x=x[...,indices[:,0],indices[:,1]]
146            return x
147        apply_rotation[-2]=rotate_2f
148
149        valid_rotations = tuple(apply_rotation.keys())
150        rotated_keys = {
151            "coordinates": 1,
152            "forces": 1,
153            "virial_tensor": 2,
154            "stress_tensor": 2,
155            "virial": 2,
156            "stress": 2,
157        }
158        if pbc_training:
159            rotated_keys["cells"] = 1
160        user_rotated_keys = dict(training_parameters.get("rotated_keys", {}))
161        for k, v in user_rotated_keys.items():
162            assert (
163                v in valid_rotations
164            ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}"
165            rotated_keys[k] = v
166
167        # rotated_vector_keys = set(
168        #     ["coordinates", "forces"]
169        #     + list(training_parameters.get("rotated_vector_keys", []))
170        # )
171        # if pbc_training:
172        #     rotated_vector_keys.add("cells")
173
174        # rotated_tensor_keys = set(
175        #     ["virial_tensor", "stress_tensor", "virial", "stress"]
176        #     + list(training_parameters.get("rotated_tensor_keys", []))
177        # )
178        # assert rotated_vector_keys.isdisjoint(
179        #     rotated_tensor_keys
180        # ), "Rotated vector keys and rotated tensor keys must be disjoint."
181        # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys)
182
183        print(
184            "Applying random rotations to the following keys if present:",
185            list(rotated_keys.keys()),
186        )
187
188        def apply_random_rotations(output, nbatch):
189            euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3))
190            r = [
191                Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T
192                for i in range(nbatch)
193            ]
194            for k, l in rotated_keys.items():
195                if k in output:
196                    for i in range(nbatch):
197                        output[k][i] = apply_rotation[l](output[k][i], r[i])
198
199    else:
200
201        def apply_random_rotations(output, nbatch):
202            pass
203
204    flow_matching = training_parameters.get("flow_matching", False)
205    if flow_matching:
206        if ref_keys is not None:
207            ref_keys.add("flow_matching_target")
208            if "flow_matching_target" in ref_keys_:
209                ref_keys_.remove("flow_matching_target")
210        all_inputs.add("flow_matching_time")
211
212        def add_flow_matching(output, nbatch):
213            ts = np_rng.uniform(0.0, 1.0, (nbatch,))
214            targets = []
215            for i in range(nbatch):
216                x1 = output["coordinates"][i]
217                com = x1.mean(axis=0, keepdims=True)
218                x1 = x1 - com
219                x0 = np_rng.normal(0.0, 1.0, x1.shape)
220                xt = (1 - ts[i]) * x0 + ts[i] * x1
221                output["coordinates"][i] = xt
222                targets.append(x1 - x0)
223            output["flow_matching_target"] = targets
224            output["flow_matching_time"] = [np.array(t) for t in ts]
225
226    else:
227
228        def add_flow_matching(output, nbatch):
229            pass
230
231    if pbc_training:
232        print("Periodic boundary conditions are active.")
233        length_nopbc = training_parameters.get("length_nopbc", 1000.0)
234
235        def add_cell(d, output):
236            if "cell" not in d:
237                cell = np.asarray(
238                    [
239                        [length_nopbc, 0.0, 0.0],
240                        [0.0, length_nopbc, 0.0],
241                        [0.0, 0.0, length_nopbc],
242                    ],
243                    dtype=fprec,
244                )
245            else:
246                cell = np.asarray(d["cell"], dtype=fprec)
247            output["cells"].append(cell.reshape(1, 3, 3))
248
249    else:
250
251        def add_cell(d, output):
252            if "cell" in d:
253                print(
254                    "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions."
255                )
256
257    if extract_all_keys:
258
259        def add_other_keys(d, output, atom_shift):
260            for k, v in d.items():
261                if k in ("cell", "total_charge"):
262                    continue
263                v_array = np.array(v)
264                # Shift atom number if necessary
265                if k.endswith("_atidx"):
266                    mask = (v_array > 0).astype(int)
267                    v_array = v_array + atom_shift*mask
268                output[k].append(v_array)
269
270    else:
271
272        def add_other_keys(d, output, atom_shift):
273            output["species"].append(np.asarray(d["species"]))
274            output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec))
275            for k in additional_input_keys:
276                v_array = np.array(d[k])
277                # Shift atom number if necessary
278                if k.endswith("_atidx"):
279                    mask = (v_array > 0).astype(int)
280                    v_array =v_array + atom_shift*mask
281                output[k].append(v_array)
282            for k in ref_keys_:
283                v_array = np.array(d[k])
284                # Shift atom number if necessary
285                if k.endswith("_atidx"):
286                    mask = (v_array > 0).astype(int)
287                    v_array =v_array + atom_shift*mask
288                output[k].append(v_array)
289                if k + "_mask" in d:
290                    output[k + "_mask"].append(np.asarray(d[k + "_mask"]))
291
292    def add_keys(d, output, atom_shift, batch_index):
293        nat = d["species"].shape[0]
294
295        output["natoms"].append(np.asarray([nat]))
296        output["batch_index"].append(np.asarray([batch_index] * nat))
297        if "total_charge" not in d:
298            total_charge = np.asarray(0.0, dtype=fprec)
299        else:
300            total_charge = np.asarray(d["total_charge"], dtype=fprec)
301        output["total_charge"].append(total_charge)
302
303        add_cell(d, output)
304        add_other_keys(d, output, atom_shift)
305
306        return atom_shift + nat
307
308    def collate_fn_(batch):
309        output = defaultdict(list)
310        atom_shift = 0
311        batch_index = 0
312
313        for d in batch:
314            atom_shift = add_keys(d, output, atom_shift, batch_index)
315            batch_index += 1
316
317            if coordinates_ref_key is not None:
318                output["system_index"].append(np.asarray([batch_index - 1]))
319                output["system_sign"].append(np.asarray([1]))
320                if coordinates_ref_key in d:
321                    dref = {**d, "coordinates": d[coordinates_ref_key]}
322                    atom_shift = add_keys(dref, output, atom_shift, batch_index)
323                    output["system_index"].append(np.asarray([batch_index - 1]))
324                    output["system_sign"].append(np.asarray([-1]))
325                    batch_index += 1
326
327        nbatch_ = len(output["natoms"])
328        apply_random_rotations(output,nbatch_)
329        add_flow_matching(output,nbatch_)
330
331        # Stack and concatenate the arrays
332        for k, v in output.items():
333            if v[0].ndim == 0:
334                v = np.stack(v)
335            else:
336                v = np.concatenate(v, axis=0)
337            if np.issubdtype(v.dtype, np.floating):
338                v = v.astype(fprec)
339            output[k] = v
340
341        if "cells" in output and pbc_training:
342            output["reciprocal_cells"] = np.linalg.inv(output["cells"])
343
344        # Rename necessary keys
345        # for key in rename_refs:
346        #     if key in output:
347        #         output["true_" + key] = output.pop(key)
348        for kold, knew in rename_refs.items():
349            assert (
350                knew not in output
351            ), f"Cannot rename key {kold} to {knew}. Key {knew} already present."
352            if kold in output:
353                output[knew] = output.pop(kold)
354
355        output["flags"] = flags
356        return output
357
358    collate_layers_train = [collate_fn_]
359    collate_layers_valid = [collate_fn_]
360
361    ### collate preprocessing
362    # add noise to the training data
363    noise_sigma = training_parameters.get("noise_sigma", None)
364    if noise_sigma is not None:
365        assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary"
366
367        for sigma in noise_sigma.values():
368            assert sigma >= 0, "Noise sigma should be a positive number"
369
370        print("Adding noise to the training data:")
371        for key, sigma in noise_sigma.items():
372            print(f"  - {key} with sigma = {sigma}")
373
374        assert np_rng is not None, "np_rng must be provided for adding noise."
375
376        def collate_with_noise(batch):
377            for key, sigma in noise_sigma.items():
378                if key in batch and sigma > 0:
379                    batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype(
380                        batch[key].dtype
381                    )
382            return batch
383
384        collate_layers_train.append(collate_with_noise)
385
386    if atom_padding:
387        padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0))
388        padder_state = padder.init()
389
390        def collate_with_padding(batch):
391            padder_state_up, output = padder(padder_state, batch)
392            padder_state.update(padder_state_up)
393            return output
394
395        collate_layers_train.append(collate_with_padding)
396        collate_layers_valid.append(collate_with_padding)
397
398    if split_data_inputs:
399
400        # input_keys += additional_input_keys
401        # input_keys = set(input_keys)
402        print("Input keys:", all_inputs)
403        print("Ref keys:", ref_keys)
404
405        def collate_split(batch):
406            inputs = {}
407            refs = {}
408            for k, v in batch.items():
409                if k in all_inputs:
410                    inputs[k] = v
411                if k in ref_keys:
412                    refs[k] = v
413                if k.endswith("_mask") and k[:-5] in ref_keys:
414                    refs[k] = v
415            return inputs, refs
416
417        collate_layers_train.append(collate_split)
418        collate_layers_valid.append(collate_split)
419
420    ### apply all collate preprocessing
421    if len(collate_layers_train) == 1:
422        collate_fn_train = collate_layers_train[0]
423    else:
424
425        def collate_fn_train(batch):
426            for layer in collate_layers_train:
427                batch = layer(batch)
428            return batch
429
430    if len(collate_layers_valid) == 1:
431        collate_fn_valid = collate_layers_valid[0]
432    else:
433
434        def collate_fn_valid(batch):
435            for layer in collate_layers_valid:
436                batch = layer(batch)
437            return batch
438
439    if not os.path.exists(dspath):
440        raise ValueError(f"Dataset file '{dspath}' not found.")
441    # dspath = training_parameters.get("dspath", None)
442    print(f"Loading dataset from {dspath}...", end="")
443    # print(f"   the following keys will be renamed if present : {rename_refs}")
444    sharded_training = False
445    if dspath.endswith(".db"):
446        dataset = {}
447        if train_val_split:
448            dataset["training"] = DBDataset(dspath, table="training")
449            dataset["validation"] = DBDataset(dspath, table="validation")
450        else:
451            dataset = DBDataset(dspath)
452    elif dspath.endswith(".h5") or dspath.endswith(".hdf5"):
453        dataset = {}
454        if train_val_split:
455            dataset["training"] = H5Dataset(dspath, table="training")
456            dataset["validation"] = H5Dataset(dspath, table="validation")
457        else:
458            dataset = H5Dataset(dspath)
459    elif dspath.endswith(".pkl") or dspath.endswith(".pickle"):
460        with open(dspath, "rb") as f:
461            dataset = pickle.load(f)
462        if not train_val_split and isinstance(dataset, dict):
463            dataset = dataset["training"]
464    elif os.path.isdir(dspath):
465        if train_val_split:
466            dataset = {}
467            with open(dspath + "/validation.pkl", "rb") as f:
468                dataset["validation"] = pickle.load(f)
469        else:
470            dataset = None
471
472        shard_files = sorted(glob.glob(dspath + "/training_*.pkl"))
473        nshards = len(shard_files)
474        if nshards == 0:
475            raise ValueError("No dataset shards found.")
476        elif nshards == 1:
477            with open(shard_files[0], "rb") as f:
478                if train_val_split:
479                    dataset["training"] = pickle.load(f)
480                else:
481                    dataset = pickle.load(f)
482        else:
483            print(f"Found {nshards} dataset shards.")
484            sharded_training = True
485
486    else:
487        raise ValueError(
488            f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'"
489        )
490    print(" done.")
491
492    ### BUILD DATALOADERS
493    # batch_size = training_parameters.get("batch_size", 16)
494    shuffle = training_parameters.get("shuffle_dataset", True)
495    if train_val_split:
496        if batch_size_val is None:
497            batch_size_val = batch_size
498        dataloader_validation = DataLoader(
499            dataset["validation"],
500            batch_size=batch_size_val,
501            shuffle=shuffle,
502            collate_fn=collate_fn_valid,
503        )
504
505    if sharded_training:
506
507        def iterate_sharded_dataset():
508            indices = np.arange(nshards)
509            if shuffle:
510                assert np_rng is not None, "np_rng must be provided for shuffling."
511                np_rng.shuffle(indices)
512            for i in indices:
513                filename = shard_files[i]
514                print(f"# Loading dataset shard from {filename}...", end="")
515                with open(filename, "rb") as f:
516                    dataset = pickle.load(f)
517                print(" done.")
518                dataloader = DataLoader(
519                    dataset,
520                    batch_size=batch_size,
521                    shuffle=shuffle,
522                    collate_fn=collate_fn_train,
523                )
524                for batch in dataloader:
525                    yield batch
526
527        class DataLoaderSharded:
528            def __iter__(self):
529                return iterate_sharded_dataset()
530
531        dataloader_training = DataLoaderSharded()
532    else:
533        dataloader_training = DataLoader(
534            dataset["training"] if train_val_split else dataset,
535            batch_size=batch_size,
536            shuffle=shuffle,
537            collate_fn=collate_fn_train,
538        )
539
540    if not infinite_iterator:
541        if train_val_split:
542            return dataloader_training, dataloader_validation
543        return dataloader_training
544
545    def next_batch_factory(dataloader):
546        while True:
547            for batch in dataloader:
548                yield batch
549
550    training_iterator = next_batch_factory(dataloader_training)
551    if train_val_split:
552        validation_iterator = next_batch_factory(dataloader_validation)
553        return training_iterator, validation_iterator
554    return training_iterator
555
556
557def load_model(
558    parameters: Dict[str, any],
559    model_file: Optional[str] = None,
560    rng_key: Optional[str] = None,
561) -> FENNIX:
562    """
563    Load a FENNIX model from a file or create a new one.
564
565    Args:
566        parameters (dict): A dictionary of parameters for the model.
567        model_file (str, optional): The path to a saved model file to load.
568
569    Returns:
570        FENNIX: A FENNIX model object.
571    """
572    print_model = parameters["training"].get("print_model", False)
573    if model_file is None:
574        model_file = parameters.get("model_file", None)
575    if model_file is not None and os.path.exists(model_file):
576        model = FENNIX.load(model_file, use_atom_padding=False)
577        if print_model:
578            print(model.summarize())
579        print(f"Restored model from '{model_file}'.")
580    else:
581        assert (
582            rng_key is not None
583        ), "rng_key must be specified if model_file is not provided."
584        model_params = parameters["model"]
585        if isinstance(model_params, str):
586            assert os.path.exists(
587                model_params
588            ), f"Model file '{model_params}' not found."
589            model = FENNIX.load(model_params, use_atom_padding=False)
590            print(f"Restored model from '{model_params}'.")
591        else:
592            model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False)
593        if print_model:
594            print(model.summarize())
595    return model
596
597
598def copy_parameters(variables, variables_ref, params=[".*"]):
599    def merge_params(full_path_, v, v_ref):
600        full_path = "/".join(full_path_[1:]).lower()
601        # status = (False, "")
602        for path in params:
603            if re.match(path.lower(), full_path):
604            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
605                return v_ref
606        return v
607        # return v_ref if status[0] else v
608
609    flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False)
610    flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False)
611    return traverse_util.unflatten_dict(
612        {
613            k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v
614            for k, v in flat.items()
615        }
616    )
617
618
619class TeeLogger(object):
620    def __init__(self, name):
621        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
622        self.stdout = None
623
624    def __del__(self):
625        self.close()
626
627    def write(self, data):
628        self.file.write(data)
629        self.stdout.write(data)
630        self.flush()
631
632    def close(self):
633        self.file.close()
634
635    def flush(self):
636        self.file.flush()
637
638    def bind_stdout(self):
639        if isinstance(sys.stdout, TeeLogger):
640            raise ValueError("stdout already bound to a Tee instance.")
641        if self.stdout is not None:
642            raise ValueError("stdout already bound.")
643        self.stdout = sys.stdout
644        sys.stdout = self
645
646    def unbind_stdout(self):
647        if self.stdout is None:
648            raise ValueError("stdout is not bound.")
649        sys.stdout = self.stdout
def load_configuration(config_file: str) -> Dict[str, <built-in function any>]:
31def load_configuration(config_file: str) -> Dict[str, any]:
32    if config_file.endswith(".json"):
33        parameters = json.load(open(config_file))
34    elif config_file.endswith(".yaml") or config_file.endswith(".yml"):
35        parameters = yaml.load(open(config_file), Loader=yaml.FullLoader)
36    elif tomlkit is not None and config_file.endswith(".toml"):
37        parameters = tomlkit.loads(open(config_file).read())
38    else:
39        supported_formats = [".json", ".yaml", ".yml"]
40        if tomlkit is not None:
41            supported_formats.append(".toml")
42        raise ValueError(
43            f"Unknown config file format. Supported formats: {supported_formats}"
44        )
45    return parameters
def load_dataset( dspath: str, batch_size: int, batch_size_val: Optional[int] = None, rename_refs: dict = {}, infinite_iterator: bool = False, atom_padding: bool = False, ref_keys: Optional[Sequence[str]] = None, split_data_inputs: bool = False, np_rng: Optional[numpy.random._generator.Generator] = None, train_val_split: bool = True, training_parameters: dict = {}, add_flags: Sequence[str] = ['training'], fprec: str = 'float32'):
 48def load_dataset(
 49    dspath: str,
 50    batch_size: int,
 51    batch_size_val: Optional[int] = None,
 52    rename_refs: dict = {},
 53    infinite_iterator: bool = False,
 54    atom_padding: bool = False,
 55    ref_keys: Optional[Sequence[str]] = None,
 56    split_data_inputs: bool = False,
 57    np_rng: Optional[np.random.Generator] = None,
 58    train_val_split: bool = True,
 59    training_parameters: dict = {},
 60    add_flags: Sequence[str] = ["training"],
 61    fprec: str = "float32",
 62):
 63    """
 64    Load a dataset from a pickle file and return two iterators for training and validation batches.
 65
 66    Args:
 67        training_parameters (dict): A dictionary with the following keys:
 68            - 'dspath': str. Path to the pickle file containing the dataset.
 69            - 'batch_size': int. Number of samples per batch.
 70        rename_refs (list, optional): A list of strings with the names of the reference properties to rename.
 71            Default is an empty list.
 72
 73    Returns:
 74        tuple: A tuple of two infinite iterators, one for training batches and one for validation batches.
 75            For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems:
 76            - 'natoms': np.ndarray. Array with the number of atoms in each system.
 77            - 'batch_index': np.ndarray. Array with the index of the system to which each atom
 78            if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
 79    """
 80
 81    assert isinstance(
 82        training_parameters, dict
 83    ), "training_parameters must be a dictionary."
 84    assert isinstance(
 85        rename_refs, dict
 86    ), "rename_refs must be a dictionary with the keys to rename."
 87
 88    pbc_training = training_parameters.get("pbc_training", False)
 89    minimum_image = training_parameters.get("minimum_image", False)
 90
 91    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
 92
 93    input_keys = [
 94        "species",
 95        "coordinates",
 96        "natoms",
 97        "batch_index",
 98        "total_charge",
 99        "flags",
100    ]
101    if pbc_training:
102        input_keys += ["cells", "reciprocal_cells"]
103    if atom_padding:
104        input_keys += ["true_atoms", "true_sys"]
105    if coordinates_ref_key is not None:
106        input_keys += ["system_index", "system_sign"]
107
108    flags = {f: None for f in add_flags}
109    if minimum_image and pbc_training:
110        flags["minimum_image"] = None
111
112    additional_input_keys = set(training_parameters.get("additional_input_keys", []))
113    additional_input_keys_ = set()
114    for key in additional_input_keys:
115        if key not in input_keys:
116            additional_input_keys_.add(key)
117    additional_input_keys = additional_input_keys_
118
119    all_inputs = set(input_keys + list(additional_input_keys))
120
121    extract_all_keys = ref_keys is None
122    if ref_keys is not None:
123        ref_keys = set(ref_keys)
124        ref_keys_ = set()
125        for key in ref_keys:
126            if key not in all_inputs:
127                ref_keys_.add(key)
128
129    random_rotation = training_parameters.get("random_rotation", False)
130    if random_rotation:
131        assert np_rng is not None, "np_rng must be provided for adding noise."
132
133        apply_rotation = {
134            1: lambda x, r: x @ r,
135            -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r),
136            2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r),
137        }
138        def rotate_2f(x,r):
139            assert x.shape[-1]==6
140            # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor
141            indices = np.array([0,3,4,3,1,5,4,5,2])
142            x=x[...,indices].reshape(*x.shape[:-1],3,3)
143            x=np.einsum("li,...lk,kj->...ij", r, x, r)
144            # select back the 6 components
145            indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]])
146            x=x[...,indices[:,0],indices[:,1]]
147            return x
148        apply_rotation[-2]=rotate_2f
149
150        valid_rotations = tuple(apply_rotation.keys())
151        rotated_keys = {
152            "coordinates": 1,
153            "forces": 1,
154            "virial_tensor": 2,
155            "stress_tensor": 2,
156            "virial": 2,
157            "stress": 2,
158        }
159        if pbc_training:
160            rotated_keys["cells"] = 1
161        user_rotated_keys = dict(training_parameters.get("rotated_keys", {}))
162        for k, v in user_rotated_keys.items():
163            assert (
164                v in valid_rotations
165            ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}"
166            rotated_keys[k] = v
167
168        # rotated_vector_keys = set(
169        #     ["coordinates", "forces"]
170        #     + list(training_parameters.get("rotated_vector_keys", []))
171        # )
172        # if pbc_training:
173        #     rotated_vector_keys.add("cells")
174
175        # rotated_tensor_keys = set(
176        #     ["virial_tensor", "stress_tensor", "virial", "stress"]
177        #     + list(training_parameters.get("rotated_tensor_keys", []))
178        # )
179        # assert rotated_vector_keys.isdisjoint(
180        #     rotated_tensor_keys
181        # ), "Rotated vector keys and rotated tensor keys must be disjoint."
182        # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys)
183
184        print(
185            "Applying random rotations to the following keys if present:",
186            list(rotated_keys.keys()),
187        )
188
189        def apply_random_rotations(output, nbatch):
190            euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3))
191            r = [
192                Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T
193                for i in range(nbatch)
194            ]
195            for k, l in rotated_keys.items():
196                if k in output:
197                    for i in range(nbatch):
198                        output[k][i] = apply_rotation[l](output[k][i], r[i])
199
200    else:
201
202        def apply_random_rotations(output, nbatch):
203            pass
204
205    flow_matching = training_parameters.get("flow_matching", False)
206    if flow_matching:
207        if ref_keys is not None:
208            ref_keys.add("flow_matching_target")
209            if "flow_matching_target" in ref_keys_:
210                ref_keys_.remove("flow_matching_target")
211        all_inputs.add("flow_matching_time")
212
213        def add_flow_matching(output, nbatch):
214            ts = np_rng.uniform(0.0, 1.0, (nbatch,))
215            targets = []
216            for i in range(nbatch):
217                x1 = output["coordinates"][i]
218                com = x1.mean(axis=0, keepdims=True)
219                x1 = x1 - com
220                x0 = np_rng.normal(0.0, 1.0, x1.shape)
221                xt = (1 - ts[i]) * x0 + ts[i] * x1
222                output["coordinates"][i] = xt
223                targets.append(x1 - x0)
224            output["flow_matching_target"] = targets
225            output["flow_matching_time"] = [np.array(t) for t in ts]
226
227    else:
228
229        def add_flow_matching(output, nbatch):
230            pass
231
232    if pbc_training:
233        print("Periodic boundary conditions are active.")
234        length_nopbc = training_parameters.get("length_nopbc", 1000.0)
235
236        def add_cell(d, output):
237            if "cell" not in d:
238                cell = np.asarray(
239                    [
240                        [length_nopbc, 0.0, 0.0],
241                        [0.0, length_nopbc, 0.0],
242                        [0.0, 0.0, length_nopbc],
243                    ],
244                    dtype=fprec,
245                )
246            else:
247                cell = np.asarray(d["cell"], dtype=fprec)
248            output["cells"].append(cell.reshape(1, 3, 3))
249
250    else:
251
252        def add_cell(d, output):
253            if "cell" in d:
254                print(
255                    "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions."
256                )
257
258    if extract_all_keys:
259
260        def add_other_keys(d, output, atom_shift):
261            for k, v in d.items():
262                if k in ("cell", "total_charge"):
263                    continue
264                v_array = np.array(v)
265                # Shift atom number if necessary
266                if k.endswith("_atidx"):
267                    mask = (v_array > 0).astype(int)
268                    v_array = v_array + atom_shift*mask
269                output[k].append(v_array)
270
271    else:
272
273        def add_other_keys(d, output, atom_shift):
274            output["species"].append(np.asarray(d["species"]))
275            output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec))
276            for k in additional_input_keys:
277                v_array = np.array(d[k])
278                # Shift atom number if necessary
279                if k.endswith("_atidx"):
280                    mask = (v_array > 0).astype(int)
281                    v_array =v_array + atom_shift*mask
282                output[k].append(v_array)
283            for k in ref_keys_:
284                v_array = np.array(d[k])
285                # Shift atom number if necessary
286                if k.endswith("_atidx"):
287                    mask = (v_array > 0).astype(int)
288                    v_array =v_array + atom_shift*mask
289                output[k].append(v_array)
290                if k + "_mask" in d:
291                    output[k + "_mask"].append(np.asarray(d[k + "_mask"]))
292
293    def add_keys(d, output, atom_shift, batch_index):
294        nat = d["species"].shape[0]
295
296        output["natoms"].append(np.asarray([nat]))
297        output["batch_index"].append(np.asarray([batch_index] * nat))
298        if "total_charge" not in d:
299            total_charge = np.asarray(0.0, dtype=fprec)
300        else:
301            total_charge = np.asarray(d["total_charge"], dtype=fprec)
302        output["total_charge"].append(total_charge)
303
304        add_cell(d, output)
305        add_other_keys(d, output, atom_shift)
306
307        return atom_shift + nat
308
309    def collate_fn_(batch):
310        output = defaultdict(list)
311        atom_shift = 0
312        batch_index = 0
313
314        for d in batch:
315            atom_shift = add_keys(d, output, atom_shift, batch_index)
316            batch_index += 1
317
318            if coordinates_ref_key is not None:
319                output["system_index"].append(np.asarray([batch_index - 1]))
320                output["system_sign"].append(np.asarray([1]))
321                if coordinates_ref_key in d:
322                    dref = {**d, "coordinates": d[coordinates_ref_key]}
323                    atom_shift = add_keys(dref, output, atom_shift, batch_index)
324                    output["system_index"].append(np.asarray([batch_index - 1]))
325                    output["system_sign"].append(np.asarray([-1]))
326                    batch_index += 1
327
328        nbatch_ = len(output["natoms"])
329        apply_random_rotations(output,nbatch_)
330        add_flow_matching(output,nbatch_)
331
332        # Stack and concatenate the arrays
333        for k, v in output.items():
334            if v[0].ndim == 0:
335                v = np.stack(v)
336            else:
337                v = np.concatenate(v, axis=0)
338            if np.issubdtype(v.dtype, np.floating):
339                v = v.astype(fprec)
340            output[k] = v
341
342        if "cells" in output and pbc_training:
343            output["reciprocal_cells"] = np.linalg.inv(output["cells"])
344
345        # Rename necessary keys
346        # for key in rename_refs:
347        #     if key in output:
348        #         output["true_" + key] = output.pop(key)
349        for kold, knew in rename_refs.items():
350            assert (
351                knew not in output
352            ), f"Cannot rename key {kold} to {knew}. Key {knew} already present."
353            if kold in output:
354                output[knew] = output.pop(kold)
355
356        output["flags"] = flags
357        return output
358
359    collate_layers_train = [collate_fn_]
360    collate_layers_valid = [collate_fn_]
361
362    ### collate preprocessing
363    # add noise to the training data
364    noise_sigma = training_parameters.get("noise_sigma", None)
365    if noise_sigma is not None:
366        assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary"
367
368        for sigma in noise_sigma.values():
369            assert sigma >= 0, "Noise sigma should be a positive number"
370
371        print("Adding noise to the training data:")
372        for key, sigma in noise_sigma.items():
373            print(f"  - {key} with sigma = {sigma}")
374
375        assert np_rng is not None, "np_rng must be provided for adding noise."
376
377        def collate_with_noise(batch):
378            for key, sigma in noise_sigma.items():
379                if key in batch and sigma > 0:
380                    batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype(
381                        batch[key].dtype
382                    )
383            return batch
384
385        collate_layers_train.append(collate_with_noise)
386
387    if atom_padding:
388        padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0))
389        padder_state = padder.init()
390
391        def collate_with_padding(batch):
392            padder_state_up, output = padder(padder_state, batch)
393            padder_state.update(padder_state_up)
394            return output
395
396        collate_layers_train.append(collate_with_padding)
397        collate_layers_valid.append(collate_with_padding)
398
399    if split_data_inputs:
400
401        # input_keys += additional_input_keys
402        # input_keys = set(input_keys)
403        print("Input keys:", all_inputs)
404        print("Ref keys:", ref_keys)
405
406        def collate_split(batch):
407            inputs = {}
408            refs = {}
409            for k, v in batch.items():
410                if k in all_inputs:
411                    inputs[k] = v
412                if k in ref_keys:
413                    refs[k] = v
414                if k.endswith("_mask") and k[:-5] in ref_keys:
415                    refs[k] = v
416            return inputs, refs
417
418        collate_layers_train.append(collate_split)
419        collate_layers_valid.append(collate_split)
420
421    ### apply all collate preprocessing
422    if len(collate_layers_train) == 1:
423        collate_fn_train = collate_layers_train[0]
424    else:
425
426        def collate_fn_train(batch):
427            for layer in collate_layers_train:
428                batch = layer(batch)
429            return batch
430
431    if len(collate_layers_valid) == 1:
432        collate_fn_valid = collate_layers_valid[0]
433    else:
434
435        def collate_fn_valid(batch):
436            for layer in collate_layers_valid:
437                batch = layer(batch)
438            return batch
439
440    if not os.path.exists(dspath):
441        raise ValueError(f"Dataset file '{dspath}' not found.")
442    # dspath = training_parameters.get("dspath", None)
443    print(f"Loading dataset from {dspath}...", end="")
444    # print(f"   the following keys will be renamed if present : {rename_refs}")
445    sharded_training = False
446    if dspath.endswith(".db"):
447        dataset = {}
448        if train_val_split:
449            dataset["training"] = DBDataset(dspath, table="training")
450            dataset["validation"] = DBDataset(dspath, table="validation")
451        else:
452            dataset = DBDataset(dspath)
453    elif dspath.endswith(".h5") or dspath.endswith(".hdf5"):
454        dataset = {}
455        if train_val_split:
456            dataset["training"] = H5Dataset(dspath, table="training")
457            dataset["validation"] = H5Dataset(dspath, table="validation")
458        else:
459            dataset = H5Dataset(dspath)
460    elif dspath.endswith(".pkl") or dspath.endswith(".pickle"):
461        with open(dspath, "rb") as f:
462            dataset = pickle.load(f)
463        if not train_val_split and isinstance(dataset, dict):
464            dataset = dataset["training"]
465    elif os.path.isdir(dspath):
466        if train_val_split:
467            dataset = {}
468            with open(dspath + "/validation.pkl", "rb") as f:
469                dataset["validation"] = pickle.load(f)
470        else:
471            dataset = None
472
473        shard_files = sorted(glob.glob(dspath + "/training_*.pkl"))
474        nshards = len(shard_files)
475        if nshards == 0:
476            raise ValueError("No dataset shards found.")
477        elif nshards == 1:
478            with open(shard_files[0], "rb") as f:
479                if train_val_split:
480                    dataset["training"] = pickle.load(f)
481                else:
482                    dataset = pickle.load(f)
483        else:
484            print(f"Found {nshards} dataset shards.")
485            sharded_training = True
486
487    else:
488        raise ValueError(
489            f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'"
490        )
491    print(" done.")
492
493    ### BUILD DATALOADERS
494    # batch_size = training_parameters.get("batch_size", 16)
495    shuffle = training_parameters.get("shuffle_dataset", True)
496    if train_val_split:
497        if batch_size_val is None:
498            batch_size_val = batch_size
499        dataloader_validation = DataLoader(
500            dataset["validation"],
501            batch_size=batch_size_val,
502            shuffle=shuffle,
503            collate_fn=collate_fn_valid,
504        )
505
506    if sharded_training:
507
508        def iterate_sharded_dataset():
509            indices = np.arange(nshards)
510            if shuffle:
511                assert np_rng is not None, "np_rng must be provided for shuffling."
512                np_rng.shuffle(indices)
513            for i in indices:
514                filename = shard_files[i]
515                print(f"# Loading dataset shard from {filename}...", end="")
516                with open(filename, "rb") as f:
517                    dataset = pickle.load(f)
518                print(" done.")
519                dataloader = DataLoader(
520                    dataset,
521                    batch_size=batch_size,
522                    shuffle=shuffle,
523                    collate_fn=collate_fn_train,
524                )
525                for batch in dataloader:
526                    yield batch
527
528        class DataLoaderSharded:
529            def __iter__(self):
530                return iterate_sharded_dataset()
531
532        dataloader_training = DataLoaderSharded()
533    else:
534        dataloader_training = DataLoader(
535            dataset["training"] if train_val_split else dataset,
536            batch_size=batch_size,
537            shuffle=shuffle,
538            collate_fn=collate_fn_train,
539        )
540
541    if not infinite_iterator:
542        if train_val_split:
543            return dataloader_training, dataloader_validation
544        return dataloader_training
545
546    def next_batch_factory(dataloader):
547        while True:
548            for batch in dataloader:
549                yield batch
550
551    training_iterator = next_batch_factory(dataloader_training)
552    if train_val_split:
553        validation_iterator = next_batch_factory(dataloader_validation)
554        return training_iterator, validation_iterator
555    return training_iterator

Load a dataset from a pickle file and return two iterators for training and validation batches.

Args: training_parameters (dict): A dictionary with the following keys: - 'dspath': str. Path to the pickle file containing the dataset. - 'batch_size': int. Number of samples per batch. rename_refs (list, optional): A list of strings with the names of the reference properties to rename. Default is an empty list.

Returns: tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: - 'natoms': np.ndarray. Array with the number of atoms in each system. - 'batch_index': np.ndarray. Array with the index of the system to which each atom if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.

def load_model( parameters: Dict[str, <built-in function any>], model_file: Optional[str] = None, rng_key: Optional[str] = None) -> fennol.models.fennix.FENNIX:
558def load_model(
559    parameters: Dict[str, any],
560    model_file: Optional[str] = None,
561    rng_key: Optional[str] = None,
562) -> FENNIX:
563    """
564    Load a FENNIX model from a file or create a new one.
565
566    Args:
567        parameters (dict): A dictionary of parameters for the model.
568        model_file (str, optional): The path to a saved model file to load.
569
570    Returns:
571        FENNIX: A FENNIX model object.
572    """
573    print_model = parameters["training"].get("print_model", False)
574    if model_file is None:
575        model_file = parameters.get("model_file", None)
576    if model_file is not None and os.path.exists(model_file):
577        model = FENNIX.load(model_file, use_atom_padding=False)
578        if print_model:
579            print(model.summarize())
580        print(f"Restored model from '{model_file}'.")
581    else:
582        assert (
583            rng_key is not None
584        ), "rng_key must be specified if model_file is not provided."
585        model_params = parameters["model"]
586        if isinstance(model_params, str):
587            assert os.path.exists(
588                model_params
589            ), f"Model file '{model_params}' not found."
590            model = FENNIX.load(model_params, use_atom_padding=False)
591            print(f"Restored model from '{model_params}'.")
592        else:
593            model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False)
594        if print_model:
595            print(model.summarize())
596    return model

Load a FENNIX model from a file or create a new one.

Args: parameters (dict): A dictionary of parameters for the model. model_file (str, optional): The path to a saved model file to load.

Returns: FENNIX: A FENNIX model object.

def copy_parameters(variables, variables_ref, params=['.*']):
599def copy_parameters(variables, variables_ref, params=[".*"]):
600    def merge_params(full_path_, v, v_ref):
601        full_path = "/".join(full_path_[1:]).lower()
602        # status = (False, "")
603        for path in params:
604            if re.match(path.lower(), full_path):
605            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
606                return v_ref
607        return v
608        # return v_ref if status[0] else v
609
610    flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False)
611    flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False)
612    return traverse_util.unflatten_dict(
613        {
614            k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v
615            for k, v in flat.items()
616        }
617    )
class TeeLogger:
620class TeeLogger(object):
621    def __init__(self, name):
622        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
623        self.stdout = None
624
625    def __del__(self):
626        self.close()
627
628    def write(self, data):
629        self.file.write(data)
630        self.stdout.write(data)
631        self.flush()
632
633    def close(self):
634        self.file.close()
635
636    def flush(self):
637        self.file.flush()
638
639    def bind_stdout(self):
640        if isinstance(sys.stdout, TeeLogger):
641            raise ValueError("stdout already bound to a Tee instance.")
642        if self.stdout is not None:
643            raise ValueError("stdout already bound.")
644        self.stdout = sys.stdout
645        sys.stdout = self
646
647    def unbind_stdout(self):
648        if self.stdout is None:
649            raise ValueError("stdout is not bound.")
650        sys.stdout = self.stdout
TeeLogger(name)
621    def __init__(self, name):
622        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
623        self.stdout = None
file
stdout
def write(self, data):
628    def write(self, data):
629        self.file.write(data)
630        self.stdout.write(data)
631        self.flush()
def close(self):
633    def close(self):
634        self.file.close()
def flush(self):
636    def flush(self):
637        self.file.flush()
def bind_stdout(self):
639    def bind_stdout(self):
640        if isinstance(sys.stdout, TeeLogger):
641            raise ValueError("stdout already bound to a Tee instance.")
642        if self.stdout is not None:
643            raise ValueError("stdout already bound.")
644        self.stdout = sys.stdout
645        sys.stdout = self
def unbind_stdout(self):
647    def unbind_stdout(self):
648        if self.stdout is None:
649            raise ValueError("stdout is not bound.")
650        sys.stdout = self.stdout