fennol.models.preprocessing

   1import flax.linen as nn
   2from typing import Sequence, Callable, Union, Dict, Any, ClassVar
   3import jax.numpy as jnp
   4import jax
   5import numpy as np
   6from typing import Optional, Tuple
   7import numba
   8import dataclasses
   9from functools import partial
  10
  11from flax.core.frozen_dict import FrozenDict
  12
  13
  14from ..utils.activations import chain,safe_sqrt
  15from ..utils import deep_update, mask_filter_1d
  16from ..utils.kspace import get_reciprocal_space_parameters
  17from .misc.misc import SwitchFunction
  18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES
  19
  20
  21@dataclasses.dataclass(frozen=True)
  22class GraphGenerator:
  23    """Generate a graph from a set of coordinates
  24
  25    FPID: GRAPH
  26
  27    For now, we generate all pairs of atoms and filter based on cutoff.
  28    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
  29    """
  30
  31    cutoff: float
  32    """Cutoff distance for the graph."""
  33    graph_key: str = "graph"
  34    """Key of the graph in the outputs."""
  35    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
  36    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
  37    kmax: int = 30
  38    """Maximum number of k-points to consider."""
  39    kthr: float = 1e-6
  40    """Threshold for k-point filtering."""
  41    k_space: bool = False
  42    """Whether to generate k-space information for the graph."""
  43    mult_size: float = 1.05
  44    """Multiplicative factor for resizing the nblist."""
  45    # covalent_cutoff: bool = False
  46
  47    FPID: ClassVar[str] = "GRAPH"
  48
  49    def init(self):
  50        return FrozenDict(
  51            {
  52                "max_nat": 1,
  53                "npairs": 1,
  54                "nblist_mult_size": self.mult_size,
  55            }
  56        )
  57
  58    def get_processor(self) -> Tuple[nn.Module, Dict]:
  59        return GraphProcessor, {
  60            "cutoff": self.cutoff,
  61            "graph_key": self.graph_key,
  62            "switch_params": self.switch_params,
  63            "name": f"{self.graph_key}_Processor",
  64        }
  65
  66    def get_graph_properties(self):
  67        return {
  68            self.graph_key: {
  69                "cutoff": self.cutoff,
  70                "directed": True,
  71            }
  72        }
  73
  74    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
  75        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
  76        if self.graph_key in inputs:
  77            graph = inputs[self.graph_key]
  78            if "keep_graph" in graph:
  79                return state, inputs
  80
  81        coords = np.array(inputs["coordinates"], dtype=np.float32)
  82        natoms = np.array(inputs["natoms"], dtype=np.int32)
  83        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
  84
  85        new_state = {**state}
  86        state_up = {}
  87
  88        mult_size = state.get("nblist_mult_size", self.mult_size)
  89        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
  90
  91        if natoms.shape[0] == 1:
  92            max_nat = coords.shape[0]
  93            true_max_nat = max_nat
  94        else:
  95            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
  96            true_max_nat = int(np.max(natoms))
  97            if true_max_nat > max_nat:
  98                add_atoms = state.get("add_atoms", 0)
  99                new_maxnat = true_max_nat + add_atoms
 100                state_up["max_nat"] = (new_maxnat, max_nat)
 101                new_state["max_nat"] = new_maxnat
 102
 103        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 104
 105        ### compute indices of all pairs
 106        apply_pbc = "cells" in inputs
 107        minimage = "minimum_image" in inputs.get("flags", {})
 108        include_self_image = apply_pbc and not minimage
 109
 110        shift = 0 if include_self_image else 1
 111        p1, p2 = np.triu_indices(true_max_nat, shift)
 112        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 113        pbc_shifts = None
 114        if natoms.shape[0] > 1:
 115            ## batching => mask irrelevant pairs
 116            mask_p12 = (
 117                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 118            ).flatten()
 119            shift = np.concatenate(
 120                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
 121            )
 122            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 123            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 124
 125        if not apply_pbc:
 126            ### NO PBC
 127            vec = coords[p2] - coords[p1]
 128        else:
 129            cells = np.array(inputs["cells"], dtype=np.float32)
 130            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
 131            if minimage:
 132                ## MINIMUM IMAGE CONVENTION
 133                vec = coords[p2] - coords[p1]
 134                if cells.shape[0] == 1:
 135                    vecpbc = np.dot(vec, reciprocal_cells[0])
 136                    pbc_shifts = -np.round(vecpbc).astype(np.int32)
 137                    vec = vec + np.dot(pbc_shifts, cells[0])
 138                else:
 139                    batch_index_vec = batch_index[p1]
 140                    vecpbc = np.einsum(
 141                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
 142                    )
 143                    pbc_shifts = -np.round(vecpbc).astype(np.int32)
 144                    vec = vec + np.einsum(
 145                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
 146                    )
 147            else:
 148                ### GENERAL PBC
 149                ## put all atoms in central box
 150                if cells.shape[0] == 1:
 151                    coords_pbc = np.dot(coords, reciprocal_cells[0])
 152                    at_shifts = -np.floor(coords_pbc).astype(np.int32)
 153                    coords_pbc = coords + np.dot(at_shifts, cells[0])
 154                else:
 155                    coords_pbc = np.einsum(
 156                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
 157                    )
 158                    at_shifts = -np.floor(coords_pbc).astype(np.int32)
 159                    coords_pbc = coords + np.einsum(
 160                        "aj,aji->ai", at_shifts, cells[batch_index]
 161                    )
 162                vec = coords_pbc[p2] - coords_pbc[p1]
 163
 164                ## compute maximum number of repeats
 165                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
 166                cdinv = cutoff_skin * inv_distances
 167                num_repeats_all = np.ceil(cdinv).astype(np.int32)
 168                if "true_sys" in inputs:
 169                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
 170                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
 171                num_repeats = np.max(num_repeats_all, axis=0)
 172                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
 173                if np.any(num_repeats > num_repeats_prev):
 174                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
 175                    state_up["num_repeats_pbc"] = (
 176                        tuple(num_repeats_new),
 177                        tuple(num_repeats_prev),
 178                    )
 179                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
 180                ## build all possible shifts
 181                cell_shift_pbc = np.array(
 182                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 183                    dtype=np.int32,
 184                ).T.reshape(-1, 3)
 185                ## shift applied to vectors
 186                if cells.shape[0] == 1:
 187                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
 188                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
 189                    pbc_shifts = np.broadcast_to(
 190                        cell_shift_pbc[None, :, :],
 191                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 192                    ).reshape(-1, 3)
 193                    p1 = np.broadcast_to(
 194                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 195                    ).flatten()
 196                    p2 = np.broadcast_to(
 197                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 198                    ).flatten()
 199                    if natoms.shape[0] > 1:
 200                        mask_p12 = np.broadcast_to(
 201                            mask_p12[:, None],
 202                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
 203                        ).flatten()
 204                else:
 205                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
 206
 207                    ## get pbc shifts specific to each box
 208                    cell_shift_pbc = np.broadcast_to(
 209                        cell_shift_pbc[None, :, :],
 210                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 211                    )
 212                    mask = np.all(
 213                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 214                    ).flatten()
 215                    idx = np.nonzero(mask)[0]
 216                    nshifts = idx.shape[0]
 217                    nshifts_prev = state.get("nshifts_pbc", 0)
 218                    if nshifts > nshifts_prev or add_margin:
 219                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
 220                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
 221                        new_state["nshifts_pbc"] = nshifts_new
 222
 223                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
 224                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
 225
 226                    ## get batch shift in the dvec_filter array
 227                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
 228                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
 229
 230                    ## compute vectors
 231                    batch_index_vec = batch_index[p1]
 232                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
 233                    vec = vec.repeat(nrep_vec, axis=0)
 234                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
 235                    nvec_pbc_prev = state.get("nvec_pbc", 0)
 236                    if nvec_pbc > nvec_pbc_prev or add_margin:
 237                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
 238                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
 239                        new_state["nvec_pbc"] = nvec_pbc_new
 240
 241                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
 242                    ## get shift index
 243                    dshift = np.concatenate(
 244                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
 245                    ).repeat(nrep_vec)
 246                    # ishift = np.arange(dshift.shape[0])-dshift
 247                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
 248                    icellshift = (
 249                        np.arange(dshift.shape[0])
 250                        - dshift
 251                        + bshift[batch_index_vec].repeat(nrep_vec)
 252                    )
 253                    # shift vectors
 254                    vec = vec + dvec_filter[icellshift]
 255                    pbc_shifts = cell_shift_pbc_filter[icellshift]
 256
 257                    p1 = np.repeat(p1, nrep_vec)
 258                    p2 = np.repeat(p2, nrep_vec)
 259                    if natoms.shape[0] > 1:
 260                        mask_p12 = np.repeat(mask_p12, nrep_vec)
 261
 262        ## compute distances
 263        d12 = (vec**2).sum(axis=-1)
 264        if natoms.shape[0] > 1:
 265            d12 = np.where(mask_p12, d12, cutoff_skin**2)
 266
 267        ## filter pairs
 268        max_pairs = state.get("npairs", 1)
 269        mask = d12 < cutoff_skin**2
 270        if include_self_image:
 271            mask_self = np.logical_or(p1 != p2, d12 > 1.e-3)
 272            mask = np.logical_and(mask, mask_self)
 273        idx = np.nonzero(mask)[0]
 274        npairs = idx.shape[0]
 275        if npairs > max_pairs or add_margin:
 276            prev_max_pairs = max_pairs
 277            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 278            state_up["npairs"] = (max_pairs, prev_max_pairs)
 279            new_state["npairs"] = max_pairs
 280
 281        nat = coords.shape[0]
 282        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 283        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 284        d12_ = np.full(max_pairs, cutoff_skin**2)
 285        edge_src[:npairs] = p1[idx]
 286        edge_dst[:npairs] = p2[idx]
 287        d12_[:npairs] = d12[idx]
 288        d12 = d12_
 289
 290        if apply_pbc:
 291            pbc_shifts_ = np.zeros((max_pairs, 3), dtype=np.int32)
 292            pbc_shifts_[:npairs] = pbc_shifts[idx]
 293            pbc_shifts = pbc_shifts_
 294            if not minimage:
 295                pbc_shifts[:npairs] = (
 296                    pbc_shifts[:npairs]
 297                    + at_shifts[edge_dst[:npairs]]
 298                    - at_shifts[edge_src[:npairs]]
 299                )
 300
 301
 302        ## symmetrize
 303        if include_self_image:
 304            mask_noself = edge_src != edge_dst
 305            idx_noself = np.nonzero(mask_noself)[0]
 306            npairs_noself = idx_noself.shape[0]
 307            max_noself = state.get("npairs_noself", 1)
 308            if npairs_noself > max_noself or add_margin:
 309                prev_max_noself = max_noself
 310                max_noself = int(mult_size * max(npairs_noself, max_noself)) + 1
 311                state_up["npairs_noself"] = (max_noself, prev_max_noself)
 312                new_state["npairs_noself"] = max_noself
 313
 314            edge_src_noself = np.full(max_noself, nat, dtype=np.int32)
 315            edge_dst_noself = np.full(max_noself, nat, dtype=np.int32)
 316            d12_noself = np.full(max_noself, cutoff_skin**2)
 317            pbc_shifts_noself = np.zeros((max_noself, 3), dtype=np.int32)
 318
 319            edge_dst_noself[:npairs_noself] = edge_dst[idx_noself]
 320            edge_src_noself[:npairs_noself] = edge_src[idx_noself]
 321            d12_noself[:npairs_noself] = d12[idx_noself]
 322            pbc_shifts_noself[:npairs_noself] = pbc_shifts[idx_noself]
 323            edge_src = np.concatenate((edge_src, edge_dst_noself))
 324            edge_dst = np.concatenate((edge_dst, edge_src_noself))
 325            d12 = np.concatenate((d12, d12_noself))
 326            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts_noself))
 327        else:
 328            edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
 329                (edge_dst, edge_src)
 330            )
 331            d12 = np.concatenate((d12, d12))
 332            if apply_pbc:
 333                pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
 334
 335        if "nblist_skin" in state:
 336            edge_src_skin = edge_src
 337            edge_dst_skin = edge_dst
 338            if apply_pbc:
 339                pbc_shifts_skin = pbc_shifts
 340            max_pairs_skin = state.get("npairs_skin", 1)
 341            mask = d12 < self.cutoff**2
 342            idx = np.nonzero(mask)[0]
 343            npairs_skin = idx.shape[0]
 344            if npairs_skin > max_pairs_skin or add_margin:
 345                prev_max_pairs_skin = max_pairs_skin
 346                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
 347                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
 348                new_state["npairs_skin"] = max_pairs_skin
 349            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
 350            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
 351            d12_ = np.full(max_pairs_skin, self.cutoff**2)
 352            edge_src[:npairs_skin] = edge_src_skin[idx]
 353            edge_dst[:npairs_skin] = edge_dst_skin[idx]
 354            d12_[:npairs_skin] = d12[idx]
 355            d12 = d12_
 356            if apply_pbc:
 357                pbc_shifts = np.zeros((max_pairs_skin, 3), dtype=np.int32)
 358                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
 359
 360        graph = inputs.get(self.graph_key, {})
 361        graph_out = {
 362            **graph,
 363            "edge_src": edge_src,
 364            "edge_dst": edge_dst,
 365            "d12": d12,
 366            "overflow": False,
 367            "pbc_shifts": pbc_shifts,
 368        }
 369        if "nblist_skin" in state:
 370            graph_out["edge_src_skin"] = edge_src_skin
 371            graph_out["edge_dst_skin"] = edge_dst_skin
 372            if apply_pbc:
 373                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 374
 375        if self.k_space and apply_pbc:
 376            if "k_points" not in graph:
 377                ks, _, _, bewald = get_reciprocal_space_parameters(
 378                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
 379                )
 380            graph_out["k_points"] = ks
 381            graph_out["b_ewald"] = bewald
 382
 383        output = {**inputs, self.graph_key: graph_out}
 384
 385        if return_state_update:
 386            return FrozenDict(new_state), output, state_up
 387        return FrozenDict(new_state), output
 388
 389    def check_reallocate(self, state, inputs, parent_overflow=False):
 390        """check for overflow and reallocate nblist if necessary"""
 391        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 392        if not overflow:
 393            return state, {}, inputs, False
 394
 395        add_margin = inputs[self.graph_key].get("overflow", False)
 396        state, inputs, state_up = self(
 397            state, inputs, return_state_update=True, add_margin=add_margin
 398        )
 399        return state, state_up, inputs, True
 400
 401    @partial(jax.jit, static_argnums=(0, 1))
 402    def process(self, state, inputs):
 403        """build a nblist on accelerator with jax and precomputed shapes"""
 404        if self.graph_key in inputs:
 405            graph = inputs[self.graph_key]
 406            if "keep_graph" in graph:
 407                return inputs
 408        coords = inputs["coordinates"]
 409        natoms = inputs["natoms"]
 410        batch_index = inputs["batch_index"]
 411
 412        if natoms.shape[0] == 1:
 413            max_nat = coords.shape[0]
 414        else:
 415            max_nat = state.get(
 416                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
 417            )
 418        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 419
 420        ### compute indices of all pairs
 421        apply_pbc = "cells" in inputs
 422        minimage = "minimum_image" in inputs.get("flags", {})
 423        include_self_image = apply_pbc and not minimage
 424
 425        shift = 0 if include_self_image else 1
 426        p1, p2 = np.triu_indices(max_nat, shift)
 427        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 428        pbc_shifts = None
 429        if natoms.shape[0] > 1:
 430            ## batching => mask irrelevant pairs
 431            mask_p12 = (
 432                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 433            ).flatten()
 434            shift = jnp.concatenate(
 435                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
 436            )
 437            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 438            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 439
 440        ## compute vectors
 441        overflow_repeats = jnp.asarray(False, dtype=bool)
 442        if "cells" not in inputs:
 443            vec = coords[p2] - coords[p1]
 444        else:
 445            cells = inputs["cells"]
 446            reciprocal_cells = inputs["reciprocal_cells"]
 447            # minimage = state.get("minimum_image", True)
 448            minimage = "minimum_image" in inputs.get("flags", {})
 449
 450            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
 451                vecpbc = jnp.dot(vec, reciprocal_cell)
 452                if mode == "round":
 453                    pbc_shifts = -jnp.round(vecpbc).astype(jnp.int32)
 454                elif mode == "floor":
 455                    pbc_shifts = -jnp.floor(vecpbc).astype(jnp.int32)
 456                else:
 457                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
 458                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
 459
 460            if minimage:
 461                ## minimum image convention
 462                vec = coords[p2] - coords[p1]
 463
 464                if cells.shape[0] == 1:
 465                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
 466                else:
 467                    batch_index_vec = batch_index[p1]
 468                    vec, pbc_shifts = jax.vmap(compute_pbc)(
 469                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
 470                    )
 471            else:
 472                ### general PBC only for single cell yet
 473                # if cells.shape[0] > 1:
 474                #     raise NotImplementedError(
 475                #         "General PBC not implemented for batches on accelerator."
 476                #     )
 477                # cell = cells[0]
 478                # reciprocal_cell = reciprocal_cells[0]
 479
 480                ## put all atoms in central box
 481                if cells.shape[0] == 1:
 482                    coords_pbc, at_shifts = compute_pbc(
 483                        coords, reciprocal_cells[0], cells[0], mode="floor"
 484                    )
 485                else:
 486                    coords_pbc, at_shifts = jax.vmap(
 487                        partial(compute_pbc, mode="floor")
 488                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
 489                vec = coords_pbc[p2] - coords_pbc[p1]
 490                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
 491                # if num_repeats is None:
 492                #     raise ValueError(
 493                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
 494                #     )
 495                # check if num_repeats is larger than previous
 496                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
 497                cdinv = cutoff_skin * inv_distances
 498                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
 499                if "true_sys" in inputs:
 500                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
 501                num_repeats_new = jnp.max(num_repeats_all, axis=0)
 502                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
 503
 504                cell_shift_pbc = jnp.asarray(
 505                    np.array(
 506                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 507                        dtype=np.int32,
 508                    ).T.reshape(-1, 3)
 509                )
 510
 511                if cells.shape[0] == 1:
 512                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
 513                    pbc_shifts = jnp.broadcast_to(
 514                        cell_shift_pbc[None, :, :],
 515                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 516                    ).reshape(-1, 3)
 517                    p1 = jnp.broadcast_to(
 518                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 519                    ).flatten()
 520                    p2 = jnp.broadcast_to(
 521                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 522                    ).flatten()
 523                    if natoms.shape[0] > 1:
 524                        mask_p12 = jnp.broadcast_to(
 525                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
 526                        ).flatten()
 527                else:
 528                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
 529
 530                    ## get pbc shifts specific to each box
 531                    cell_shift_pbc = jnp.broadcast_to(
 532                        cell_shift_pbc[None, :, :],
 533                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 534                    )
 535                    mask = jnp.all(
 536                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 537                    ).flatten()
 538                    max_shifts  = state.get("nshifts_pbc", 1)
 539
 540                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
 541                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
 542                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
 543                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
 544                        mask,
 545                        max_shifts,
 546                        (dvecx, 0.),
 547                        (dvecy, 0.),
 548                        (dvecz, 0.),
 549                        (shiftx, 0),
 550                        (shifty, 0),
 551                        (shiftz, 0),
 552                    )
 553                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
 554                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1, dtype=jnp.int32)
 555                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
 556
 557                    ## get batch shift in the dvec_filter array
 558                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
 559                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
 560
 561                    ## repeat vectors
 562                    nvec_max = state.get("nvec_pbc", 1)
 563                    batch_index_vec = batch_index[p1]
 564                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
 565                    nvec = nrep_vec.sum()
 566                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
 567                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
 568                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
 569
 570                    ## get shift index
 571                    dshift = jnp.concatenate(
 572                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
 573                    )
 574                    if nrep_vec.size == 0:
 575                        dshift = jnp.array([],dtype=jnp.int32)
 576                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
 577                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
 578                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
 579                    vec = vec + dvec[icellshift]
 580                    pbc_shifts = cell_shift_pbc[icellshift]
 581                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
 582                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
 583                    mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
 584
 585                    mask_valid = jnp.arange(nvec_max) < nvec
 586                    mask_p12 = jnp.where(mask_valid, mask_p12, False)
 587                
 588
 589        ## compute distances
 590        d12 = (vec**2).sum(axis=-1)
 591        if natoms.shape[0] > 1:
 592            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
 593
 594        ## filter pairs
 595        max_pairs = state.get("npairs", 1)
 596        mask = d12 < cutoff_skin**2
 597        if include_self_image:
 598            mask_self = jnp.logical_or(p1 != p2, d12 > 1.e-3)
 599            mask = jnp.logical_and(mask, mask_self)
 600        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 601            mask,
 602            max_pairs,
 603            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
 604            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
 605            (d12, cutoff_skin**2),
 606        )
 607        if "cells" in inputs:
 608            pbc_shifts = (
 609                jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype)
 610                .at[scatter_idx]
 611                .set(pbc_shifts, mode="drop")
 612            )
 613            if not minimage:
 614                pbc_shifts = (
 615                    pbc_shifts
 616                    + at_shifts.at[edge_dst].get(fill_value=0)
 617                    - at_shifts.at[edge_src].get(fill_value=0)
 618                )
 619
 620        ## check for overflow
 621        if natoms.shape[0] == 1:
 622            true_max_nat = coords.shape[0]
 623        else:
 624            true_max_nat = jnp.max(natoms)
 625        overflow_count = npairs > max_pairs
 626        overflow_at = true_max_nat > max_nat
 627        overflow = overflow_count | overflow_at | overflow_repeats
 628
 629        ## symmetrize
 630        if include_self_image:
 631            mask_noself = edge_src != edge_dst
 632            max_noself = state.get("npairs_noself", 1)
 633            (edge_src_noself, edge_dst_noself, d12_noself, pbc_shifts_noself), scatter_idx, npairs_noself = mask_filter_1d(
 634                mask_noself,
 635                max_noself,
 636                (edge_src, coords.shape[0]),
 637                (edge_dst, coords.shape[0]),
 638                (d12, cutoff_skin**2),
 639                (pbc_shifts, jnp.zeros((3,), dtype=pbc_shifts.dtype)),
 640            )
 641            overflow = overflow | (npairs_noself > max_noself)
 642            edge_src = jnp.concatenate((edge_src, edge_dst_noself))
 643            edge_dst = jnp.concatenate((edge_dst, edge_src_noself))
 644            d12 = jnp.concatenate((d12, d12_noself))
 645            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts_noself))
 646        else:
 647            edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
 648                (edge_dst, edge_src)
 649            )
 650            d12 = jnp.concatenate((d12, d12))
 651            if "cells" in inputs:
 652                pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
 653
 654        if "nblist_skin" in state:
 655            # edge_mask_skin = edge_mask
 656            edge_src_skin = edge_src
 657            edge_dst_skin = edge_dst
 658            if "cells" in inputs:
 659                pbc_shifts_skin = pbc_shifts
 660            max_pairs_skin = state.get("npairs_skin", 1)
 661            mask = d12 < self.cutoff**2
 662            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
 663                mask,
 664                max_pairs_skin,
 665                (edge_src, coords.shape[0]),
 666                (edge_dst, coords.shape[0]),
 667                (d12, self.cutoff**2),
 668            )
 669            if "cells" in inputs:
 670                pbc_shifts = (
 671                    jnp.zeros((max_pairs_skin, 3), dtype=pbc_shifts.dtype)
 672                    .at[scatter_idx]
 673                    .set(pbc_shifts, mode="drop")
 674                )
 675            overflow = overflow | (npairs_skin > max_pairs_skin)
 676
 677        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 678        graph_out = {
 679            **graph,
 680            "edge_src": edge_src,
 681            "edge_dst": edge_dst,
 682            "d12": d12,
 683            "overflow": overflow,
 684            "pbc_shifts": pbc_shifts,
 685        }
 686        if "nblist_skin" in state:
 687            graph_out["edge_src_skin"] = edge_src_skin
 688            graph_out["edge_dst_skin"] = edge_dst_skin
 689            if "cells" in inputs:
 690                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 691
 692        if self.k_space and "cells" in inputs:
 693            if "k_points" not in graph:
 694                raise NotImplementedError(
 695                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 696                )
 697        return {**inputs, self.graph_key: graph_out}
 698
 699    @partial(jax.jit, static_argnums=(0,))
 700    def update_skin(self, inputs):
 701        """update the nblist without recomputing the full nblist"""
 702        graph = inputs[self.graph_key]
 703
 704        edge_src_skin = graph["edge_src_skin"]
 705        edge_dst_skin = graph["edge_dst_skin"]
 706        coords = inputs["coordinates"]
 707        vec = coords.at[edge_dst_skin].get(
 708            mode="fill", fill_value=self.cutoff
 709        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
 710
 711        if "cells" in inputs:
 712            pbc_shifts_skin = graph["pbc_shifts_skin"]
 713            cells = inputs["cells"]
 714            if cells.shape[0] == 1:
 715                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
 716            else:
 717                batch_index_vec = inputs["batch_index"][edge_src_skin]
 718                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
 719
 720        nat = coords.shape[0]
 721        d12 = jnp.sum(vec**2, axis=-1)
 722        mask = d12 < self.cutoff**2
 723        max_pairs = graph["edge_src"].shape[0]
 724        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 725            mask,
 726            max_pairs,
 727            (edge_src_skin, nat),
 728            (edge_dst_skin, nat),
 729            (d12, self.cutoff**2),
 730        )
 731        if "cells" in inputs:
 732            pbc_shifts = (
 733                jnp.zeros((max_pairs, 3), dtype=pbc_shifts_skin.dtype)
 734                .at[scatter_idx]
 735                .set(pbc_shifts_skin)
 736            )
 737
 738        overflow = graph.get("overflow", False) | (npairs > max_pairs)
 739        graph_out = {
 740            **graph,
 741            "edge_src": edge_src,
 742            "edge_dst": edge_dst,
 743            "d12": d12,
 744            "overflow": overflow,
 745        }
 746        if "cells" in inputs:
 747            graph_out["pbc_shifts"] = pbc_shifts
 748
 749        if self.k_space and "cells" in inputs:
 750            if "k_points" not in graph:
 751                raise NotImplementedError(
 752                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 753                )
 754
 755        return {**inputs, self.graph_key: graph_out}
 756
 757
 758class GraphProcessor(nn.Module):
 759    """Process a pre-generated graph
 760
 761    The pre-generated graph should contain the following keys:
 762    - edge_src: source indices of the edges
 763    - edge_dst: destination indices of the edges
 764    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
 765
 766    This module is automatically added to a FENNIX model when a GraphGenerator is used.
 767
 768    """
 769
 770    cutoff: float
 771    """Cutoff distance for the graph."""
 772    graph_key: str = "graph"
 773    """Key of the graph in the outputs."""
 774    switch_params: dict = dataclasses.field(default_factory=dict)
 775    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 776
 777    @nn.compact
 778    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 779        graph = inputs[self.graph_key]
 780        coords = inputs["coordinates"]
 781        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 782        # edge_mask = edge_src < coords.shape[0]
 783        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
 784            edge_src
 785        ].get(mode="fill", fill_value=0.0)
 786        if "cells" in inputs:
 787            cells = inputs["cells"]
 788            if cells.shape[0] == 1:
 789                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
 790            else:
 791                batch_index_vec = inputs["batch_index"][edge_src]
 792                vec = vec + jax.vmap(jnp.dot)(
 793                    graph["pbc_shifts"], cells[batch_index_vec]
 794                )
 795
 796        d2  = jnp.sum(vec**2, axis=-1)
 797        distances = safe_sqrt(d2)
 798        edge_mask = distances < self.cutoff
 799
 800        switch = SwitchFunction(
 801            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
 802        )((distances, edge_mask))
 803
 804        graph_out = {
 805            **graph,
 806            "vec": vec,
 807            "distances": distances,
 808            "switch": switch,
 809            "edge_mask": edge_mask,
 810        }
 811
 812        if "alch_group" in inputs:
 813            alch_group = inputs["alch_group"]
 814            lambda_e = inputs["alch_elambda"]
 815            lambda_v = inputs["alch_vlambda"]
 816            mask = alch_group[edge_src] == alch_group[edge_dst]
 817            graph_out["switch_raw"] = switch
 818            graph_out["switch"] = jnp.where(
 819                mask,
 820                switch,
 821                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
 822            )
 823            graph_out["distances_raw"] = distances
 824            if "alch_softcore_e" in inputs:
 825                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
 826            else:
 827                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
 828
 829            graph_out["distances"] = jnp.where(
 830                mask,
 831                distances,
 832                safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2))
 833            )  
 834
 835
 836        return {**inputs, self.graph_key: graph_out}
 837
 838
 839@dataclasses.dataclass(frozen=True)
 840class GraphFilter:
 841    """Filter a graph based on a cutoff distance
 842
 843    FPID: GRAPH_FILTER
 844    """
 845
 846    cutoff: float
 847    """Cutoff distance for the filtering."""
 848    parent_graph: str
 849    """Key of the parent graph in the inputs."""
 850    graph_key: str
 851    """Key of the filtered graph in the outputs."""
 852    remove_hydrogens: int = False
 853    """Remove edges where the source is a hydrogen atom."""
 854    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
 855    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 856    k_space: bool = False
 857    """Generate k-space information for the graph."""
 858    kmax: int = 30
 859    """Maximum number of k-points to consider."""
 860    kthr: float = 1e-6
 861    """Threshold for k-point filtering."""
 862    mult_size: float = 1.05
 863    """Multiplicative factor for resizing the nblist."""
 864
 865    FPID: ClassVar[str] = "GRAPH_FILTER"
 866
 867    def init(self):
 868        return FrozenDict(
 869            {
 870                "npairs": 1,
 871                "nblist_mult_size": self.mult_size,
 872            }
 873        )
 874
 875    def get_processor(self) -> Tuple[nn.Module, Dict]:
 876        return GraphFilterProcessor, {
 877            "cutoff": self.cutoff,
 878            "graph_key": self.graph_key,
 879            "parent_graph": self.parent_graph,
 880            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
 881            "switch_params": self.switch_params,
 882        }
 883
 884    def get_graph_properties(self):
 885        return {
 886            self.graph_key: {
 887                "cutoff": self.cutoff,
 888                "directed": True,
 889                "parent_graph": self.parent_graph,
 890            }
 891        }
 892
 893    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 894        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 895        graph_in = inputs[self.parent_graph]
 896        nat = inputs["species"].shape[0]
 897
 898        new_state = {**state}
 899        state_up = {}
 900        mult_size = state.get("nblist_mult_size", self.mult_size)
 901        assert mult_size >= 1., "nblist_mult_size should be >= 1."
 902
 903        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
 904        d12 = np.array(graph_in["d12"], dtype=np.float32)
 905        if self.remove_hydrogens:
 906            species = inputs["species"]
 907            src_idx = (edge_src < nat).nonzero()[0]
 908            mask = np.zeros(edge_src.shape[0], dtype=bool)
 909            mask[src_idx] = (species > 1)[edge_src[src_idx]]
 910            d12 = np.where(mask, d12, self.cutoff**2)
 911        mask = d12 < self.cutoff**2
 912
 913        max_pairs = state.get("npairs", 1)
 914        idx = np.nonzero(mask)[0]
 915        npairs = idx.shape[0]
 916        if npairs > max_pairs or add_margin:
 917            prev_max_pairs = max_pairs
 918            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 919            state_up["npairs"] = (max_pairs, prev_max_pairs)
 920            new_state["npairs"] = max_pairs
 921
 922        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
 923        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 924        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 925        d12_ = np.full(max_pairs, self.cutoff**2)
 926        filter_indices[:npairs] = idx
 927        edge_src[:npairs] = graph_in["edge_src"][idx]
 928        edge_dst[:npairs] = graph_in["edge_dst"][idx]
 929        d12_[:npairs] = d12[idx]
 930        d12 = d12_
 931
 932        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 933        graph_out = {
 934            **graph,
 935            "edge_src": edge_src,
 936            "edge_dst": edge_dst,
 937            "filter_indices": filter_indices,
 938            "d12": d12,
 939            "overflow": False,
 940        }
 941        if "cells" in inputs:
 942            pbc_shifts = np.zeros((max_pairs, 3), dtype=np.int32)
 943            pbc_shifts[:npairs] = graph_in["pbc_shifts"][idx]
 944            graph_out["pbc_shifts"] = pbc_shifts
 945
 946            if self.k_space:
 947                if "k_points" not in graph:
 948                    ks, _, _, bewald = get_reciprocal_space_parameters(
 949                        inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
 950                    )
 951                graph_out["k_points"] = ks
 952                graph_out["b_ewald"] = bewald
 953
 954        output = {**inputs, self.graph_key: graph_out}
 955        if return_state_update:
 956            return FrozenDict(new_state), output, state_up
 957        return FrozenDict(new_state), output
 958
 959    def check_reallocate(self, state, inputs, parent_overflow=False):
 960        """check for overflow and reallocate nblist if necessary"""
 961        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 962        if not overflow:
 963            return state, {}, inputs, False
 964
 965        add_margin = inputs[self.graph_key].get("overflow", False)
 966        state, inputs, state_up = self(
 967            state, inputs, return_state_update=True, add_margin=add_margin
 968        )
 969        return state, state_up, inputs, True
 970
 971    @partial(jax.jit, static_argnums=(0, 1))
 972    def process(self, state, inputs):
 973        """filter a nblist on accelerator with jax and precomputed shapes"""
 974        graph_in = inputs[self.parent_graph]
 975        if state is None:
 976            # skin update mode
 977            graph = inputs[self.graph_key]
 978            max_pairs = graph["edge_src"].shape[0]
 979        else:
 980            max_pairs = state.get("npairs", 1)
 981
 982        max_pairs_in = graph_in["edge_src"].shape[0]
 983        nat = inputs["species"].shape[0]
 984
 985        edge_src = graph_in["edge_src"]
 986        d12 = graph_in["d12"]
 987        if self.remove_hydrogens:
 988            species = inputs["species"]
 989            mask = (species > 1)[edge_src]
 990            d12 = jnp.where(mask, d12, self.cutoff**2)
 991        mask = d12 < self.cutoff**2
 992
 993        (edge_src, edge_dst, d12, filter_indices), scatter_idx, npairs = mask_filter_1d(
 994            mask,
 995            max_pairs,
 996            (edge_src, nat),
 997            (graph_in["edge_dst"], nat),
 998            (d12, self.cutoff**2),
 999            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
1000        )
1001
1002        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
1003        overflow = graph.get("overflow", False) | (npairs > max_pairs)
1004        graph_out = {
1005            **graph,
1006            "edge_src": edge_src,
1007            "edge_dst": edge_dst,
1008            "filter_indices": filter_indices,
1009            "d12": d12,
1010            "overflow": overflow,
1011        }
1012
1013        if "cells" in inputs:
1014            pbc_shifts = graph_in["pbc_shifts"]
1015            pbc_shifts = (
1016                jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype)
1017                .at[scatter_idx].set(pbc_shifts, mode="drop")
1018            )
1019            graph_out["pbc_shifts"] = pbc_shifts
1020            if self.k_space:
1021                if "k_points" not in graph:
1022                    raise NotImplementedError(
1023                        "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
1024                    )
1025
1026        return {**inputs, self.graph_key: graph_out}
1027
1028    @partial(jax.jit, static_argnums=(0,))
1029    def update_skin(self, inputs):
1030        return self.process(None, inputs)
1031
1032
1033class GraphFilterProcessor(nn.Module):
1034    """Filter processing for a pre-generated graph
1035
1036    This module is automatically added to a FENNIX model when a GraphFilter is used.
1037    """
1038
1039    cutoff: float
1040    """Cutoff distance for the filtering."""
1041    graph_key: str
1042    """Key of the filtered graph in the inputs."""
1043    parent_graph: str
1044    """Key of the parent graph in the inputs."""
1045    switch_params: dict = dataclasses.field(default_factory=dict)
1046    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
1047
1048    @nn.compact
1049    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1050        graph_in = inputs[self.parent_graph]
1051        graph = inputs[self.graph_key]
1052
1053        d_key = "distances_raw" if "distances_raw" in graph else "distances"
1054
1055        if graph_in["vec"].shape[0] == 0:
1056            vec = graph_in["vec"]
1057            distances = graph_in[d_key]
1058            filter_indices = jnp.asarray([], dtype=jnp.int32)
1059        else:
1060            filter_indices = graph["filter_indices"]
1061            vec = (
1062                graph_in["vec"]
1063                .at[filter_indices]
1064                .get(mode="fill", fill_value=self.cutoff)
1065            )
1066            distances = (
1067                graph_in[d_key]
1068                .at[filter_indices]
1069                .get(mode="fill", fill_value=self.cutoff)
1070            )
1071
1072        edge_mask = distances < self.cutoff
1073        switch = SwitchFunction(
1074            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
1075        )((distances, edge_mask))
1076
1077        graph_out = {
1078            **graph,
1079            "vec": vec,
1080            "distances": distances,
1081            "switch": switch,
1082            "filter_indices": filter_indices,
1083            "edge_mask": edge_mask,
1084        }
1085
1086        if "alch_group" in inputs:
1087            edge_src=graph["edge_src"]
1088            edge_dst=graph["edge_dst"]
1089            alch_group = inputs["alch_group"]
1090            lambda_e = inputs["alch_elambda"]
1091            lambda_v = inputs["alch_vlambda"]
1092            mask = alch_group[edge_src] == alch_group[edge_dst]
1093            graph_out["switch_raw"] = switch
1094            graph_out["switch"] = jnp.where(
1095                mask,
1096                switch,
1097                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
1098            )
1099
1100            graph_out["distances_raw"] = distances
1101            if "alch_softcore_e" in inputs:
1102                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
1103            else:
1104                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
1105
1106            graph_out["distances"] = jnp.where(
1107                mask,
1108                distances,
1109                safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2))
1110            )  
1111            
1112
1113        return {**inputs, self.graph_key: graph_out}
1114
1115
1116@dataclasses.dataclass(frozen=True)
1117class GraphAngularExtension:
1118    """Add angles list to a graph
1119
1120    FPID: GRAPH_ANGULAR_EXTENSION
1121    """
1122
1123    mult_size: float = 1.05
1124    """Multiplicative factor for resizing the nblist."""
1125    add_neigh: int = 5
1126    """Additional neighbors to add to the nblist when resizing."""
1127    graph_key: str = "graph"
1128    """Key of the graph in the inputs."""
1129
1130    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1131
1132    def init(self):
1133        return FrozenDict(
1134            {
1135                "nangles": 0,
1136                "nblist_mult_size": self.mult_size,
1137                "max_neigh": self.add_neigh,
1138                "add_neigh": self.add_neigh,
1139            }
1140        )
1141
1142    def get_processor(self) -> Tuple[nn.Module, Dict]:
1143        return GraphAngleProcessor, {
1144            "graph_key": self.graph_key,
1145            "name": f"{self.graph_key}_AngleProcessor",
1146        }
1147
1148    def get_graph_properties(self):
1149        return {
1150            self.graph_key: {
1151                "has_angles": True,
1152            }
1153        }
1154
1155    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1156        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1157        graph = inputs[self.graph_key]
1158        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1159
1160        new_state = {**state}
1161        state_up = {}
1162        mult_size = state.get("nblist_mult_size", self.mult_size)
1163        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1164
1165        ### count number of neighbors
1166        nat = inputs["species"].shape[0]
1167        count = np.zeros(nat + 1, dtype=np.int32)
1168        np.add.at(count, edge_src, 1)
1169        max_count = int(np.max(count[:-1]))
1170
1171        ### get sizes
1172        max_neigh = state.get("max_neigh", self.add_neigh)
1173        nedge = edge_src.shape[0]
1174        if max_count > max_neigh or add_margin:
1175            prev_max_neigh = max_neigh
1176            max_neigh = max(max_count, max_neigh) + state.get(
1177                "add_neigh", self.add_neigh
1178            )
1179            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1180            new_state["max_neigh"] = max_neigh
1181
1182        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1183
1184        nedge = edge_src.shape[0]
1185
1186        ### sort edge_src
1187        idx_sort = np.argsort(edge_src)
1188        edge_src_sorted = edge_src[idx_sort]
1189
1190        ### map sparse to dense nblist
1191        offset = np.tile(np.arange(max_count), nat)
1192        if max_count * nat >= nedge:
1193            offset = np.tile(np.arange(max_count), nat)[:nedge]
1194        else:
1195            offset = np.zeros(nedge, dtype=np.int32)
1196            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1197
1198        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1199        mask = edge_src_sorted < nat
1200        indices = edge_src_sorted * max_count + offset
1201        indices = indices[mask]
1202        idx_sort = idx_sort[mask]
1203        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1204        edge_idx[indices] = idx_sort
1205        edge_idx = edge_idx.reshape(nat, max_count)
1206
1207        ### find all triplet for each atom center
1208        local_src, local_dst = np.triu_indices(max_count, 1)
1209        angle_src = edge_idx[:, local_src].flatten()
1210        angle_dst = edge_idx[:, local_dst].flatten()
1211
1212        ### mask for valid angles
1213        mask1 = angle_src < nedge
1214        mask2 = angle_dst < nedge
1215        angle_mask = mask1 & mask2
1216
1217        max_angles = state.get("nangles", 0)
1218        idx = np.nonzero(angle_mask)[0]
1219        nangles = idx.shape[0]
1220        if nangles > max_angles or add_margin:
1221            max_angles_prev = max_angles
1222            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1223            state_up["nangles"] = (max_angles, max_angles_prev)
1224            new_state["nangles"] = max_angles
1225
1226        ## filter angles to sparse representation
1227        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1228        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1229        angle_src_[:nangles] = angle_src[idx]
1230        angle_dst_[:nangles] = angle_dst[idx]
1231
1232        central_atom = np.full(max_angles, nat, dtype=np.int32)
1233        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1234
1235        ## update graph
1236        output = {
1237            **inputs,
1238            self.graph_key: {
1239                **graph,
1240                "angle_src": angle_src_,
1241                "angle_dst": angle_dst_,
1242                "central_atom": central_atom,
1243                "angle_overflow": False,
1244                "max_neigh": max_neigh,
1245                "__max_neigh_array": max_neigh_arr,
1246            },
1247        }
1248
1249        if return_state_update:
1250            return FrozenDict(new_state), output, state_up
1251        return FrozenDict(new_state), output
1252
1253    def check_reallocate(self, state, inputs, parent_overflow=False):
1254        """check for overflow and reallocate nblist if necessary"""
1255        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1256        if not overflow:
1257            return state, {}, inputs, False
1258
1259        add_margin = inputs[self.graph_key]["angle_overflow"]
1260        state, inputs, state_up = self(
1261            state, inputs, return_state_update=True, add_margin=add_margin
1262        )
1263        return state, state_up, inputs, True
1264
1265    @partial(jax.jit, static_argnums=(0, 1))
1266    def process(self, state, inputs):
1267        """build angle nblist on accelerator with jax and precomputed shapes"""
1268        graph = inputs[self.graph_key]
1269        edge_src = graph["edge_src"]
1270
1271        ### count number of neighbors
1272        nat = inputs["species"].shape[0]
1273        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1274        max_count = jnp.max(count)
1275
1276        ### get sizes
1277        if state is None:
1278            max_neigh_arr = graph["__max_neigh_array"]
1279            max_neigh = max_neigh_arr.shape[0]
1280            prev_nangles = graph["angle_src"].shape[0]
1281        else:
1282            max_neigh = state.get("max_neigh", self.add_neigh)
1283            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1284            prev_nangles = state.get("nangles", 0)
1285
1286        nedge = edge_src.shape[0]
1287
1288        ### sort edge_src
1289        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1290        edge_src_sorted = edge_src[idx_sort]
1291
1292        ### map sparse to dense nblist
1293        if max_neigh * nat < nedge:
1294            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1295        offset = jnp.asarray(
1296            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1297        )
1298        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1299        indices = edge_src_sorted * max_neigh + offset
1300        edge_idx = (
1301            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1302            .at[indices]
1303            .set(idx_sort, mode="drop")
1304            .reshape(nat, max_neigh)
1305        )
1306
1307        ### find all triplet for each atom center
1308        local_src, local_dst = np.triu_indices(max_neigh, 1)
1309        angle_src = edge_idx[:, local_src].flatten()
1310        angle_dst = edge_idx[:, local_dst].flatten()
1311
1312        ### mask for valid angles
1313        mask1 = angle_src < nedge
1314        mask2 = angle_dst < nedge
1315        angle_mask = mask1 & mask2
1316
1317        ## filter angles to sparse representation
1318        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1319            angle_mask,
1320            prev_nangles,
1321            (angle_src, nedge),
1322            (angle_dst, nedge),
1323        )
1324        ## find central atom
1325        central_atom = edge_src[angle_src]
1326
1327        ## check for overflow
1328        angle_overflow = nangles > prev_nangles
1329        neigh_overflow = max_count > max_neigh
1330        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1331
1332        ## update graph
1333        output = {
1334            **inputs,
1335            self.graph_key: {
1336                **graph,
1337                "angle_src": angle_src,
1338                "angle_dst": angle_dst,
1339                "central_atom": central_atom,
1340                "angle_overflow": overflow,
1341                # "max_neigh": max_neigh,
1342                "__max_neigh_array": max_neigh_arr,
1343            },
1344        }
1345
1346        return output
1347
1348    @partial(jax.jit, static_argnums=(0,))
1349    def update_skin(self, inputs):
1350        return self.process(None, inputs)
1351
1352
1353class GraphAngleProcessor(nn.Module):
1354    """Process a pre-generated graph to compute angles
1355
1356    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1357
1358    """
1359
1360    graph_key: str
1361    """Key of the graph in the inputs."""
1362
1363    @nn.compact
1364    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1365        graph = inputs[self.graph_key]
1366        distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"]
1367        vec = graph["vec"]
1368        angle_src = graph["angle_src"]
1369        angle_dst = graph["angle_dst"]
1370
1371        dir = vec / jnp.clip(distances[:, None], min=1.0e-5)
1372        cos_angles = (
1373            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1374            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1375        ).sum(axis=-1)
1376
1377        angles = jnp.arccos(0.95 * cos_angles)
1378
1379        return {
1380            **inputs,
1381            self.graph_key: {
1382                **graph,
1383                # "cos_angles": cos_angles,
1384                "angles": angles,
1385                # "angle_mask": angle_mask,
1386            },
1387        }
1388
1389
1390@dataclasses.dataclass(frozen=True)
1391class SpeciesIndexer:
1392    """Build an index that splits atomic arrays by species.
1393
1394    FPID: SPECIES_INDEXER
1395
1396    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1397    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1398
1399    """
1400
1401    output_key: str = "species_index"
1402    """Key for the output dictionary."""
1403    species_order: Optional[str] = None
1404    """Comma separated list of species in the order they should be indexed."""
1405    add_atoms: int = 0
1406    """Additional atoms to add to the sizes."""
1407    add_atoms_margin: int = 10
1408    """Additional atoms to add to the sizes when adding margin."""
1409
1410    FPID: ClassVar[str] = "SPECIES_INDEXER"
1411
1412    def init(self):
1413        return FrozenDict(
1414            {
1415                "sizes": {},
1416            }
1417        )
1418
1419    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1420        species = np.array(inputs["species"], dtype=np.int32)
1421        nat = species.shape[0]
1422        set_species, counts = np.unique(species, return_counts=True)
1423
1424        new_state = {**state}
1425        state_up = {}
1426
1427        sizes = state.get("sizes", FrozenDict({}))
1428        new_sizes = {**sizes}
1429        up_sizes = False
1430        counts_dict = {}
1431        for s, c in zip(set_species, counts):
1432            if s <= 0:
1433                continue
1434            counts_dict[s] = c
1435            if c > sizes.get(s, 0):
1436                up_sizes = True
1437                add_atoms = state.get("add_atoms", self.add_atoms)
1438                if add_margin:
1439                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1440                new_sizes[s] = c + add_atoms
1441
1442        new_sizes = FrozenDict(new_sizes)
1443        if up_sizes:
1444            state_up["sizes"] = (new_sizes, sizes)
1445            new_state["sizes"] = new_sizes
1446
1447        if self.species_order is not None:
1448            species_order = [el.strip() for el in self.species_order.split(",")]
1449            max_size_prev = state.get("max_size", 0)
1450            max_size = max(new_sizes.values())
1451            if max_size > max_size_prev:
1452                state_up["max_size"] = (max_size, max_size_prev)
1453                new_state["max_size"] = max_size
1454                max_size_prev = max_size
1455
1456            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1457            for i, el in enumerate(species_order):
1458                s = PERIODIC_TABLE_REV_IDX[el]
1459                if s in counts_dict.keys():
1460                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1461        else:
1462            species_index = {
1463                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1464                for s, c in new_sizes.items()
1465            }
1466            for s, c in zip(set_species, counts):
1467                if s <= 0:
1468                    continue
1469                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1470
1471        output = {
1472            **inputs,
1473            self.output_key: species_index,
1474            self.output_key + "_overflow": False,
1475        }
1476
1477        if return_state_update:
1478            return FrozenDict(new_state), output, state_up
1479        return FrozenDict(new_state), output
1480
1481    def check_reallocate(self, state, inputs, parent_overflow=False):
1482        """check for overflow and reallocate nblist if necessary"""
1483        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1484        if not overflow:
1485            return state, {}, inputs, False
1486
1487        add_margin = inputs[self.output_key + "_overflow"]
1488        state, inputs, state_up = self(
1489            state, inputs, return_state_update=True, add_margin=add_margin
1490        )
1491        return state, state_up, inputs, True
1492        # return state, {}, inputs, parent_overflow
1493
1494    @partial(jax.jit, static_argnums=(0, 1))
1495    def process(self, state, inputs):
1496        # assert (
1497        #     self.output_key in inputs
1498        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1499
1500        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1501        if self.output_key in inputs and not recompute_species_index:
1502            return inputs
1503
1504        if state is None:
1505            raise ValueError("Species Indexer state must be provided on accelerator.")
1506
1507        species = inputs["species"]
1508        nat = species.shape[0]
1509
1510        sizes = state["sizes"]
1511
1512        if self.species_order is not None:
1513            species_order = [el.strip() for el in self.species_order.split(",")]
1514            max_size = state["max_size"]
1515
1516            species_index = jnp.full(
1517                (len(species_order), max_size), nat, dtype=jnp.int32
1518            )
1519            for i, el in enumerate(species_order):
1520                s = PERIODIC_TABLE_REV_IDX[el]
1521                if s in sizes.keys():
1522                    c = sizes[s]
1523                    species_index = species_index.at[i, :].set(
1524                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1525                    )
1526                # if s in counts_dict.keys():
1527                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1528        else:
1529            # species_index = {
1530            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1531            # for s, c in sizes.items()
1532            # }
1533            species_index = {}
1534            overflow = False
1535            natcount = 0
1536            for s, c in sizes.items():
1537                mask = species == s
1538                new_size = jnp.sum(mask)
1539                natcount = natcount + new_size
1540                overflow = overflow | (new_size > c)  # check if sizes are correct
1541                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1542                    species == s, size=c, fill_value=nat
1543                )[0]
1544
1545            mask = species <= 0
1546            new_size = jnp.sum(mask)
1547            natcount = natcount + new_size
1548            overflow = overflow | (
1549                natcount < species.shape[0]
1550            )  # check if any species missing
1551
1552        return {
1553            **inputs,
1554            self.output_key: species_index,
1555            self.output_key + "_overflow": overflow,
1556        }
1557
1558    @partial(jax.jit, static_argnums=(0,))
1559    def update_skin(self, inputs):
1560        return self.process(None, inputs)
1561
1562@dataclasses.dataclass(frozen=True)
1563class BlockIndexer:
1564    """Build an index that splits atomic arrays by chemical blocks.
1565
1566    FPID: BLOCK_INDEXER
1567
1568    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1569    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1570
1571    """
1572
1573    output_key: str = "block_index"
1574    """Key for the output dictionary."""
1575    add_atoms: int = 0
1576    """Additional atoms to add to the sizes."""
1577    add_atoms_margin: int = 10
1578    """Additional atoms to add to the sizes when adding margin."""
1579    split_CNOPSSe: bool = False
1580
1581    FPID: ClassVar[str] = "BLOCK_INDEXER"
1582
1583    def init(self):
1584        return FrozenDict(
1585            {
1586                "sizes": {},
1587            }
1588        )
1589
1590    def build_chemical_blocks(self):
1591        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1592        if self.split_CNOPSSe:
1593            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1594            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1595        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1596        if self.split_CNOPSSe:
1597            _CHEMICAL_BLOCKS[6] = 1
1598            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1599            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1600            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1601            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1602            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1603        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1604
1605    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1606        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1607
1608        species = np.array(inputs["species"], dtype=np.int32)
1609        blocks = _CHEMICAL_BLOCKS[species]
1610        nat = species.shape[0]
1611        set_blocks, counts = np.unique(blocks, return_counts=True)
1612
1613        new_state = {**state}
1614        state_up = {}
1615
1616        sizes = state.get("sizes", FrozenDict({}))
1617        new_sizes = {**sizes}
1618        up_sizes = False
1619        for s, c in zip(set_blocks, counts):
1620            if s < 0:
1621                continue
1622            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1623            if c > sizes.get(key, 0):
1624                up_sizes = True
1625                add_atoms = state.get("add_atoms", self.add_atoms)
1626                if add_margin:
1627                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1628                new_sizes[key] = c + add_atoms
1629
1630        new_sizes = FrozenDict(new_sizes)
1631        if up_sizes:
1632            state_up["sizes"] = (new_sizes, sizes)
1633            new_state["sizes"] = new_sizes
1634
1635        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1636        for (_,n), c in new_sizes.items():
1637            block_index[n] = np.full(c, nat, dtype=np.int32)
1638        # block_index = {
1639            # n: np.full(c, nat, dtype=np.int32)
1640            # for (_,n), c in new_sizes.items()
1641        # }
1642        for s, c in zip(set_blocks, counts):
1643            if s < 0:
1644                continue
1645            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1646
1647        output = {
1648            **inputs,
1649            self.output_key: block_index,
1650            self.output_key + "_overflow": False,
1651        }
1652
1653        if return_state_update:
1654            return FrozenDict(new_state), output, state_up
1655        return FrozenDict(new_state), output
1656
1657    def check_reallocate(self, state, inputs, parent_overflow=False):
1658        """check for overflow and reallocate nblist if necessary"""
1659        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1660        if not overflow:
1661            return state, {}, inputs, False
1662
1663        add_margin = inputs[self.output_key + "_overflow"]
1664        state, inputs, state_up = self(
1665            state, inputs, return_state_update=True, add_margin=add_margin
1666        )
1667        return state, state_up, inputs, True
1668        # return state, {}, inputs, parent_overflow
1669
1670    @partial(jax.jit, static_argnums=(0, 1))
1671    def process(self, state, inputs):
1672        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1673        # assert (
1674        #     self.output_key in inputs
1675        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1676
1677        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1678        if self.output_key in inputs and not recompute_species_index:
1679            return inputs
1680
1681        if state is None:
1682            raise ValueError("Block Indexer state must be provided on accelerator.")
1683
1684        species = inputs["species"]
1685        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1686        nat = species.shape[0]
1687
1688        sizes = state["sizes"]
1689
1690        # species_index = {
1691        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1692        # for s, c in sizes.items()
1693        # }
1694        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1695        overflow = False
1696        natcount = 0
1697        for (s,name), c in sizes.items():
1698            mask = blocks == s
1699            new_size = jnp.sum(mask)
1700            natcount = natcount + new_size
1701            overflow = overflow | (new_size > c)  # check if sizes are correct
1702            block_index[name] = jnp.nonzero(
1703                mask, size=c, fill_value=nat
1704            )[0]
1705
1706        mask = blocks < 0
1707        new_size = jnp.sum(mask)
1708        natcount = natcount + new_size
1709        overflow = overflow | (
1710            natcount < species.shape[0]
1711        )  # check if any species missing
1712
1713        return {
1714            **inputs,
1715            self.output_key: block_index,
1716            self.output_key + "_overflow": overflow,
1717        }
1718
1719    @partial(jax.jit, static_argnums=(0,))
1720    def update_skin(self, inputs):
1721        return self.process(None, inputs)
1722
1723
1724@dataclasses.dataclass(frozen=True)
1725class AtomPadding:
1726    """Pad atomic arrays to a fixed size."""
1727
1728    mult_size: float = 1.2
1729    """Multiplicative factor for resizing the atomic arrays."""
1730    add_sys: int = 0
1731
1732    def init(self):
1733        return {"prev_nat": 0, "prev_nsys": 0}
1734
1735    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1736        species = inputs["species"]
1737        nat = species.shape[0]
1738
1739        prev_nat = state.get("prev_nat", 0)
1740        prev_nat_ = prev_nat
1741        if nat > prev_nat_:
1742            prev_nat_ = int(self.mult_size * nat) + 1
1743
1744        nsys = len(inputs["natoms"])
1745        prev_nsys = state.get("prev_nsys", 0)
1746        prev_nsys_ = prev_nsys
1747        if nsys > prev_nsys_:
1748            prev_nsys_ = nsys + self.add_sys
1749
1750        add_atoms = prev_nat_ - nat
1751        add_sys = prev_nsys_ - nsys  + 1
1752        output = {**inputs}
1753        if add_atoms > 0:
1754            for k, v in inputs.items():
1755                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1756                    if v.shape[0] == nat:
1757                        output[k] = np.append(
1758                            v,
1759                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1760                            axis=0,
1761                        )
1762                    elif v.shape[0] == nsys:
1763                        if k == "cells":
1764                            output[k] = np.append(
1765                                v,
1766                                1000
1767                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1768                                    add_sys, axis=0
1769                                ),
1770                                axis=0,
1771                            )
1772                        else:
1773                            output[k] = np.append(
1774                                v,
1775                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1776                                axis=0,
1777                            )
1778            output["natoms"] = np.append(
1779                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1780            )
1781            output["species"] = np.append(
1782                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1783            )
1784            output["batch_index"] = np.append(
1785                inputs["batch_index"],
1786                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1787            )
1788            if "system_index" in inputs:
1789                output["system_index"] = np.append(
1790                    inputs["system_index"],
1791                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1792                )
1793
1794        output["true_atoms"] = output["species"] > 0
1795        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1796
1797        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1798
1799        return FrozenDict(state), output
1800
1801
1802def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1803    """Remove padding from atomic arrays."""
1804    if "true_atoms" not in inputs:
1805        return inputs
1806
1807    species = np.asarray(inputs["species"])
1808    true_atoms = np.asarray(inputs["true_atoms"])
1809    true_sys = np.asarray(inputs["true_sys"])
1810    natall = species.shape[0]
1811    nat = np.argmax(species <= 0)
1812    if nat == 0:
1813        return inputs
1814
1815    natoms = inputs["natoms"]
1816    nsysall = len(natoms)
1817
1818    output = {**inputs}
1819    for k, v in inputs.items():
1820        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1821            v = np.asarray(v)
1822            if v.ndim == 0:
1823                output[k] = v
1824            elif v.shape[0] == natall:
1825                output[k] = v[true_atoms]
1826            elif v.shape[0] == nsysall:
1827                output[k] = v[true_sys]
1828    del output["true_sys"]
1829    del output["true_atoms"]
1830    return output
1831
1832
1833def check_input(inputs):
1834    """Check the input dictionary for required keys and types."""
1835    assert "species" in inputs, "species must be provided"
1836    assert "coordinates" in inputs, "coordinates must be provided"
1837    species = inputs["species"].astype(np.int32)
1838    ifake = np.argmax(species <= 0)
1839    if ifake > 0:
1840        assert np.all(species[:ifake] > 0), "species must be positive"
1841    nat = inputs["species"].shape[0]
1842
1843    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1844    batch_index = inputs.get(
1845        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1846    ).astype(np.int32)
1847    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1848    if "cells" in inputs:
1849        cells = inputs["cells"]
1850        if "reciprocal_cells" not in inputs:
1851            reciprocal_cells = np.linalg.inv(cells)
1852        else:
1853            reciprocal_cells = inputs["reciprocal_cells"]
1854        if cells.ndim == 2:
1855            cells = cells[None, :, :]
1856        if reciprocal_cells.ndim == 2:
1857            reciprocal_cells = reciprocal_cells[None, :, :]
1858        output["cells"] = cells
1859        output["reciprocal_cells"] = reciprocal_cells
1860
1861    return output
1862
1863
1864def convert_to_jax(data):
1865    """Convert a numpy arrays to jax arrays in a pytree."""
1866
1867    def convert(x):
1868        if isinstance(x, np.ndarray):
1869            # if x.dtype == np.float64:
1870            #     return jnp.asarray(x, dtype=jnp.float32)
1871            return jnp.asarray(x)
1872        return x
1873
1874    return jax.tree_util.tree_map(convert, data)
1875
1876
1877class JaxConverter(nn.Module):
1878    """Convert numpy arrays to jax arrays in a pytree."""
1879
1880    def __call__(self, data):
1881        return convert_to_jax(data)
1882
1883def convert_to_numpy(data):
1884    """Convert jax arrays to numpy arrays in a pytree."""
1885
1886    def convert(x):
1887        if isinstance(x, jax.Array):
1888            return np.array(x)
1889        return x
1890
1891    return jax.tree_util.tree_map(convert, data)
1892
1893
1894@dataclasses.dataclass(frozen=True)
1895class PreprocessingChain:
1896    """Chain of preprocessing layers."""
1897
1898    layers: Tuple[Callable[..., Dict[str, Any]]]
1899    """Preprocessing layers."""
1900    use_atom_padding: bool = False
1901    """Add an AtomPadding layer at the beginning of the chain."""
1902    atom_padder: AtomPadding = AtomPadding()
1903    """AtomPadding layer."""
1904
1905    def __post_init__(self):
1906        if not isinstance(self.layers, Sequence):
1907            raise ValueError(
1908                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1909            )
1910        if not self.layers:
1911            raise ValueError(f"Error: no Preprocessing layers were provided.")
1912
1913    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1914        do_check_input = state.get("check_input", True)
1915        if do_check_input:
1916            inputs = check_input(inputs)
1917        new_state = {**state}
1918        if self.use_atom_padding:
1919            s, inputs = self.atom_padder(state["padder_state"], inputs)
1920            new_state["padder_state"] = s
1921        layer_state = state["layers_state"]
1922        new_layer_state = []
1923        for i,layer in enumerate(self.layers):
1924            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1925            new_layer_state.append(s)
1926        new_state["layers_state"] = tuple(new_layer_state)
1927        return FrozenDict(new_state), convert_to_jax(inputs)
1928
1929    def check_reallocate(self, state, inputs):
1930        new_state = []
1931        state_up = []
1932        layer_state = state["layers_state"]
1933        parent_overflow = False
1934        for i,layer in enumerate(self.layers):
1935            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1936                layer_state[i], inputs, parent_overflow
1937            )
1938            new_state.append(s)
1939            state_up.append(s_up)
1940
1941        if not parent_overflow:
1942            return state, {}, inputs, False
1943        return (
1944            FrozenDict({**state, "layers_state": tuple(new_state)}),
1945            state_up,
1946            inputs,
1947            True,
1948        )
1949
1950    def atom_padding(self, state, inputs):
1951        if self.use_atom_padding:
1952            padder_state,inputs = self.atom_padder(state["padder_state"], inputs)
1953            return FrozenDict({**state,"padder_state": padder_state}), inputs
1954        return state, inputs
1955
1956    @partial(jax.jit, static_argnums=(0, 1))
1957    def process(self, state, inputs):
1958        layer_state = state["layers_state"]
1959        for i,layer in enumerate(self.layers):
1960            inputs = layer.process(layer_state[i], inputs)
1961        return inputs
1962
1963    @partial(jax.jit, static_argnums=(0))
1964    def update_skin(self, inputs):
1965        for layer in self.layers:
1966            inputs = layer.update_skin(inputs)
1967        return inputs
1968
1969    def init(self):
1970        state = {"check_input": True}
1971        if self.use_atom_padding:
1972            state["padder_state"] = self.atom_padder.init()
1973        layer_state = []
1974        for layer in self.layers:
1975            layer_state.append(layer.init())
1976        state["layers_state"] = tuple(layer_state)
1977        return FrozenDict(state)
1978
1979    def init_with_output(self, inputs):
1980        state = self.init()
1981        return self(state, inputs)
1982
1983    def get_processors(self):
1984        processors = []
1985        for layer in self.layers:
1986            if hasattr(layer, "get_processor"):
1987                processors.append(layer.get_processor())
1988        return processors
1989
1990    def get_graphs_properties(self):
1991        properties = {}
1992        for layer in self.layers:
1993            if hasattr(layer, "get_graph_properties"):
1994                properties = deep_update(properties, layer.get_graph_properties())
1995        return properties
1996
1997
1998# PREPROCESSING = {
1999#     "GRAPH": GraphGenerator,
2000#     # "GRAPH_FIXED": GraphGeneratorFixed,
2001#     "GRAPH_FILTER": GraphFilter,
2002#     "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension,
2003#     # "GRAPH_DENSE_EXTENSION": GraphDenseExtension,
2004#     "SPECIES_INDEXER": SpeciesIndexer,
2005# }
@dataclasses.dataclass(frozen=True)
class GraphGenerator:
 22@dataclasses.dataclass(frozen=True)
 23class GraphGenerator:
 24    """Generate a graph from a set of coordinates
 25
 26    FPID: GRAPH
 27
 28    For now, we generate all pairs of atoms and filter based on cutoff.
 29    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
 30    """
 31
 32    cutoff: float
 33    """Cutoff distance for the graph."""
 34    graph_key: str = "graph"
 35    """Key of the graph in the outputs."""
 36    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
 37    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 38    kmax: int = 30
 39    """Maximum number of k-points to consider."""
 40    kthr: float = 1e-6
 41    """Threshold for k-point filtering."""
 42    k_space: bool = False
 43    """Whether to generate k-space information for the graph."""
 44    mult_size: float = 1.05
 45    """Multiplicative factor for resizing the nblist."""
 46    # covalent_cutoff: bool = False
 47
 48    FPID: ClassVar[str] = "GRAPH"
 49
 50    def init(self):
 51        return FrozenDict(
 52            {
 53                "max_nat": 1,
 54                "npairs": 1,
 55                "nblist_mult_size": self.mult_size,
 56            }
 57        )
 58
 59    def get_processor(self) -> Tuple[nn.Module, Dict]:
 60        return GraphProcessor, {
 61            "cutoff": self.cutoff,
 62            "graph_key": self.graph_key,
 63            "switch_params": self.switch_params,
 64            "name": f"{self.graph_key}_Processor",
 65        }
 66
 67    def get_graph_properties(self):
 68        return {
 69            self.graph_key: {
 70                "cutoff": self.cutoff,
 71                "directed": True,
 72            }
 73        }
 74
 75    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 76        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 77        if self.graph_key in inputs:
 78            graph = inputs[self.graph_key]
 79            if "keep_graph" in graph:
 80                return state, inputs
 81
 82        coords = np.array(inputs["coordinates"], dtype=np.float32)
 83        natoms = np.array(inputs["natoms"], dtype=np.int32)
 84        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
 85
 86        new_state = {**state}
 87        state_up = {}
 88
 89        mult_size = state.get("nblist_mult_size", self.mult_size)
 90        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
 91
 92        if natoms.shape[0] == 1:
 93            max_nat = coords.shape[0]
 94            true_max_nat = max_nat
 95        else:
 96            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
 97            true_max_nat = int(np.max(natoms))
 98            if true_max_nat > max_nat:
 99                add_atoms = state.get("add_atoms", 0)
100                new_maxnat = true_max_nat + add_atoms
101                state_up["max_nat"] = (new_maxnat, max_nat)
102                new_state["max_nat"] = new_maxnat
103
104        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
105
106        ### compute indices of all pairs
107        apply_pbc = "cells" in inputs
108        minimage = "minimum_image" in inputs.get("flags", {})
109        include_self_image = apply_pbc and not minimage
110
111        shift = 0 if include_self_image else 1
112        p1, p2 = np.triu_indices(true_max_nat, shift)
113        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
114        pbc_shifts = None
115        if natoms.shape[0] > 1:
116            ## batching => mask irrelevant pairs
117            mask_p12 = (
118                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
119            ).flatten()
120            shift = np.concatenate(
121                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
122            )
123            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
124            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
125
126        if not apply_pbc:
127            ### NO PBC
128            vec = coords[p2] - coords[p1]
129        else:
130            cells = np.array(inputs["cells"], dtype=np.float32)
131            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
132            if minimage:
133                ## MINIMUM IMAGE CONVENTION
134                vec = coords[p2] - coords[p1]
135                if cells.shape[0] == 1:
136                    vecpbc = np.dot(vec, reciprocal_cells[0])
137                    pbc_shifts = -np.round(vecpbc).astype(np.int32)
138                    vec = vec + np.dot(pbc_shifts, cells[0])
139                else:
140                    batch_index_vec = batch_index[p1]
141                    vecpbc = np.einsum(
142                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
143                    )
144                    pbc_shifts = -np.round(vecpbc).astype(np.int32)
145                    vec = vec + np.einsum(
146                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
147                    )
148            else:
149                ### GENERAL PBC
150                ## put all atoms in central box
151                if cells.shape[0] == 1:
152                    coords_pbc = np.dot(coords, reciprocal_cells[0])
153                    at_shifts = -np.floor(coords_pbc).astype(np.int32)
154                    coords_pbc = coords + np.dot(at_shifts, cells[0])
155                else:
156                    coords_pbc = np.einsum(
157                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
158                    )
159                    at_shifts = -np.floor(coords_pbc).astype(np.int32)
160                    coords_pbc = coords + np.einsum(
161                        "aj,aji->ai", at_shifts, cells[batch_index]
162                    )
163                vec = coords_pbc[p2] - coords_pbc[p1]
164
165                ## compute maximum number of repeats
166                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
167                cdinv = cutoff_skin * inv_distances
168                num_repeats_all = np.ceil(cdinv).astype(np.int32)
169                if "true_sys" in inputs:
170                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
171                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
172                num_repeats = np.max(num_repeats_all, axis=0)
173                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
174                if np.any(num_repeats > num_repeats_prev):
175                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
176                    state_up["num_repeats_pbc"] = (
177                        tuple(num_repeats_new),
178                        tuple(num_repeats_prev),
179                    )
180                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
181                ## build all possible shifts
182                cell_shift_pbc = np.array(
183                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
184                    dtype=np.int32,
185                ).T.reshape(-1, 3)
186                ## shift applied to vectors
187                if cells.shape[0] == 1:
188                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
189                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
190                    pbc_shifts = np.broadcast_to(
191                        cell_shift_pbc[None, :, :],
192                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
193                    ).reshape(-1, 3)
194                    p1 = np.broadcast_to(
195                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
196                    ).flatten()
197                    p2 = np.broadcast_to(
198                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
199                    ).flatten()
200                    if natoms.shape[0] > 1:
201                        mask_p12 = np.broadcast_to(
202                            mask_p12[:, None],
203                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
204                        ).flatten()
205                else:
206                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
207
208                    ## get pbc shifts specific to each box
209                    cell_shift_pbc = np.broadcast_to(
210                        cell_shift_pbc[None, :, :],
211                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
212                    )
213                    mask = np.all(
214                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
215                    ).flatten()
216                    idx = np.nonzero(mask)[0]
217                    nshifts = idx.shape[0]
218                    nshifts_prev = state.get("nshifts_pbc", 0)
219                    if nshifts > nshifts_prev or add_margin:
220                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
221                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
222                        new_state["nshifts_pbc"] = nshifts_new
223
224                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
225                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
226
227                    ## get batch shift in the dvec_filter array
228                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
229                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
230
231                    ## compute vectors
232                    batch_index_vec = batch_index[p1]
233                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
234                    vec = vec.repeat(nrep_vec, axis=0)
235                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
236                    nvec_pbc_prev = state.get("nvec_pbc", 0)
237                    if nvec_pbc > nvec_pbc_prev or add_margin:
238                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
239                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
240                        new_state["nvec_pbc"] = nvec_pbc_new
241
242                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
243                    ## get shift index
244                    dshift = np.concatenate(
245                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
246                    ).repeat(nrep_vec)
247                    # ishift = np.arange(dshift.shape[0])-dshift
248                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
249                    icellshift = (
250                        np.arange(dshift.shape[0])
251                        - dshift
252                        + bshift[batch_index_vec].repeat(nrep_vec)
253                    )
254                    # shift vectors
255                    vec = vec + dvec_filter[icellshift]
256                    pbc_shifts = cell_shift_pbc_filter[icellshift]
257
258                    p1 = np.repeat(p1, nrep_vec)
259                    p2 = np.repeat(p2, nrep_vec)
260                    if natoms.shape[0] > 1:
261                        mask_p12 = np.repeat(mask_p12, nrep_vec)
262
263        ## compute distances
264        d12 = (vec**2).sum(axis=-1)
265        if natoms.shape[0] > 1:
266            d12 = np.where(mask_p12, d12, cutoff_skin**2)
267
268        ## filter pairs
269        max_pairs = state.get("npairs", 1)
270        mask = d12 < cutoff_skin**2
271        if include_self_image:
272            mask_self = np.logical_or(p1 != p2, d12 > 1.e-3)
273            mask = np.logical_and(mask, mask_self)
274        idx = np.nonzero(mask)[0]
275        npairs = idx.shape[0]
276        if npairs > max_pairs or add_margin:
277            prev_max_pairs = max_pairs
278            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
279            state_up["npairs"] = (max_pairs, prev_max_pairs)
280            new_state["npairs"] = max_pairs
281
282        nat = coords.shape[0]
283        edge_src = np.full(max_pairs, nat, dtype=np.int32)
284        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
285        d12_ = np.full(max_pairs, cutoff_skin**2)
286        edge_src[:npairs] = p1[idx]
287        edge_dst[:npairs] = p2[idx]
288        d12_[:npairs] = d12[idx]
289        d12 = d12_
290
291        if apply_pbc:
292            pbc_shifts_ = np.zeros((max_pairs, 3), dtype=np.int32)
293            pbc_shifts_[:npairs] = pbc_shifts[idx]
294            pbc_shifts = pbc_shifts_
295            if not minimage:
296                pbc_shifts[:npairs] = (
297                    pbc_shifts[:npairs]
298                    + at_shifts[edge_dst[:npairs]]
299                    - at_shifts[edge_src[:npairs]]
300                )
301
302
303        ## symmetrize
304        if include_self_image:
305            mask_noself = edge_src != edge_dst
306            idx_noself = np.nonzero(mask_noself)[0]
307            npairs_noself = idx_noself.shape[0]
308            max_noself = state.get("npairs_noself", 1)
309            if npairs_noself > max_noself or add_margin:
310                prev_max_noself = max_noself
311                max_noself = int(mult_size * max(npairs_noself, max_noself)) + 1
312                state_up["npairs_noself"] = (max_noself, prev_max_noself)
313                new_state["npairs_noself"] = max_noself
314
315            edge_src_noself = np.full(max_noself, nat, dtype=np.int32)
316            edge_dst_noself = np.full(max_noself, nat, dtype=np.int32)
317            d12_noself = np.full(max_noself, cutoff_skin**2)
318            pbc_shifts_noself = np.zeros((max_noself, 3), dtype=np.int32)
319
320            edge_dst_noself[:npairs_noself] = edge_dst[idx_noself]
321            edge_src_noself[:npairs_noself] = edge_src[idx_noself]
322            d12_noself[:npairs_noself] = d12[idx_noself]
323            pbc_shifts_noself[:npairs_noself] = pbc_shifts[idx_noself]
324            edge_src = np.concatenate((edge_src, edge_dst_noself))
325            edge_dst = np.concatenate((edge_dst, edge_src_noself))
326            d12 = np.concatenate((d12, d12_noself))
327            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts_noself))
328        else:
329            edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
330                (edge_dst, edge_src)
331            )
332            d12 = np.concatenate((d12, d12))
333            if apply_pbc:
334                pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
335
336        if "nblist_skin" in state:
337            edge_src_skin = edge_src
338            edge_dst_skin = edge_dst
339            if apply_pbc:
340                pbc_shifts_skin = pbc_shifts
341            max_pairs_skin = state.get("npairs_skin", 1)
342            mask = d12 < self.cutoff**2
343            idx = np.nonzero(mask)[0]
344            npairs_skin = idx.shape[0]
345            if npairs_skin > max_pairs_skin or add_margin:
346                prev_max_pairs_skin = max_pairs_skin
347                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
348                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
349                new_state["npairs_skin"] = max_pairs_skin
350            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
351            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
352            d12_ = np.full(max_pairs_skin, self.cutoff**2)
353            edge_src[:npairs_skin] = edge_src_skin[idx]
354            edge_dst[:npairs_skin] = edge_dst_skin[idx]
355            d12_[:npairs_skin] = d12[idx]
356            d12 = d12_
357            if apply_pbc:
358                pbc_shifts = np.zeros((max_pairs_skin, 3), dtype=np.int32)
359                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
360
361        graph = inputs.get(self.graph_key, {})
362        graph_out = {
363            **graph,
364            "edge_src": edge_src,
365            "edge_dst": edge_dst,
366            "d12": d12,
367            "overflow": False,
368            "pbc_shifts": pbc_shifts,
369        }
370        if "nblist_skin" in state:
371            graph_out["edge_src_skin"] = edge_src_skin
372            graph_out["edge_dst_skin"] = edge_dst_skin
373            if apply_pbc:
374                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
375
376        if self.k_space and apply_pbc:
377            if "k_points" not in graph:
378                ks, _, _, bewald = get_reciprocal_space_parameters(
379                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
380                )
381            graph_out["k_points"] = ks
382            graph_out["b_ewald"] = bewald
383
384        output = {**inputs, self.graph_key: graph_out}
385
386        if return_state_update:
387            return FrozenDict(new_state), output, state_up
388        return FrozenDict(new_state), output
389
390    def check_reallocate(self, state, inputs, parent_overflow=False):
391        """check for overflow and reallocate nblist if necessary"""
392        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
393        if not overflow:
394            return state, {}, inputs, False
395
396        add_margin = inputs[self.graph_key].get("overflow", False)
397        state, inputs, state_up = self(
398            state, inputs, return_state_update=True, add_margin=add_margin
399        )
400        return state, state_up, inputs, True
401
402    @partial(jax.jit, static_argnums=(0, 1))
403    def process(self, state, inputs):
404        """build a nblist on accelerator with jax and precomputed shapes"""
405        if self.graph_key in inputs:
406            graph = inputs[self.graph_key]
407            if "keep_graph" in graph:
408                return inputs
409        coords = inputs["coordinates"]
410        natoms = inputs["natoms"]
411        batch_index = inputs["batch_index"]
412
413        if natoms.shape[0] == 1:
414            max_nat = coords.shape[0]
415        else:
416            max_nat = state.get(
417                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
418            )
419        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
420
421        ### compute indices of all pairs
422        apply_pbc = "cells" in inputs
423        minimage = "minimum_image" in inputs.get("flags", {})
424        include_self_image = apply_pbc and not minimage
425
426        shift = 0 if include_self_image else 1
427        p1, p2 = np.triu_indices(max_nat, shift)
428        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
429        pbc_shifts = None
430        if natoms.shape[0] > 1:
431            ## batching => mask irrelevant pairs
432            mask_p12 = (
433                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
434            ).flatten()
435            shift = jnp.concatenate(
436                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
437            )
438            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
439            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
440
441        ## compute vectors
442        overflow_repeats = jnp.asarray(False, dtype=bool)
443        if "cells" not in inputs:
444            vec = coords[p2] - coords[p1]
445        else:
446            cells = inputs["cells"]
447            reciprocal_cells = inputs["reciprocal_cells"]
448            # minimage = state.get("minimum_image", True)
449            minimage = "minimum_image" in inputs.get("flags", {})
450
451            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
452                vecpbc = jnp.dot(vec, reciprocal_cell)
453                if mode == "round":
454                    pbc_shifts = -jnp.round(vecpbc).astype(jnp.int32)
455                elif mode == "floor":
456                    pbc_shifts = -jnp.floor(vecpbc).astype(jnp.int32)
457                else:
458                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
459                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
460
461            if minimage:
462                ## minimum image convention
463                vec = coords[p2] - coords[p1]
464
465                if cells.shape[0] == 1:
466                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
467                else:
468                    batch_index_vec = batch_index[p1]
469                    vec, pbc_shifts = jax.vmap(compute_pbc)(
470                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
471                    )
472            else:
473                ### general PBC only for single cell yet
474                # if cells.shape[0] > 1:
475                #     raise NotImplementedError(
476                #         "General PBC not implemented for batches on accelerator."
477                #     )
478                # cell = cells[0]
479                # reciprocal_cell = reciprocal_cells[0]
480
481                ## put all atoms in central box
482                if cells.shape[0] == 1:
483                    coords_pbc, at_shifts = compute_pbc(
484                        coords, reciprocal_cells[0], cells[0], mode="floor"
485                    )
486                else:
487                    coords_pbc, at_shifts = jax.vmap(
488                        partial(compute_pbc, mode="floor")
489                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
490                vec = coords_pbc[p2] - coords_pbc[p1]
491                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
492                # if num_repeats is None:
493                #     raise ValueError(
494                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
495                #     )
496                # check if num_repeats is larger than previous
497                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
498                cdinv = cutoff_skin * inv_distances
499                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
500                if "true_sys" in inputs:
501                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
502                num_repeats_new = jnp.max(num_repeats_all, axis=0)
503                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
504
505                cell_shift_pbc = jnp.asarray(
506                    np.array(
507                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
508                        dtype=np.int32,
509                    ).T.reshape(-1, 3)
510                )
511
512                if cells.shape[0] == 1:
513                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
514                    pbc_shifts = jnp.broadcast_to(
515                        cell_shift_pbc[None, :, :],
516                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
517                    ).reshape(-1, 3)
518                    p1 = jnp.broadcast_to(
519                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
520                    ).flatten()
521                    p2 = jnp.broadcast_to(
522                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
523                    ).flatten()
524                    if natoms.shape[0] > 1:
525                        mask_p12 = jnp.broadcast_to(
526                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
527                        ).flatten()
528                else:
529                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
530
531                    ## get pbc shifts specific to each box
532                    cell_shift_pbc = jnp.broadcast_to(
533                        cell_shift_pbc[None, :, :],
534                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
535                    )
536                    mask = jnp.all(
537                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
538                    ).flatten()
539                    max_shifts  = state.get("nshifts_pbc", 1)
540
541                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
542                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
543                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
544                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
545                        mask,
546                        max_shifts,
547                        (dvecx, 0.),
548                        (dvecy, 0.),
549                        (dvecz, 0.),
550                        (shiftx, 0),
551                        (shifty, 0),
552                        (shiftz, 0),
553                    )
554                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
555                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1, dtype=jnp.int32)
556                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
557
558                    ## get batch shift in the dvec_filter array
559                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
560                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
561
562                    ## repeat vectors
563                    nvec_max = state.get("nvec_pbc", 1)
564                    batch_index_vec = batch_index[p1]
565                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
566                    nvec = nrep_vec.sum()
567                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
568                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
569                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
570
571                    ## get shift index
572                    dshift = jnp.concatenate(
573                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
574                    )
575                    if nrep_vec.size == 0:
576                        dshift = jnp.array([],dtype=jnp.int32)
577                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
578                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
579                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
580                    vec = vec + dvec[icellshift]
581                    pbc_shifts = cell_shift_pbc[icellshift]
582                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
583                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
584                    mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
585
586                    mask_valid = jnp.arange(nvec_max) < nvec
587                    mask_p12 = jnp.where(mask_valid, mask_p12, False)
588                
589
590        ## compute distances
591        d12 = (vec**2).sum(axis=-1)
592        if natoms.shape[0] > 1:
593            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
594
595        ## filter pairs
596        max_pairs = state.get("npairs", 1)
597        mask = d12 < cutoff_skin**2
598        if include_self_image:
599            mask_self = jnp.logical_or(p1 != p2, d12 > 1.e-3)
600            mask = jnp.logical_and(mask, mask_self)
601        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
602            mask,
603            max_pairs,
604            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
605            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
606            (d12, cutoff_skin**2),
607        )
608        if "cells" in inputs:
609            pbc_shifts = (
610                jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype)
611                .at[scatter_idx]
612                .set(pbc_shifts, mode="drop")
613            )
614            if not minimage:
615                pbc_shifts = (
616                    pbc_shifts
617                    + at_shifts.at[edge_dst].get(fill_value=0)
618                    - at_shifts.at[edge_src].get(fill_value=0)
619                )
620
621        ## check for overflow
622        if natoms.shape[0] == 1:
623            true_max_nat = coords.shape[0]
624        else:
625            true_max_nat = jnp.max(natoms)
626        overflow_count = npairs > max_pairs
627        overflow_at = true_max_nat > max_nat
628        overflow = overflow_count | overflow_at | overflow_repeats
629
630        ## symmetrize
631        if include_self_image:
632            mask_noself = edge_src != edge_dst
633            max_noself = state.get("npairs_noself", 1)
634            (edge_src_noself, edge_dst_noself, d12_noself, pbc_shifts_noself), scatter_idx, npairs_noself = mask_filter_1d(
635                mask_noself,
636                max_noself,
637                (edge_src, coords.shape[0]),
638                (edge_dst, coords.shape[0]),
639                (d12, cutoff_skin**2),
640                (pbc_shifts, jnp.zeros((3,), dtype=pbc_shifts.dtype)),
641            )
642            overflow = overflow | (npairs_noself > max_noself)
643            edge_src = jnp.concatenate((edge_src, edge_dst_noself))
644            edge_dst = jnp.concatenate((edge_dst, edge_src_noself))
645            d12 = jnp.concatenate((d12, d12_noself))
646            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts_noself))
647        else:
648            edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
649                (edge_dst, edge_src)
650            )
651            d12 = jnp.concatenate((d12, d12))
652            if "cells" in inputs:
653                pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
654
655        if "nblist_skin" in state:
656            # edge_mask_skin = edge_mask
657            edge_src_skin = edge_src
658            edge_dst_skin = edge_dst
659            if "cells" in inputs:
660                pbc_shifts_skin = pbc_shifts
661            max_pairs_skin = state.get("npairs_skin", 1)
662            mask = d12 < self.cutoff**2
663            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
664                mask,
665                max_pairs_skin,
666                (edge_src, coords.shape[0]),
667                (edge_dst, coords.shape[0]),
668                (d12, self.cutoff**2),
669            )
670            if "cells" in inputs:
671                pbc_shifts = (
672                    jnp.zeros((max_pairs_skin, 3), dtype=pbc_shifts.dtype)
673                    .at[scatter_idx]
674                    .set(pbc_shifts, mode="drop")
675                )
676            overflow = overflow | (npairs_skin > max_pairs_skin)
677
678        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
679        graph_out = {
680            **graph,
681            "edge_src": edge_src,
682            "edge_dst": edge_dst,
683            "d12": d12,
684            "overflow": overflow,
685            "pbc_shifts": pbc_shifts,
686        }
687        if "nblist_skin" in state:
688            graph_out["edge_src_skin"] = edge_src_skin
689            graph_out["edge_dst_skin"] = edge_dst_skin
690            if "cells" in inputs:
691                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
692
693        if self.k_space and "cells" in inputs:
694            if "k_points" not in graph:
695                raise NotImplementedError(
696                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
697                )
698        return {**inputs, self.graph_key: graph_out}
699
700    @partial(jax.jit, static_argnums=(0,))
701    def update_skin(self, inputs):
702        """update the nblist without recomputing the full nblist"""
703        graph = inputs[self.graph_key]
704
705        edge_src_skin = graph["edge_src_skin"]
706        edge_dst_skin = graph["edge_dst_skin"]
707        coords = inputs["coordinates"]
708        vec = coords.at[edge_dst_skin].get(
709            mode="fill", fill_value=self.cutoff
710        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
711
712        if "cells" in inputs:
713            pbc_shifts_skin = graph["pbc_shifts_skin"]
714            cells = inputs["cells"]
715            if cells.shape[0] == 1:
716                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
717            else:
718                batch_index_vec = inputs["batch_index"][edge_src_skin]
719                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
720
721        nat = coords.shape[0]
722        d12 = jnp.sum(vec**2, axis=-1)
723        mask = d12 < self.cutoff**2
724        max_pairs = graph["edge_src"].shape[0]
725        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
726            mask,
727            max_pairs,
728            (edge_src_skin, nat),
729            (edge_dst_skin, nat),
730            (d12, self.cutoff**2),
731        )
732        if "cells" in inputs:
733            pbc_shifts = (
734                jnp.zeros((max_pairs, 3), dtype=pbc_shifts_skin.dtype)
735                .at[scatter_idx]
736                .set(pbc_shifts_skin)
737            )
738
739        overflow = graph.get("overflow", False) | (npairs > max_pairs)
740        graph_out = {
741            **graph,
742            "edge_src": edge_src,
743            "edge_dst": edge_dst,
744            "d12": d12,
745            "overflow": overflow,
746        }
747        if "cells" in inputs:
748            graph_out["pbc_shifts"] = pbc_shifts
749
750        if self.k_space and "cells" in inputs:
751            if "k_points" not in graph:
752                raise NotImplementedError(
753                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
754                )
755
756        return {**inputs, self.graph_key: graph_out}

Generate a graph from a set of coordinates

FPID: GRAPH

For now, we generate all pairs of atoms and filter based on cutoff. If a nblist_skin is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin method to update the original graph without recomputing the full nblist.

GraphGenerator( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, kmax: int = 30, kthr: float = 1e-06, k_space: bool = False, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

k_space: bool = False

Whether to generate k-space information for the graph.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH'
def init(self):
50    def init(self):
51        return FrozenDict(
52            {
53                "max_nat": 1,
54                "npairs": 1,
55                "nblist_mult_size": self.mult_size,
56            }
57        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
59    def get_processor(self) -> Tuple[nn.Module, Dict]:
60        return GraphProcessor, {
61            "cutoff": self.cutoff,
62            "graph_key": self.graph_key,
63            "switch_params": self.switch_params,
64            "name": f"{self.graph_key}_Processor",
65        }
def get_graph_properties(self):
67    def get_graph_properties(self):
68        return {
69            self.graph_key: {
70                "cutoff": self.cutoff,
71                "directed": True,
72            }
73        }
def check_reallocate(self, state, inputs, parent_overflow=False):
390    def check_reallocate(self, state, inputs, parent_overflow=False):
391        """check for overflow and reallocate nblist if necessary"""
392        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
393        if not overflow:
394            return state, {}, inputs, False
395
396        add_margin = inputs[self.graph_key].get("overflow", False)
397        state, inputs, state_up = self(
398            state, inputs, return_state_update=True, add_margin=add_margin
399        )
400        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
402    @partial(jax.jit, static_argnums=(0, 1))
403    def process(self, state, inputs):
404        """build a nblist on accelerator with jax and precomputed shapes"""
405        if self.graph_key in inputs:
406            graph = inputs[self.graph_key]
407            if "keep_graph" in graph:
408                return inputs
409        coords = inputs["coordinates"]
410        natoms = inputs["natoms"]
411        batch_index = inputs["batch_index"]
412
413        if natoms.shape[0] == 1:
414            max_nat = coords.shape[0]
415        else:
416            max_nat = state.get(
417                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
418            )
419        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
420
421        ### compute indices of all pairs
422        apply_pbc = "cells" in inputs
423        minimage = "minimum_image" in inputs.get("flags", {})
424        include_self_image = apply_pbc and not minimage
425
426        shift = 0 if include_self_image else 1
427        p1, p2 = np.triu_indices(max_nat, shift)
428        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
429        pbc_shifts = None
430        if natoms.shape[0] > 1:
431            ## batching => mask irrelevant pairs
432            mask_p12 = (
433                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
434            ).flatten()
435            shift = jnp.concatenate(
436                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
437            )
438            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
439            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
440
441        ## compute vectors
442        overflow_repeats = jnp.asarray(False, dtype=bool)
443        if "cells" not in inputs:
444            vec = coords[p2] - coords[p1]
445        else:
446            cells = inputs["cells"]
447            reciprocal_cells = inputs["reciprocal_cells"]
448            # minimage = state.get("minimum_image", True)
449            minimage = "minimum_image" in inputs.get("flags", {})
450
451            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
452                vecpbc = jnp.dot(vec, reciprocal_cell)
453                if mode == "round":
454                    pbc_shifts = -jnp.round(vecpbc).astype(jnp.int32)
455                elif mode == "floor":
456                    pbc_shifts = -jnp.floor(vecpbc).astype(jnp.int32)
457                else:
458                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
459                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
460
461            if minimage:
462                ## minimum image convention
463                vec = coords[p2] - coords[p1]
464
465                if cells.shape[0] == 1:
466                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
467                else:
468                    batch_index_vec = batch_index[p1]
469                    vec, pbc_shifts = jax.vmap(compute_pbc)(
470                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
471                    )
472            else:
473                ### general PBC only for single cell yet
474                # if cells.shape[0] > 1:
475                #     raise NotImplementedError(
476                #         "General PBC not implemented for batches on accelerator."
477                #     )
478                # cell = cells[0]
479                # reciprocal_cell = reciprocal_cells[0]
480
481                ## put all atoms in central box
482                if cells.shape[0] == 1:
483                    coords_pbc, at_shifts = compute_pbc(
484                        coords, reciprocal_cells[0], cells[0], mode="floor"
485                    )
486                else:
487                    coords_pbc, at_shifts = jax.vmap(
488                        partial(compute_pbc, mode="floor")
489                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
490                vec = coords_pbc[p2] - coords_pbc[p1]
491                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
492                # if num_repeats is None:
493                #     raise ValueError(
494                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
495                #     )
496                # check if num_repeats is larger than previous
497                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
498                cdinv = cutoff_skin * inv_distances
499                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
500                if "true_sys" in inputs:
501                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
502                num_repeats_new = jnp.max(num_repeats_all, axis=0)
503                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
504
505                cell_shift_pbc = jnp.asarray(
506                    np.array(
507                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
508                        dtype=np.int32,
509                    ).T.reshape(-1, 3)
510                )
511
512                if cells.shape[0] == 1:
513                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
514                    pbc_shifts = jnp.broadcast_to(
515                        cell_shift_pbc[None, :, :],
516                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
517                    ).reshape(-1, 3)
518                    p1 = jnp.broadcast_to(
519                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
520                    ).flatten()
521                    p2 = jnp.broadcast_to(
522                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
523                    ).flatten()
524                    if natoms.shape[0] > 1:
525                        mask_p12 = jnp.broadcast_to(
526                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
527                        ).flatten()
528                else:
529                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
530
531                    ## get pbc shifts specific to each box
532                    cell_shift_pbc = jnp.broadcast_to(
533                        cell_shift_pbc[None, :, :],
534                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
535                    )
536                    mask = jnp.all(
537                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
538                    ).flatten()
539                    max_shifts  = state.get("nshifts_pbc", 1)
540
541                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
542                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
543                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
544                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
545                        mask,
546                        max_shifts,
547                        (dvecx, 0.),
548                        (dvecy, 0.),
549                        (dvecz, 0.),
550                        (shiftx, 0),
551                        (shifty, 0),
552                        (shiftz, 0),
553                    )
554                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
555                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1, dtype=jnp.int32)
556                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
557
558                    ## get batch shift in the dvec_filter array
559                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
560                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
561
562                    ## repeat vectors
563                    nvec_max = state.get("nvec_pbc", 1)
564                    batch_index_vec = batch_index[p1]
565                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
566                    nvec = nrep_vec.sum()
567                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
568                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
569                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
570
571                    ## get shift index
572                    dshift = jnp.concatenate(
573                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
574                    )
575                    if nrep_vec.size == 0:
576                        dshift = jnp.array([],dtype=jnp.int32)
577                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
578                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
579                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
580                    vec = vec + dvec[icellshift]
581                    pbc_shifts = cell_shift_pbc[icellshift]
582                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
583                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
584                    mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
585
586                    mask_valid = jnp.arange(nvec_max) < nvec
587                    mask_p12 = jnp.where(mask_valid, mask_p12, False)
588                
589
590        ## compute distances
591        d12 = (vec**2).sum(axis=-1)
592        if natoms.shape[0] > 1:
593            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
594
595        ## filter pairs
596        max_pairs = state.get("npairs", 1)
597        mask = d12 < cutoff_skin**2
598        if include_self_image:
599            mask_self = jnp.logical_or(p1 != p2, d12 > 1.e-3)
600            mask = jnp.logical_and(mask, mask_self)
601        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
602            mask,
603            max_pairs,
604            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
605            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
606            (d12, cutoff_skin**2),
607        )
608        if "cells" in inputs:
609            pbc_shifts = (
610                jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype)
611                .at[scatter_idx]
612                .set(pbc_shifts, mode="drop")
613            )
614            if not minimage:
615                pbc_shifts = (
616                    pbc_shifts
617                    + at_shifts.at[edge_dst].get(fill_value=0)
618                    - at_shifts.at[edge_src].get(fill_value=0)
619                )
620
621        ## check for overflow
622        if natoms.shape[0] == 1:
623            true_max_nat = coords.shape[0]
624        else:
625            true_max_nat = jnp.max(natoms)
626        overflow_count = npairs > max_pairs
627        overflow_at = true_max_nat > max_nat
628        overflow = overflow_count | overflow_at | overflow_repeats
629
630        ## symmetrize
631        if include_self_image:
632            mask_noself = edge_src != edge_dst
633            max_noself = state.get("npairs_noself", 1)
634            (edge_src_noself, edge_dst_noself, d12_noself, pbc_shifts_noself), scatter_idx, npairs_noself = mask_filter_1d(
635                mask_noself,
636                max_noself,
637                (edge_src, coords.shape[0]),
638                (edge_dst, coords.shape[0]),
639                (d12, cutoff_skin**2),
640                (pbc_shifts, jnp.zeros((3,), dtype=pbc_shifts.dtype)),
641            )
642            overflow = overflow | (npairs_noself > max_noself)
643            edge_src = jnp.concatenate((edge_src, edge_dst_noself))
644            edge_dst = jnp.concatenate((edge_dst, edge_src_noself))
645            d12 = jnp.concatenate((d12, d12_noself))
646            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts_noself))
647        else:
648            edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
649                (edge_dst, edge_src)
650            )
651            d12 = jnp.concatenate((d12, d12))
652            if "cells" in inputs:
653                pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
654
655        if "nblist_skin" in state:
656            # edge_mask_skin = edge_mask
657            edge_src_skin = edge_src
658            edge_dst_skin = edge_dst
659            if "cells" in inputs:
660                pbc_shifts_skin = pbc_shifts
661            max_pairs_skin = state.get("npairs_skin", 1)
662            mask = d12 < self.cutoff**2
663            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
664                mask,
665                max_pairs_skin,
666                (edge_src, coords.shape[0]),
667                (edge_dst, coords.shape[0]),
668                (d12, self.cutoff**2),
669            )
670            if "cells" in inputs:
671                pbc_shifts = (
672                    jnp.zeros((max_pairs_skin, 3), dtype=pbc_shifts.dtype)
673                    .at[scatter_idx]
674                    .set(pbc_shifts, mode="drop")
675                )
676            overflow = overflow | (npairs_skin > max_pairs_skin)
677
678        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
679        graph_out = {
680            **graph,
681            "edge_src": edge_src,
682            "edge_dst": edge_dst,
683            "d12": d12,
684            "overflow": overflow,
685            "pbc_shifts": pbc_shifts,
686        }
687        if "nblist_skin" in state:
688            graph_out["edge_src_skin"] = edge_src_skin
689            graph_out["edge_dst_skin"] = edge_dst_skin
690            if "cells" in inputs:
691                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
692
693        if self.k_space and "cells" in inputs:
694            if "k_points" not in graph:
695                raise NotImplementedError(
696                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
697                )
698        return {**inputs, self.graph_key: graph_out}

build a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
700    @partial(jax.jit, static_argnums=(0,))
701    def update_skin(self, inputs):
702        """update the nblist without recomputing the full nblist"""
703        graph = inputs[self.graph_key]
704
705        edge_src_skin = graph["edge_src_skin"]
706        edge_dst_skin = graph["edge_dst_skin"]
707        coords = inputs["coordinates"]
708        vec = coords.at[edge_dst_skin].get(
709            mode="fill", fill_value=self.cutoff
710        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
711
712        if "cells" in inputs:
713            pbc_shifts_skin = graph["pbc_shifts_skin"]
714            cells = inputs["cells"]
715            if cells.shape[0] == 1:
716                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
717            else:
718                batch_index_vec = inputs["batch_index"][edge_src_skin]
719                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
720
721        nat = coords.shape[0]
722        d12 = jnp.sum(vec**2, axis=-1)
723        mask = d12 < self.cutoff**2
724        max_pairs = graph["edge_src"].shape[0]
725        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
726            mask,
727            max_pairs,
728            (edge_src_skin, nat),
729            (edge_dst_skin, nat),
730            (d12, self.cutoff**2),
731        )
732        if "cells" in inputs:
733            pbc_shifts = (
734                jnp.zeros((max_pairs, 3), dtype=pbc_shifts_skin.dtype)
735                .at[scatter_idx]
736                .set(pbc_shifts_skin)
737            )
738
739        overflow = graph.get("overflow", False) | (npairs > max_pairs)
740        graph_out = {
741            **graph,
742            "edge_src": edge_src,
743            "edge_dst": edge_dst,
744            "d12": d12,
745            "overflow": overflow,
746        }
747        if "cells" in inputs:
748            graph_out["pbc_shifts"] = pbc_shifts
749
750        if self.k_space and "cells" in inputs:
751            if "k_points" not in graph:
752                raise NotImplementedError(
753                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
754                )
755
756        return {**inputs, self.graph_key: graph_out}

update the nblist without recomputing the full nblist

class GraphProcessor(flax.linen.module.Module):
759class GraphProcessor(nn.Module):
760    """Process a pre-generated graph
761
762    The pre-generated graph should contain the following keys:
763    - edge_src: source indices of the edges
764    - edge_dst: destination indices of the edges
765    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
766
767    This module is automatically added to a FENNIX model when a GraphGenerator is used.
768
769    """
770
771    cutoff: float
772    """Cutoff distance for the graph."""
773    graph_key: str = "graph"
774    """Key of the graph in the outputs."""
775    switch_params: dict = dataclasses.field(default_factory=dict)
776    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
777
778    @nn.compact
779    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
780        graph = inputs[self.graph_key]
781        coords = inputs["coordinates"]
782        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
783        # edge_mask = edge_src < coords.shape[0]
784        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
785            edge_src
786        ].get(mode="fill", fill_value=0.0)
787        if "cells" in inputs:
788            cells = inputs["cells"]
789            if cells.shape[0] == 1:
790                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
791            else:
792                batch_index_vec = inputs["batch_index"][edge_src]
793                vec = vec + jax.vmap(jnp.dot)(
794                    graph["pbc_shifts"], cells[batch_index_vec]
795                )
796
797        d2  = jnp.sum(vec**2, axis=-1)
798        distances = safe_sqrt(d2)
799        edge_mask = distances < self.cutoff
800
801        switch = SwitchFunction(
802            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
803        )((distances, edge_mask))
804
805        graph_out = {
806            **graph,
807            "vec": vec,
808            "distances": distances,
809            "switch": switch,
810            "edge_mask": edge_mask,
811        }
812
813        if "alch_group" in inputs:
814            alch_group = inputs["alch_group"]
815            lambda_e = inputs["alch_elambda"]
816            lambda_v = inputs["alch_vlambda"]
817            mask = alch_group[edge_src] == alch_group[edge_dst]
818            graph_out["switch_raw"] = switch
819            graph_out["switch"] = jnp.where(
820                mask,
821                switch,
822                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
823            )
824            graph_out["distances_raw"] = distances
825            if "alch_softcore_e" in inputs:
826                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
827            else:
828                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
829
830            graph_out["distances"] = jnp.where(
831                mask,
832                distances,
833                safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2))
834            )  
835
836
837        return {**inputs, self.graph_key: graph_out}

Process a pre-generated graph

The pre-generated graph should contain the following keys:

  • edge_src: source indices of the edges
  • edge_dst: destination indices of the edges
  • pbcs_shifts: pbc shifts for the edges (only if cells are present in the inputs)

This module is automatically added to a FENNIX model when a GraphGenerator is used.

GraphProcessor( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphFilter:
 840@dataclasses.dataclass(frozen=True)
 841class GraphFilter:
 842    """Filter a graph based on a cutoff distance
 843
 844    FPID: GRAPH_FILTER
 845    """
 846
 847    cutoff: float
 848    """Cutoff distance for the filtering."""
 849    parent_graph: str
 850    """Key of the parent graph in the inputs."""
 851    graph_key: str
 852    """Key of the filtered graph in the outputs."""
 853    remove_hydrogens: int = False
 854    """Remove edges where the source is a hydrogen atom."""
 855    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
 856    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 857    k_space: bool = False
 858    """Generate k-space information for the graph."""
 859    kmax: int = 30
 860    """Maximum number of k-points to consider."""
 861    kthr: float = 1e-6
 862    """Threshold for k-point filtering."""
 863    mult_size: float = 1.05
 864    """Multiplicative factor for resizing the nblist."""
 865
 866    FPID: ClassVar[str] = "GRAPH_FILTER"
 867
 868    def init(self):
 869        return FrozenDict(
 870            {
 871                "npairs": 1,
 872                "nblist_mult_size": self.mult_size,
 873            }
 874        )
 875
 876    def get_processor(self) -> Tuple[nn.Module, Dict]:
 877        return GraphFilterProcessor, {
 878            "cutoff": self.cutoff,
 879            "graph_key": self.graph_key,
 880            "parent_graph": self.parent_graph,
 881            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
 882            "switch_params": self.switch_params,
 883        }
 884
 885    def get_graph_properties(self):
 886        return {
 887            self.graph_key: {
 888                "cutoff": self.cutoff,
 889                "directed": True,
 890                "parent_graph": self.parent_graph,
 891            }
 892        }
 893
 894    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 895        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 896        graph_in = inputs[self.parent_graph]
 897        nat = inputs["species"].shape[0]
 898
 899        new_state = {**state}
 900        state_up = {}
 901        mult_size = state.get("nblist_mult_size", self.mult_size)
 902        assert mult_size >= 1., "nblist_mult_size should be >= 1."
 903
 904        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
 905        d12 = np.array(graph_in["d12"], dtype=np.float32)
 906        if self.remove_hydrogens:
 907            species = inputs["species"]
 908            src_idx = (edge_src < nat).nonzero()[0]
 909            mask = np.zeros(edge_src.shape[0], dtype=bool)
 910            mask[src_idx] = (species > 1)[edge_src[src_idx]]
 911            d12 = np.where(mask, d12, self.cutoff**2)
 912        mask = d12 < self.cutoff**2
 913
 914        max_pairs = state.get("npairs", 1)
 915        idx = np.nonzero(mask)[0]
 916        npairs = idx.shape[0]
 917        if npairs > max_pairs or add_margin:
 918            prev_max_pairs = max_pairs
 919            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 920            state_up["npairs"] = (max_pairs, prev_max_pairs)
 921            new_state["npairs"] = max_pairs
 922
 923        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
 924        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 925        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 926        d12_ = np.full(max_pairs, self.cutoff**2)
 927        filter_indices[:npairs] = idx
 928        edge_src[:npairs] = graph_in["edge_src"][idx]
 929        edge_dst[:npairs] = graph_in["edge_dst"][idx]
 930        d12_[:npairs] = d12[idx]
 931        d12 = d12_
 932
 933        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 934        graph_out = {
 935            **graph,
 936            "edge_src": edge_src,
 937            "edge_dst": edge_dst,
 938            "filter_indices": filter_indices,
 939            "d12": d12,
 940            "overflow": False,
 941        }
 942        if "cells" in inputs:
 943            pbc_shifts = np.zeros((max_pairs, 3), dtype=np.int32)
 944            pbc_shifts[:npairs] = graph_in["pbc_shifts"][idx]
 945            graph_out["pbc_shifts"] = pbc_shifts
 946
 947            if self.k_space:
 948                if "k_points" not in graph:
 949                    ks, _, _, bewald = get_reciprocal_space_parameters(
 950                        inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
 951                    )
 952                graph_out["k_points"] = ks
 953                graph_out["b_ewald"] = bewald
 954
 955        output = {**inputs, self.graph_key: graph_out}
 956        if return_state_update:
 957            return FrozenDict(new_state), output, state_up
 958        return FrozenDict(new_state), output
 959
 960    def check_reallocate(self, state, inputs, parent_overflow=False):
 961        """check for overflow and reallocate nblist if necessary"""
 962        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 963        if not overflow:
 964            return state, {}, inputs, False
 965
 966        add_margin = inputs[self.graph_key].get("overflow", False)
 967        state, inputs, state_up = self(
 968            state, inputs, return_state_update=True, add_margin=add_margin
 969        )
 970        return state, state_up, inputs, True
 971
 972    @partial(jax.jit, static_argnums=(0, 1))
 973    def process(self, state, inputs):
 974        """filter a nblist on accelerator with jax and precomputed shapes"""
 975        graph_in = inputs[self.parent_graph]
 976        if state is None:
 977            # skin update mode
 978            graph = inputs[self.graph_key]
 979            max_pairs = graph["edge_src"].shape[0]
 980        else:
 981            max_pairs = state.get("npairs", 1)
 982
 983        max_pairs_in = graph_in["edge_src"].shape[0]
 984        nat = inputs["species"].shape[0]
 985
 986        edge_src = graph_in["edge_src"]
 987        d12 = graph_in["d12"]
 988        if self.remove_hydrogens:
 989            species = inputs["species"]
 990            mask = (species > 1)[edge_src]
 991            d12 = jnp.where(mask, d12, self.cutoff**2)
 992        mask = d12 < self.cutoff**2
 993
 994        (edge_src, edge_dst, d12, filter_indices), scatter_idx, npairs = mask_filter_1d(
 995            mask,
 996            max_pairs,
 997            (edge_src, nat),
 998            (graph_in["edge_dst"], nat),
 999            (d12, self.cutoff**2),
1000            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
1001        )
1002
1003        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
1004        overflow = graph.get("overflow", False) | (npairs > max_pairs)
1005        graph_out = {
1006            **graph,
1007            "edge_src": edge_src,
1008            "edge_dst": edge_dst,
1009            "filter_indices": filter_indices,
1010            "d12": d12,
1011            "overflow": overflow,
1012        }
1013
1014        if "cells" in inputs:
1015            pbc_shifts = graph_in["pbc_shifts"]
1016            pbc_shifts = (
1017                jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype)
1018                .at[scatter_idx].set(pbc_shifts, mode="drop")
1019            )
1020            graph_out["pbc_shifts"] = pbc_shifts
1021            if self.k_space:
1022                if "k_points" not in graph:
1023                    raise NotImplementedError(
1024                        "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
1025                    )
1026
1027        return {**inputs, self.graph_key: graph_out}
1028
1029    @partial(jax.jit, static_argnums=(0,))
1030    def update_skin(self, inputs):
1031        return self.process(None, inputs)

Filter a graph based on a cutoff distance

FPID: GRAPH_FILTER

GraphFilter( cutoff: float, parent_graph: str, graph_key: str, remove_hydrogens: int = False, switch_params: flax.core.frozen_dict.FrozenDict = <factory>, k_space: bool = False, kmax: int = 30, kthr: float = 1e-06, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the filtering.

parent_graph: str

Key of the parent graph in the inputs.

graph_key: str

Key of the filtered graph in the outputs.

remove_hydrogens: int = False

Remove edges where the source is a hydrogen atom.

switch_params: flax.core.frozen_dict.FrozenDict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

k_space: bool = False

Generate k-space information for the graph.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH_FILTER'
def init(self):
868    def init(self):
869        return FrozenDict(
870            {
871                "npairs": 1,
872                "nblist_mult_size": self.mult_size,
873            }
874        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
876    def get_processor(self) -> Tuple[nn.Module, Dict]:
877        return GraphFilterProcessor, {
878            "cutoff": self.cutoff,
879            "graph_key": self.graph_key,
880            "parent_graph": self.parent_graph,
881            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
882            "switch_params": self.switch_params,
883        }
def get_graph_properties(self):
885    def get_graph_properties(self):
886        return {
887            self.graph_key: {
888                "cutoff": self.cutoff,
889                "directed": True,
890                "parent_graph": self.parent_graph,
891            }
892        }
def check_reallocate(self, state, inputs, parent_overflow=False):
960    def check_reallocate(self, state, inputs, parent_overflow=False):
961        """check for overflow and reallocate nblist if necessary"""
962        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
963        if not overflow:
964            return state, {}, inputs, False
965
966        add_margin = inputs[self.graph_key].get("overflow", False)
967        state, inputs, state_up = self(
968            state, inputs, return_state_update=True, add_margin=add_margin
969        )
970        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
 972    @partial(jax.jit, static_argnums=(0, 1))
 973    def process(self, state, inputs):
 974        """filter a nblist on accelerator with jax and precomputed shapes"""
 975        graph_in = inputs[self.parent_graph]
 976        if state is None:
 977            # skin update mode
 978            graph = inputs[self.graph_key]
 979            max_pairs = graph["edge_src"].shape[0]
 980        else:
 981            max_pairs = state.get("npairs", 1)
 982
 983        max_pairs_in = graph_in["edge_src"].shape[0]
 984        nat = inputs["species"].shape[0]
 985
 986        edge_src = graph_in["edge_src"]
 987        d12 = graph_in["d12"]
 988        if self.remove_hydrogens:
 989            species = inputs["species"]
 990            mask = (species > 1)[edge_src]
 991            d12 = jnp.where(mask, d12, self.cutoff**2)
 992        mask = d12 < self.cutoff**2
 993
 994        (edge_src, edge_dst, d12, filter_indices), scatter_idx, npairs = mask_filter_1d(
 995            mask,
 996            max_pairs,
 997            (edge_src, nat),
 998            (graph_in["edge_dst"], nat),
 999            (d12, self.cutoff**2),
1000            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
1001        )
1002
1003        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
1004        overflow = graph.get("overflow", False) | (npairs > max_pairs)
1005        graph_out = {
1006            **graph,
1007            "edge_src": edge_src,
1008            "edge_dst": edge_dst,
1009            "filter_indices": filter_indices,
1010            "d12": d12,
1011            "overflow": overflow,
1012        }
1013
1014        if "cells" in inputs:
1015            pbc_shifts = graph_in["pbc_shifts"]
1016            pbc_shifts = (
1017                jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype)
1018                .at[scatter_idx].set(pbc_shifts, mode="drop")
1019            )
1020            graph_out["pbc_shifts"] = pbc_shifts
1021            if self.k_space:
1022                if "k_points" not in graph:
1023                    raise NotImplementedError(
1024                        "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
1025                    )
1026
1027        return {**inputs, self.graph_key: graph_out}

filter a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1029    @partial(jax.jit, static_argnums=(0,))
1030    def update_skin(self, inputs):
1031        return self.process(None, inputs)
class GraphFilterProcessor(flax.linen.module.Module):
1034class GraphFilterProcessor(nn.Module):
1035    """Filter processing for a pre-generated graph
1036
1037    This module is automatically added to a FENNIX model when a GraphFilter is used.
1038    """
1039
1040    cutoff: float
1041    """Cutoff distance for the filtering."""
1042    graph_key: str
1043    """Key of the filtered graph in the inputs."""
1044    parent_graph: str
1045    """Key of the parent graph in the inputs."""
1046    switch_params: dict = dataclasses.field(default_factory=dict)
1047    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
1048
1049    @nn.compact
1050    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1051        graph_in = inputs[self.parent_graph]
1052        graph = inputs[self.graph_key]
1053
1054        d_key = "distances_raw" if "distances_raw" in graph else "distances"
1055
1056        if graph_in["vec"].shape[0] == 0:
1057            vec = graph_in["vec"]
1058            distances = graph_in[d_key]
1059            filter_indices = jnp.asarray([], dtype=jnp.int32)
1060        else:
1061            filter_indices = graph["filter_indices"]
1062            vec = (
1063                graph_in["vec"]
1064                .at[filter_indices]
1065                .get(mode="fill", fill_value=self.cutoff)
1066            )
1067            distances = (
1068                graph_in[d_key]
1069                .at[filter_indices]
1070                .get(mode="fill", fill_value=self.cutoff)
1071            )
1072
1073        edge_mask = distances < self.cutoff
1074        switch = SwitchFunction(
1075            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
1076        )((distances, edge_mask))
1077
1078        graph_out = {
1079            **graph,
1080            "vec": vec,
1081            "distances": distances,
1082            "switch": switch,
1083            "filter_indices": filter_indices,
1084            "edge_mask": edge_mask,
1085        }
1086
1087        if "alch_group" in inputs:
1088            edge_src=graph["edge_src"]
1089            edge_dst=graph["edge_dst"]
1090            alch_group = inputs["alch_group"]
1091            lambda_e = inputs["alch_elambda"]
1092            lambda_v = inputs["alch_vlambda"]
1093            mask = alch_group[edge_src] == alch_group[edge_dst]
1094            graph_out["switch_raw"] = switch
1095            graph_out["switch"] = jnp.where(
1096                mask,
1097                switch,
1098                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
1099            )
1100
1101            graph_out["distances_raw"] = distances
1102            if "alch_softcore_e" in inputs:
1103                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
1104            else:
1105                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
1106
1107            graph_out["distances"] = jnp.where(
1108                mask,
1109                distances,
1110                safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2))
1111            )  
1112            
1113
1114        return {**inputs, self.graph_key: graph_out}

Filter processing for a pre-generated graph

This module is automatically added to a FENNIX model when a GraphFilter is used.

GraphFilterProcessor( cutoff: float, graph_key: str, parent_graph: str, switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the filtering.

graph_key: str

Key of the filtered graph in the inputs.

parent_graph: str

Key of the parent graph in the inputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphAngularExtension:
1117@dataclasses.dataclass(frozen=True)
1118class GraphAngularExtension:
1119    """Add angles list to a graph
1120
1121    FPID: GRAPH_ANGULAR_EXTENSION
1122    """
1123
1124    mult_size: float = 1.05
1125    """Multiplicative factor for resizing the nblist."""
1126    add_neigh: int = 5
1127    """Additional neighbors to add to the nblist when resizing."""
1128    graph_key: str = "graph"
1129    """Key of the graph in the inputs."""
1130
1131    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1132
1133    def init(self):
1134        return FrozenDict(
1135            {
1136                "nangles": 0,
1137                "nblist_mult_size": self.mult_size,
1138                "max_neigh": self.add_neigh,
1139                "add_neigh": self.add_neigh,
1140            }
1141        )
1142
1143    def get_processor(self) -> Tuple[nn.Module, Dict]:
1144        return GraphAngleProcessor, {
1145            "graph_key": self.graph_key,
1146            "name": f"{self.graph_key}_AngleProcessor",
1147        }
1148
1149    def get_graph_properties(self):
1150        return {
1151            self.graph_key: {
1152                "has_angles": True,
1153            }
1154        }
1155
1156    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1157        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1158        graph = inputs[self.graph_key]
1159        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1160
1161        new_state = {**state}
1162        state_up = {}
1163        mult_size = state.get("nblist_mult_size", self.mult_size)
1164        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1165
1166        ### count number of neighbors
1167        nat = inputs["species"].shape[0]
1168        count = np.zeros(nat + 1, dtype=np.int32)
1169        np.add.at(count, edge_src, 1)
1170        max_count = int(np.max(count[:-1]))
1171
1172        ### get sizes
1173        max_neigh = state.get("max_neigh", self.add_neigh)
1174        nedge = edge_src.shape[0]
1175        if max_count > max_neigh or add_margin:
1176            prev_max_neigh = max_neigh
1177            max_neigh = max(max_count, max_neigh) + state.get(
1178                "add_neigh", self.add_neigh
1179            )
1180            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1181            new_state["max_neigh"] = max_neigh
1182
1183        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1184
1185        nedge = edge_src.shape[0]
1186
1187        ### sort edge_src
1188        idx_sort = np.argsort(edge_src)
1189        edge_src_sorted = edge_src[idx_sort]
1190
1191        ### map sparse to dense nblist
1192        offset = np.tile(np.arange(max_count), nat)
1193        if max_count * nat >= nedge:
1194            offset = np.tile(np.arange(max_count), nat)[:nedge]
1195        else:
1196            offset = np.zeros(nedge, dtype=np.int32)
1197            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1198
1199        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1200        mask = edge_src_sorted < nat
1201        indices = edge_src_sorted * max_count + offset
1202        indices = indices[mask]
1203        idx_sort = idx_sort[mask]
1204        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1205        edge_idx[indices] = idx_sort
1206        edge_idx = edge_idx.reshape(nat, max_count)
1207
1208        ### find all triplet for each atom center
1209        local_src, local_dst = np.triu_indices(max_count, 1)
1210        angle_src = edge_idx[:, local_src].flatten()
1211        angle_dst = edge_idx[:, local_dst].flatten()
1212
1213        ### mask for valid angles
1214        mask1 = angle_src < nedge
1215        mask2 = angle_dst < nedge
1216        angle_mask = mask1 & mask2
1217
1218        max_angles = state.get("nangles", 0)
1219        idx = np.nonzero(angle_mask)[0]
1220        nangles = idx.shape[0]
1221        if nangles > max_angles or add_margin:
1222            max_angles_prev = max_angles
1223            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1224            state_up["nangles"] = (max_angles, max_angles_prev)
1225            new_state["nangles"] = max_angles
1226
1227        ## filter angles to sparse representation
1228        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1229        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1230        angle_src_[:nangles] = angle_src[idx]
1231        angle_dst_[:nangles] = angle_dst[idx]
1232
1233        central_atom = np.full(max_angles, nat, dtype=np.int32)
1234        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1235
1236        ## update graph
1237        output = {
1238            **inputs,
1239            self.graph_key: {
1240                **graph,
1241                "angle_src": angle_src_,
1242                "angle_dst": angle_dst_,
1243                "central_atom": central_atom,
1244                "angle_overflow": False,
1245                "max_neigh": max_neigh,
1246                "__max_neigh_array": max_neigh_arr,
1247            },
1248        }
1249
1250        if return_state_update:
1251            return FrozenDict(new_state), output, state_up
1252        return FrozenDict(new_state), output
1253
1254    def check_reallocate(self, state, inputs, parent_overflow=False):
1255        """check for overflow and reallocate nblist if necessary"""
1256        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1257        if not overflow:
1258            return state, {}, inputs, False
1259
1260        add_margin = inputs[self.graph_key]["angle_overflow"]
1261        state, inputs, state_up = self(
1262            state, inputs, return_state_update=True, add_margin=add_margin
1263        )
1264        return state, state_up, inputs, True
1265
1266    @partial(jax.jit, static_argnums=(0, 1))
1267    def process(self, state, inputs):
1268        """build angle nblist on accelerator with jax and precomputed shapes"""
1269        graph = inputs[self.graph_key]
1270        edge_src = graph["edge_src"]
1271
1272        ### count number of neighbors
1273        nat = inputs["species"].shape[0]
1274        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1275        max_count = jnp.max(count)
1276
1277        ### get sizes
1278        if state is None:
1279            max_neigh_arr = graph["__max_neigh_array"]
1280            max_neigh = max_neigh_arr.shape[0]
1281            prev_nangles = graph["angle_src"].shape[0]
1282        else:
1283            max_neigh = state.get("max_neigh", self.add_neigh)
1284            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1285            prev_nangles = state.get("nangles", 0)
1286
1287        nedge = edge_src.shape[0]
1288
1289        ### sort edge_src
1290        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1291        edge_src_sorted = edge_src[idx_sort]
1292
1293        ### map sparse to dense nblist
1294        if max_neigh * nat < nedge:
1295            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1296        offset = jnp.asarray(
1297            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1298        )
1299        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1300        indices = edge_src_sorted * max_neigh + offset
1301        edge_idx = (
1302            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1303            .at[indices]
1304            .set(idx_sort, mode="drop")
1305            .reshape(nat, max_neigh)
1306        )
1307
1308        ### find all triplet for each atom center
1309        local_src, local_dst = np.triu_indices(max_neigh, 1)
1310        angle_src = edge_idx[:, local_src].flatten()
1311        angle_dst = edge_idx[:, local_dst].flatten()
1312
1313        ### mask for valid angles
1314        mask1 = angle_src < nedge
1315        mask2 = angle_dst < nedge
1316        angle_mask = mask1 & mask2
1317
1318        ## filter angles to sparse representation
1319        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1320            angle_mask,
1321            prev_nangles,
1322            (angle_src, nedge),
1323            (angle_dst, nedge),
1324        )
1325        ## find central atom
1326        central_atom = edge_src[angle_src]
1327
1328        ## check for overflow
1329        angle_overflow = nangles > prev_nangles
1330        neigh_overflow = max_count > max_neigh
1331        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1332
1333        ## update graph
1334        output = {
1335            **inputs,
1336            self.graph_key: {
1337                **graph,
1338                "angle_src": angle_src,
1339                "angle_dst": angle_dst,
1340                "central_atom": central_atom,
1341                "angle_overflow": overflow,
1342                # "max_neigh": max_neigh,
1343                "__max_neigh_array": max_neigh_arr,
1344            },
1345        }
1346
1347        return output
1348
1349    @partial(jax.jit, static_argnums=(0,))
1350    def update_skin(self, inputs):
1351        return self.process(None, inputs)

Add angles list to a graph

FPID: GRAPH_ANGULAR_EXTENSION

GraphAngularExtension( mult_size: float = 1.05, add_neigh: int = 5, graph_key: str = 'graph')
mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

add_neigh: int = 5

Additional neighbors to add to the nblist when resizing.

graph_key: str = 'graph'

Key of the graph in the inputs.

FPID: ClassVar[str] = 'GRAPH_ANGULAR_EXTENSION'
def init(self):
1133    def init(self):
1134        return FrozenDict(
1135            {
1136                "nangles": 0,
1137                "nblist_mult_size": self.mult_size,
1138                "max_neigh": self.add_neigh,
1139                "add_neigh": self.add_neigh,
1140            }
1141        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
1143    def get_processor(self) -> Tuple[nn.Module, Dict]:
1144        return GraphAngleProcessor, {
1145            "graph_key": self.graph_key,
1146            "name": f"{self.graph_key}_AngleProcessor",
1147        }
def get_graph_properties(self):
1149    def get_graph_properties(self):
1150        return {
1151            self.graph_key: {
1152                "has_angles": True,
1153            }
1154        }
def check_reallocate(self, state, inputs, parent_overflow=False):
1254    def check_reallocate(self, state, inputs, parent_overflow=False):
1255        """check for overflow and reallocate nblist if necessary"""
1256        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1257        if not overflow:
1258            return state, {}, inputs, False
1259
1260        add_margin = inputs[self.graph_key]["angle_overflow"]
1261        state, inputs, state_up = self(
1262            state, inputs, return_state_update=True, add_margin=add_margin
1263        )
1264        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1266    @partial(jax.jit, static_argnums=(0, 1))
1267    def process(self, state, inputs):
1268        """build angle nblist on accelerator with jax and precomputed shapes"""
1269        graph = inputs[self.graph_key]
1270        edge_src = graph["edge_src"]
1271
1272        ### count number of neighbors
1273        nat = inputs["species"].shape[0]
1274        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1275        max_count = jnp.max(count)
1276
1277        ### get sizes
1278        if state is None:
1279            max_neigh_arr = graph["__max_neigh_array"]
1280            max_neigh = max_neigh_arr.shape[0]
1281            prev_nangles = graph["angle_src"].shape[0]
1282        else:
1283            max_neigh = state.get("max_neigh", self.add_neigh)
1284            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1285            prev_nangles = state.get("nangles", 0)
1286
1287        nedge = edge_src.shape[0]
1288
1289        ### sort edge_src
1290        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1291        edge_src_sorted = edge_src[idx_sort]
1292
1293        ### map sparse to dense nblist
1294        if max_neigh * nat < nedge:
1295            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1296        offset = jnp.asarray(
1297            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1298        )
1299        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1300        indices = edge_src_sorted * max_neigh + offset
1301        edge_idx = (
1302            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1303            .at[indices]
1304            .set(idx_sort, mode="drop")
1305            .reshape(nat, max_neigh)
1306        )
1307
1308        ### find all triplet for each atom center
1309        local_src, local_dst = np.triu_indices(max_neigh, 1)
1310        angle_src = edge_idx[:, local_src].flatten()
1311        angle_dst = edge_idx[:, local_dst].flatten()
1312
1313        ### mask for valid angles
1314        mask1 = angle_src < nedge
1315        mask2 = angle_dst < nedge
1316        angle_mask = mask1 & mask2
1317
1318        ## filter angles to sparse representation
1319        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1320            angle_mask,
1321            prev_nangles,
1322            (angle_src, nedge),
1323            (angle_dst, nedge),
1324        )
1325        ## find central atom
1326        central_atom = edge_src[angle_src]
1327
1328        ## check for overflow
1329        angle_overflow = nangles > prev_nangles
1330        neigh_overflow = max_count > max_neigh
1331        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1332
1333        ## update graph
1334        output = {
1335            **inputs,
1336            self.graph_key: {
1337                **graph,
1338                "angle_src": angle_src,
1339                "angle_dst": angle_dst,
1340                "central_atom": central_atom,
1341                "angle_overflow": overflow,
1342                # "max_neigh": max_neigh,
1343                "__max_neigh_array": max_neigh_arr,
1344            },
1345        }
1346
1347        return output

build angle nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1349    @partial(jax.jit, static_argnums=(0,))
1350    def update_skin(self, inputs):
1351        return self.process(None, inputs)
class GraphAngleProcessor(flax.linen.module.Module):
1354class GraphAngleProcessor(nn.Module):
1355    """Process a pre-generated graph to compute angles
1356
1357    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1358
1359    """
1360
1361    graph_key: str
1362    """Key of the graph in the inputs."""
1363
1364    @nn.compact
1365    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1366        graph = inputs[self.graph_key]
1367        distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"]
1368        vec = graph["vec"]
1369        angle_src = graph["angle_src"]
1370        angle_dst = graph["angle_dst"]
1371
1372        dir = vec / jnp.clip(distances[:, None], min=1.0e-5)
1373        cos_angles = (
1374            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1375            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1376        ).sum(axis=-1)
1377
1378        angles = jnp.arccos(0.95 * cos_angles)
1379
1380        return {
1381            **inputs,
1382            self.graph_key: {
1383                **graph,
1384                # "cos_angles": cos_angles,
1385                "angles": angles,
1386                # "angle_mask": angle_mask,
1387            },
1388        }

Process a pre-generated graph to compute angles

This module is automatically added to a FENNIX model when a GraphAngularExtension is used.

GraphAngleProcessor( graph_key: str, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str

Key of the graph in the inputs.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class SpeciesIndexer:
1391@dataclasses.dataclass(frozen=True)
1392class SpeciesIndexer:
1393    """Build an index that splits atomic arrays by species.
1394
1395    FPID: SPECIES_INDEXER
1396
1397    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1398    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1399
1400    """
1401
1402    output_key: str = "species_index"
1403    """Key for the output dictionary."""
1404    species_order: Optional[str] = None
1405    """Comma separated list of species in the order they should be indexed."""
1406    add_atoms: int = 0
1407    """Additional atoms to add to the sizes."""
1408    add_atoms_margin: int = 10
1409    """Additional atoms to add to the sizes when adding margin."""
1410
1411    FPID: ClassVar[str] = "SPECIES_INDEXER"
1412
1413    def init(self):
1414        return FrozenDict(
1415            {
1416                "sizes": {},
1417            }
1418        )
1419
1420    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1421        species = np.array(inputs["species"], dtype=np.int32)
1422        nat = species.shape[0]
1423        set_species, counts = np.unique(species, return_counts=True)
1424
1425        new_state = {**state}
1426        state_up = {}
1427
1428        sizes = state.get("sizes", FrozenDict({}))
1429        new_sizes = {**sizes}
1430        up_sizes = False
1431        counts_dict = {}
1432        for s, c in zip(set_species, counts):
1433            if s <= 0:
1434                continue
1435            counts_dict[s] = c
1436            if c > sizes.get(s, 0):
1437                up_sizes = True
1438                add_atoms = state.get("add_atoms", self.add_atoms)
1439                if add_margin:
1440                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1441                new_sizes[s] = c + add_atoms
1442
1443        new_sizes = FrozenDict(new_sizes)
1444        if up_sizes:
1445            state_up["sizes"] = (new_sizes, sizes)
1446            new_state["sizes"] = new_sizes
1447
1448        if self.species_order is not None:
1449            species_order = [el.strip() for el in self.species_order.split(",")]
1450            max_size_prev = state.get("max_size", 0)
1451            max_size = max(new_sizes.values())
1452            if max_size > max_size_prev:
1453                state_up["max_size"] = (max_size, max_size_prev)
1454                new_state["max_size"] = max_size
1455                max_size_prev = max_size
1456
1457            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1458            for i, el in enumerate(species_order):
1459                s = PERIODIC_TABLE_REV_IDX[el]
1460                if s in counts_dict.keys():
1461                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1462        else:
1463            species_index = {
1464                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1465                for s, c in new_sizes.items()
1466            }
1467            for s, c in zip(set_species, counts):
1468                if s <= 0:
1469                    continue
1470                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1471
1472        output = {
1473            **inputs,
1474            self.output_key: species_index,
1475            self.output_key + "_overflow": False,
1476        }
1477
1478        if return_state_update:
1479            return FrozenDict(new_state), output, state_up
1480        return FrozenDict(new_state), output
1481
1482    def check_reallocate(self, state, inputs, parent_overflow=False):
1483        """check for overflow and reallocate nblist if necessary"""
1484        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1485        if not overflow:
1486            return state, {}, inputs, False
1487
1488        add_margin = inputs[self.output_key + "_overflow"]
1489        state, inputs, state_up = self(
1490            state, inputs, return_state_update=True, add_margin=add_margin
1491        )
1492        return state, state_up, inputs, True
1493        # return state, {}, inputs, parent_overflow
1494
1495    @partial(jax.jit, static_argnums=(0, 1))
1496    def process(self, state, inputs):
1497        # assert (
1498        #     self.output_key in inputs
1499        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1500
1501        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1502        if self.output_key in inputs and not recompute_species_index:
1503            return inputs
1504
1505        if state is None:
1506            raise ValueError("Species Indexer state must be provided on accelerator.")
1507
1508        species = inputs["species"]
1509        nat = species.shape[0]
1510
1511        sizes = state["sizes"]
1512
1513        if self.species_order is not None:
1514            species_order = [el.strip() for el in self.species_order.split(",")]
1515            max_size = state["max_size"]
1516
1517            species_index = jnp.full(
1518                (len(species_order), max_size), nat, dtype=jnp.int32
1519            )
1520            for i, el in enumerate(species_order):
1521                s = PERIODIC_TABLE_REV_IDX[el]
1522                if s in sizes.keys():
1523                    c = sizes[s]
1524                    species_index = species_index.at[i, :].set(
1525                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1526                    )
1527                # if s in counts_dict.keys():
1528                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1529        else:
1530            # species_index = {
1531            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1532            # for s, c in sizes.items()
1533            # }
1534            species_index = {}
1535            overflow = False
1536            natcount = 0
1537            for s, c in sizes.items():
1538                mask = species == s
1539                new_size = jnp.sum(mask)
1540                natcount = natcount + new_size
1541                overflow = overflow | (new_size > c)  # check if sizes are correct
1542                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1543                    species == s, size=c, fill_value=nat
1544                )[0]
1545
1546            mask = species <= 0
1547            new_size = jnp.sum(mask)
1548            natcount = natcount + new_size
1549            overflow = overflow | (
1550                natcount < species.shape[0]
1551            )  # check if any species missing
1552
1553        return {
1554            **inputs,
1555            self.output_key: species_index,
1556            self.output_key + "_overflow": overflow,
1557        }
1558
1559    @partial(jax.jit, static_argnums=(0,))
1560    def update_skin(self, inputs):
1561        return self.process(None, inputs)

Build an index that splits atomic arrays by species.

FPID: SPECIES_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

SpeciesIndexer( output_key: str = 'species_index', species_order: Optional[str] = None, add_atoms: int = 0, add_atoms_margin: int = 10)
output_key: str = 'species_index'

Key for the output dictionary.

species_order: Optional[str] = None

Comma separated list of species in the order they should be indexed.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

FPID: ClassVar[str] = 'SPECIES_INDEXER'
def init(self):
1413    def init(self):
1414        return FrozenDict(
1415            {
1416                "sizes": {},
1417            }
1418        )
def check_reallocate(self, state, inputs, parent_overflow=False):
1482    def check_reallocate(self, state, inputs, parent_overflow=False):
1483        """check for overflow and reallocate nblist if necessary"""
1484        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1485        if not overflow:
1486            return state, {}, inputs, False
1487
1488        add_margin = inputs[self.output_key + "_overflow"]
1489        state, inputs, state_up = self(
1490            state, inputs, return_state_update=True, add_margin=add_margin
1491        )
1492        return state, state_up, inputs, True
1493        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1495    @partial(jax.jit, static_argnums=(0, 1))
1496    def process(self, state, inputs):
1497        # assert (
1498        #     self.output_key in inputs
1499        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1500
1501        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1502        if self.output_key in inputs and not recompute_species_index:
1503            return inputs
1504
1505        if state is None:
1506            raise ValueError("Species Indexer state must be provided on accelerator.")
1507
1508        species = inputs["species"]
1509        nat = species.shape[0]
1510
1511        sizes = state["sizes"]
1512
1513        if self.species_order is not None:
1514            species_order = [el.strip() for el in self.species_order.split(",")]
1515            max_size = state["max_size"]
1516
1517            species_index = jnp.full(
1518                (len(species_order), max_size), nat, dtype=jnp.int32
1519            )
1520            for i, el in enumerate(species_order):
1521                s = PERIODIC_TABLE_REV_IDX[el]
1522                if s in sizes.keys():
1523                    c = sizes[s]
1524                    species_index = species_index.at[i, :].set(
1525                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1526                    )
1527                # if s in counts_dict.keys():
1528                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1529        else:
1530            # species_index = {
1531            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1532            # for s, c in sizes.items()
1533            # }
1534            species_index = {}
1535            overflow = False
1536            natcount = 0
1537            for s, c in sizes.items():
1538                mask = species == s
1539                new_size = jnp.sum(mask)
1540                natcount = natcount + new_size
1541                overflow = overflow | (new_size > c)  # check if sizes are correct
1542                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1543                    species == s, size=c, fill_value=nat
1544                )[0]
1545
1546            mask = species <= 0
1547            new_size = jnp.sum(mask)
1548            natcount = natcount + new_size
1549            overflow = overflow | (
1550                natcount < species.shape[0]
1551            )  # check if any species missing
1552
1553        return {
1554            **inputs,
1555            self.output_key: species_index,
1556            self.output_key + "_overflow": overflow,
1557        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1559    @partial(jax.jit, static_argnums=(0,))
1560    def update_skin(self, inputs):
1561        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class BlockIndexer:
1563@dataclasses.dataclass(frozen=True)
1564class BlockIndexer:
1565    """Build an index that splits atomic arrays by chemical blocks.
1566
1567    FPID: BLOCK_INDEXER
1568
1569    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1570    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1571
1572    """
1573
1574    output_key: str = "block_index"
1575    """Key for the output dictionary."""
1576    add_atoms: int = 0
1577    """Additional atoms to add to the sizes."""
1578    add_atoms_margin: int = 10
1579    """Additional atoms to add to the sizes when adding margin."""
1580    split_CNOPSSe: bool = False
1581
1582    FPID: ClassVar[str] = "BLOCK_INDEXER"
1583
1584    def init(self):
1585        return FrozenDict(
1586            {
1587                "sizes": {},
1588            }
1589        )
1590
1591    def build_chemical_blocks(self):
1592        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1593        if self.split_CNOPSSe:
1594            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1595            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1596        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1597        if self.split_CNOPSSe:
1598            _CHEMICAL_BLOCKS[6] = 1
1599            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1600            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1601            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1602            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1603            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1604        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1605
1606    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1607        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1608
1609        species = np.array(inputs["species"], dtype=np.int32)
1610        blocks = _CHEMICAL_BLOCKS[species]
1611        nat = species.shape[0]
1612        set_blocks, counts = np.unique(blocks, return_counts=True)
1613
1614        new_state = {**state}
1615        state_up = {}
1616
1617        sizes = state.get("sizes", FrozenDict({}))
1618        new_sizes = {**sizes}
1619        up_sizes = False
1620        for s, c in zip(set_blocks, counts):
1621            if s < 0:
1622                continue
1623            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1624            if c > sizes.get(key, 0):
1625                up_sizes = True
1626                add_atoms = state.get("add_atoms", self.add_atoms)
1627                if add_margin:
1628                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1629                new_sizes[key] = c + add_atoms
1630
1631        new_sizes = FrozenDict(new_sizes)
1632        if up_sizes:
1633            state_up["sizes"] = (new_sizes, sizes)
1634            new_state["sizes"] = new_sizes
1635
1636        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1637        for (_,n), c in new_sizes.items():
1638            block_index[n] = np.full(c, nat, dtype=np.int32)
1639        # block_index = {
1640            # n: np.full(c, nat, dtype=np.int32)
1641            # for (_,n), c in new_sizes.items()
1642        # }
1643        for s, c in zip(set_blocks, counts):
1644            if s < 0:
1645                continue
1646            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1647
1648        output = {
1649            **inputs,
1650            self.output_key: block_index,
1651            self.output_key + "_overflow": False,
1652        }
1653
1654        if return_state_update:
1655            return FrozenDict(new_state), output, state_up
1656        return FrozenDict(new_state), output
1657
1658    def check_reallocate(self, state, inputs, parent_overflow=False):
1659        """check for overflow and reallocate nblist if necessary"""
1660        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1661        if not overflow:
1662            return state, {}, inputs, False
1663
1664        add_margin = inputs[self.output_key + "_overflow"]
1665        state, inputs, state_up = self(
1666            state, inputs, return_state_update=True, add_margin=add_margin
1667        )
1668        return state, state_up, inputs, True
1669        # return state, {}, inputs, parent_overflow
1670
1671    @partial(jax.jit, static_argnums=(0, 1))
1672    def process(self, state, inputs):
1673        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1674        # assert (
1675        #     self.output_key in inputs
1676        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1677
1678        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1679        if self.output_key in inputs and not recompute_species_index:
1680            return inputs
1681
1682        if state is None:
1683            raise ValueError("Block Indexer state must be provided on accelerator.")
1684
1685        species = inputs["species"]
1686        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1687        nat = species.shape[0]
1688
1689        sizes = state["sizes"]
1690
1691        # species_index = {
1692        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1693        # for s, c in sizes.items()
1694        # }
1695        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1696        overflow = False
1697        natcount = 0
1698        for (s,name), c in sizes.items():
1699            mask = blocks == s
1700            new_size = jnp.sum(mask)
1701            natcount = natcount + new_size
1702            overflow = overflow | (new_size > c)  # check if sizes are correct
1703            block_index[name] = jnp.nonzero(
1704                mask, size=c, fill_value=nat
1705            )[0]
1706
1707        mask = blocks < 0
1708        new_size = jnp.sum(mask)
1709        natcount = natcount + new_size
1710        overflow = overflow | (
1711            natcount < species.shape[0]
1712        )  # check if any species missing
1713
1714        return {
1715            **inputs,
1716            self.output_key: block_index,
1717            self.output_key + "_overflow": overflow,
1718        }
1719
1720    @partial(jax.jit, static_argnums=(0,))
1721    def update_skin(self, inputs):
1722        return self.process(None, inputs)

Build an index that splits atomic arrays by chemical blocks.

FPID: BLOCK_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

BlockIndexer( output_key: str = 'block_index', add_atoms: int = 0, add_atoms_margin: int = 10, split_CNOPSSe: bool = False)
output_key: str = 'block_index'

Key for the output dictionary.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

split_CNOPSSe: bool = False
FPID: ClassVar[str] = 'BLOCK_INDEXER'
def init(self):
1584    def init(self):
1585        return FrozenDict(
1586            {
1587                "sizes": {},
1588            }
1589        )
def build_chemical_blocks(self):
1591    def build_chemical_blocks(self):
1592        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1593        if self.split_CNOPSSe:
1594            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1595            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1596        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1597        if self.split_CNOPSSe:
1598            _CHEMICAL_BLOCKS[6] = 1
1599            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1600            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1601            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1602            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1603            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1604        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
def check_reallocate(self, state, inputs, parent_overflow=False):
1658    def check_reallocate(self, state, inputs, parent_overflow=False):
1659        """check for overflow and reallocate nblist if necessary"""
1660        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1661        if not overflow:
1662            return state, {}, inputs, False
1663
1664        add_margin = inputs[self.output_key + "_overflow"]
1665        state, inputs, state_up = self(
1666            state, inputs, return_state_update=True, add_margin=add_margin
1667        )
1668        return state, state_up, inputs, True
1669        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1671    @partial(jax.jit, static_argnums=(0, 1))
1672    def process(self, state, inputs):
1673        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1674        # assert (
1675        #     self.output_key in inputs
1676        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1677
1678        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1679        if self.output_key in inputs and not recompute_species_index:
1680            return inputs
1681
1682        if state is None:
1683            raise ValueError("Block Indexer state must be provided on accelerator.")
1684
1685        species = inputs["species"]
1686        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1687        nat = species.shape[0]
1688
1689        sizes = state["sizes"]
1690
1691        # species_index = {
1692        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1693        # for s, c in sizes.items()
1694        # }
1695        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1696        overflow = False
1697        natcount = 0
1698        for (s,name), c in sizes.items():
1699            mask = blocks == s
1700            new_size = jnp.sum(mask)
1701            natcount = natcount + new_size
1702            overflow = overflow | (new_size > c)  # check if sizes are correct
1703            block_index[name] = jnp.nonzero(
1704                mask, size=c, fill_value=nat
1705            )[0]
1706
1707        mask = blocks < 0
1708        new_size = jnp.sum(mask)
1709        natcount = natcount + new_size
1710        overflow = overflow | (
1711            natcount < species.shape[0]
1712        )  # check if any species missing
1713
1714        return {
1715            **inputs,
1716            self.output_key: block_index,
1717            self.output_key + "_overflow": overflow,
1718        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1720    @partial(jax.jit, static_argnums=(0,))
1721    def update_skin(self, inputs):
1722        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class AtomPadding:
1725@dataclasses.dataclass(frozen=True)
1726class AtomPadding:
1727    """Pad atomic arrays to a fixed size."""
1728
1729    mult_size: float = 1.2
1730    """Multiplicative factor for resizing the atomic arrays."""
1731    add_sys: int = 0
1732
1733    def init(self):
1734        return {"prev_nat": 0, "prev_nsys": 0}
1735
1736    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1737        species = inputs["species"]
1738        nat = species.shape[0]
1739
1740        prev_nat = state.get("prev_nat", 0)
1741        prev_nat_ = prev_nat
1742        if nat > prev_nat_:
1743            prev_nat_ = int(self.mult_size * nat) + 1
1744
1745        nsys = len(inputs["natoms"])
1746        prev_nsys = state.get("prev_nsys", 0)
1747        prev_nsys_ = prev_nsys
1748        if nsys > prev_nsys_:
1749            prev_nsys_ = nsys + self.add_sys
1750
1751        add_atoms = prev_nat_ - nat
1752        add_sys = prev_nsys_ - nsys  + 1
1753        output = {**inputs}
1754        if add_atoms > 0:
1755            for k, v in inputs.items():
1756                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1757                    if v.shape[0] == nat:
1758                        output[k] = np.append(
1759                            v,
1760                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1761                            axis=0,
1762                        )
1763                    elif v.shape[0] == nsys:
1764                        if k == "cells":
1765                            output[k] = np.append(
1766                                v,
1767                                1000
1768                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1769                                    add_sys, axis=0
1770                                ),
1771                                axis=0,
1772                            )
1773                        else:
1774                            output[k] = np.append(
1775                                v,
1776                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1777                                axis=0,
1778                            )
1779            output["natoms"] = np.append(
1780                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1781            )
1782            output["species"] = np.append(
1783                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1784            )
1785            output["batch_index"] = np.append(
1786                inputs["batch_index"],
1787                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1788            )
1789            if "system_index" in inputs:
1790                output["system_index"] = np.append(
1791                    inputs["system_index"],
1792                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1793                )
1794
1795        output["true_atoms"] = output["species"] > 0
1796        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1797
1798        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1799
1800        return FrozenDict(state), output

Pad atomic arrays to a fixed size.

AtomPadding(mult_size: float = 1.2, add_sys: int = 0)
mult_size: float = 1.2

Multiplicative factor for resizing the atomic arrays.

add_sys: int = 0
def init(self):
1733    def init(self):
1734        return {"prev_nat": 0, "prev_nsys": 0}
def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1803def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1804    """Remove padding from atomic arrays."""
1805    if "true_atoms" not in inputs:
1806        return inputs
1807
1808    species = np.asarray(inputs["species"])
1809    true_atoms = np.asarray(inputs["true_atoms"])
1810    true_sys = np.asarray(inputs["true_sys"])
1811    natall = species.shape[0]
1812    nat = np.argmax(species <= 0)
1813    if nat == 0:
1814        return inputs
1815
1816    natoms = inputs["natoms"]
1817    nsysall = len(natoms)
1818
1819    output = {**inputs}
1820    for k, v in inputs.items():
1821        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1822            v = np.asarray(v)
1823            if v.ndim == 0:
1824                output[k] = v
1825            elif v.shape[0] == natall:
1826                output[k] = v[true_atoms]
1827            elif v.shape[0] == nsysall:
1828                output[k] = v[true_sys]
1829    del output["true_sys"]
1830    del output["true_atoms"]
1831    return output

Remove padding from atomic arrays.

def check_input(inputs):
1834def check_input(inputs):
1835    """Check the input dictionary for required keys and types."""
1836    assert "species" in inputs, "species must be provided"
1837    assert "coordinates" in inputs, "coordinates must be provided"
1838    species = inputs["species"].astype(np.int32)
1839    ifake = np.argmax(species <= 0)
1840    if ifake > 0:
1841        assert np.all(species[:ifake] > 0), "species must be positive"
1842    nat = inputs["species"].shape[0]
1843
1844    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1845    batch_index = inputs.get(
1846        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1847    ).astype(np.int32)
1848    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1849    if "cells" in inputs:
1850        cells = inputs["cells"]
1851        if "reciprocal_cells" not in inputs:
1852            reciprocal_cells = np.linalg.inv(cells)
1853        else:
1854            reciprocal_cells = inputs["reciprocal_cells"]
1855        if cells.ndim == 2:
1856            cells = cells[None, :, :]
1857        if reciprocal_cells.ndim == 2:
1858            reciprocal_cells = reciprocal_cells[None, :, :]
1859        output["cells"] = cells
1860        output["reciprocal_cells"] = reciprocal_cells
1861
1862    return output

Check the input dictionary for required keys and types.

def convert_to_jax(data):
1865def convert_to_jax(data):
1866    """Convert a numpy arrays to jax arrays in a pytree."""
1867
1868    def convert(x):
1869        if isinstance(x, np.ndarray):
1870            # if x.dtype == np.float64:
1871            #     return jnp.asarray(x, dtype=jnp.float32)
1872            return jnp.asarray(x)
1873        return x
1874
1875    return jax.tree_util.tree_map(convert, data)

Convert a numpy arrays to jax arrays in a pytree.

class JaxConverter(flax.linen.module.Module):
1878class JaxConverter(nn.Module):
1879    """Convert numpy arrays to jax arrays in a pytree."""
1880
1881    def __call__(self, data):
1882        return convert_to_jax(data)

Convert numpy arrays to jax arrays in a pytree.

JaxConverter( parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
def convert_to_numpy(data):
1884def convert_to_numpy(data):
1885    """Convert jax arrays to numpy arrays in a pytree."""
1886
1887    def convert(x):
1888        if isinstance(x, jax.Array):
1889            return np.array(x)
1890        return x
1891
1892    return jax.tree_util.tree_map(convert, data)

Convert jax arrays to numpy arrays in a pytree.

@dataclasses.dataclass(frozen=True)
class PreprocessingChain:
1895@dataclasses.dataclass(frozen=True)
1896class PreprocessingChain:
1897    """Chain of preprocessing layers."""
1898
1899    layers: Tuple[Callable[..., Dict[str, Any]]]
1900    """Preprocessing layers."""
1901    use_atom_padding: bool = False
1902    """Add an AtomPadding layer at the beginning of the chain."""
1903    atom_padder: AtomPadding = AtomPadding()
1904    """AtomPadding layer."""
1905
1906    def __post_init__(self):
1907        if not isinstance(self.layers, Sequence):
1908            raise ValueError(
1909                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1910            )
1911        if not self.layers:
1912            raise ValueError(f"Error: no Preprocessing layers were provided.")
1913
1914    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1915        do_check_input = state.get("check_input", True)
1916        if do_check_input:
1917            inputs = check_input(inputs)
1918        new_state = {**state}
1919        if self.use_atom_padding:
1920            s, inputs = self.atom_padder(state["padder_state"], inputs)
1921            new_state["padder_state"] = s
1922        layer_state = state["layers_state"]
1923        new_layer_state = []
1924        for i,layer in enumerate(self.layers):
1925            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1926            new_layer_state.append(s)
1927        new_state["layers_state"] = tuple(new_layer_state)
1928        return FrozenDict(new_state), convert_to_jax(inputs)
1929
1930    def check_reallocate(self, state, inputs):
1931        new_state = []
1932        state_up = []
1933        layer_state = state["layers_state"]
1934        parent_overflow = False
1935        for i,layer in enumerate(self.layers):
1936            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1937                layer_state[i], inputs, parent_overflow
1938            )
1939            new_state.append(s)
1940            state_up.append(s_up)
1941
1942        if not parent_overflow:
1943            return state, {}, inputs, False
1944        return (
1945            FrozenDict({**state, "layers_state": tuple(new_state)}),
1946            state_up,
1947            inputs,
1948            True,
1949        )
1950
1951    def atom_padding(self, state, inputs):
1952        if self.use_atom_padding:
1953            padder_state,inputs = self.atom_padder(state["padder_state"], inputs)
1954            return FrozenDict({**state,"padder_state": padder_state}), inputs
1955        return state, inputs
1956
1957    @partial(jax.jit, static_argnums=(0, 1))
1958    def process(self, state, inputs):
1959        layer_state = state["layers_state"]
1960        for i,layer in enumerate(self.layers):
1961            inputs = layer.process(layer_state[i], inputs)
1962        return inputs
1963
1964    @partial(jax.jit, static_argnums=(0))
1965    def update_skin(self, inputs):
1966        for layer in self.layers:
1967            inputs = layer.update_skin(inputs)
1968        return inputs
1969
1970    def init(self):
1971        state = {"check_input": True}
1972        if self.use_atom_padding:
1973            state["padder_state"] = self.atom_padder.init()
1974        layer_state = []
1975        for layer in self.layers:
1976            layer_state.append(layer.init())
1977        state["layers_state"] = tuple(layer_state)
1978        return FrozenDict(state)
1979
1980    def init_with_output(self, inputs):
1981        state = self.init()
1982        return self(state, inputs)
1983
1984    def get_processors(self):
1985        processors = []
1986        for layer in self.layers:
1987            if hasattr(layer, "get_processor"):
1988                processors.append(layer.get_processor())
1989        return processors
1990
1991    def get_graphs_properties(self):
1992        properties = {}
1993        for layer in self.layers:
1994            if hasattr(layer, "get_graph_properties"):
1995                properties = deep_update(properties, layer.get_graph_properties())
1996        return properties

Chain of preprocessing layers.

PreprocessingChain( layers: Tuple[Callable[..., Dict[str, Any]]], use_atom_padding: bool = False, atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0))
layers: Tuple[Callable[..., Dict[str, Any]]]

Preprocessing layers.

use_atom_padding: bool = False

Add an AtomPadding layer at the beginning of the chain.

atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0)

AtomPadding layer.

def check_reallocate(self, state, inputs):
1930    def check_reallocate(self, state, inputs):
1931        new_state = []
1932        state_up = []
1933        layer_state = state["layers_state"]
1934        parent_overflow = False
1935        for i,layer in enumerate(self.layers):
1936            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1937                layer_state[i], inputs, parent_overflow
1938            )
1939            new_state.append(s)
1940            state_up.append(s_up)
1941
1942        if not parent_overflow:
1943            return state, {}, inputs, False
1944        return (
1945            FrozenDict({**state, "layers_state": tuple(new_state)}),
1946            state_up,
1947            inputs,
1948            True,
1949        )
def atom_padding(self, state, inputs):
1951    def atom_padding(self, state, inputs):
1952        if self.use_atom_padding:
1953            padder_state,inputs = self.atom_padder(state["padder_state"], inputs)
1954            return FrozenDict({**state,"padder_state": padder_state}), inputs
1955        return state, inputs
@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1957    @partial(jax.jit, static_argnums=(0, 1))
1958    def process(self, state, inputs):
1959        layer_state = state["layers_state"]
1960        for i,layer in enumerate(self.layers):
1961            inputs = layer.process(layer_state[i], inputs)
1962        return inputs
@partial(jax.jit, static_argnums=0)
def update_skin(self, inputs):
1964    @partial(jax.jit, static_argnums=(0))
1965    def update_skin(self, inputs):
1966        for layer in self.layers:
1967            inputs = layer.update_skin(inputs)
1968        return inputs
def init(self):
1970    def init(self):
1971        state = {"check_input": True}
1972        if self.use_atom_padding:
1973            state["padder_state"] = self.atom_padder.init()
1974        layer_state = []
1975        for layer in self.layers:
1976            layer_state.append(layer.init())
1977        state["layers_state"] = tuple(layer_state)
1978        return FrozenDict(state)
def init_with_output(self, inputs):
1980    def init_with_output(self, inputs):
1981        state = self.init()
1982        return self(state, inputs)
def get_processors(self):
1984    def get_processors(self):
1985        processors = []
1986        for layer in self.layers:
1987            if hasattr(layer, "get_processor"):
1988                processors.append(layer.get_processor())
1989        return processors
def get_graphs_properties(self):
1991    def get_graphs_properties(self):
1992        properties = {}
1993        for layer in self.layers:
1994            if hasattr(layer, "get_graph_properties"):
1995                properties = deep_update(properties, layer.get_graph_properties())
1996        return properties