fennol.models.preprocessing
1import flax.linen as nn 2from typing import Sequence, Callable, Union, Dict, Any, ClassVar 3import jax.numpy as jnp 4import jax 5import numpy as np 6from typing import Optional, Tuple 7import numba 8import dataclasses 9from functools import partial 10 11from flax.core.frozen_dict import FrozenDict 12 13 14from ..utils.activations import chain,safe_sqrt 15from ..utils import deep_update, mask_filter_1d 16from ..utils.kspace import get_reciprocal_space_parameters 17from .misc.misc import SwitchFunction 18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES 19 20 21@dataclasses.dataclass(frozen=True) 22class GraphGenerator: 23 """Generate a graph from a set of coordinates 24 25 FPID: GRAPH 26 27 For now, we generate all pairs of atoms and filter based on cutoff. 28 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 29 """ 30 31 cutoff: float 32 """Cutoff distance for the graph.""" 33 graph_key: str = "graph" 34 """Key of the graph in the outputs.""" 35 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 36 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 37 kmax: int = 30 38 """Maximum number of k-points to consider.""" 39 kthr: float = 1e-6 40 """Threshold for k-point filtering.""" 41 k_space: bool = False 42 """Whether to generate k-space information for the graph.""" 43 mult_size: float = 1.05 44 """Multiplicative factor for resizing the nblist.""" 45 # covalent_cutoff: bool = False 46 47 FPID: ClassVar[str] = "GRAPH" 48 49 def init(self): 50 return FrozenDict( 51 { 52 "max_nat": 1, 53 "npairs": 1, 54 "nblist_mult_size": self.mult_size, 55 } 56 ) 57 58 def get_processor(self) -> Tuple[nn.Module, Dict]: 59 return GraphProcessor, { 60 "cutoff": self.cutoff, 61 "graph_key": self.graph_key, 62 "switch_params": self.switch_params, 63 "name": f"{self.graph_key}_Processor", 64 } 65 66 def get_graph_properties(self): 67 return { 68 self.graph_key: { 69 "cutoff": self.cutoff, 70 "directed": True, 71 } 72 } 73 74 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 75 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 76 if self.graph_key in inputs: 77 graph = inputs[self.graph_key] 78 if "keep_graph" in graph: 79 return state, inputs 80 81 coords = np.array(inputs["coordinates"], dtype=np.float32) 82 natoms = np.array(inputs["natoms"], dtype=np.int32) 83 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 84 85 new_state = {**state} 86 state_up = {} 87 88 mult_size = state.get("nblist_mult_size", self.mult_size) 89 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 90 91 if natoms.shape[0] == 1: 92 max_nat = coords.shape[0] 93 true_max_nat = max_nat 94 else: 95 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 96 true_max_nat = int(np.max(natoms)) 97 if true_max_nat > max_nat: 98 add_atoms = state.get("add_atoms", 0) 99 new_maxnat = true_max_nat + add_atoms 100 state_up["max_nat"] = (new_maxnat, max_nat) 101 new_state["max_nat"] = new_maxnat 102 103 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 104 105 ### compute indices of all pairs 106 apply_pbc = "cells" in inputs 107 minimage = "minimum_image" in inputs.get("flags", {}) 108 include_self_image = apply_pbc and not minimage 109 110 shift = 0 if include_self_image else 1 111 p1, p2 = np.triu_indices(true_max_nat, shift) 112 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 113 pbc_shifts = None 114 if natoms.shape[0] > 1: 115 ## batching => mask irrelevant pairs 116 mask_p12 = ( 117 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 118 ).flatten() 119 shift = np.concatenate( 120 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 121 ) 122 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 123 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 124 125 if not apply_pbc: 126 ### NO PBC 127 vec = coords[p2] - coords[p1] 128 else: 129 cells = np.array(inputs["cells"], dtype=np.float32) 130 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 131 if minimage: 132 ## MINIMUM IMAGE CONVENTION 133 vec = coords[p2] - coords[p1] 134 if cells.shape[0] == 1: 135 vecpbc = np.dot(vec, reciprocal_cells[0]) 136 pbc_shifts = -np.round(vecpbc).astype(np.int32) 137 vec = vec + np.dot(pbc_shifts, cells[0]) 138 else: 139 batch_index_vec = batch_index[p1] 140 vecpbc = np.einsum( 141 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 142 ) 143 pbc_shifts = -np.round(vecpbc).astype(np.int32) 144 vec = vec + np.einsum( 145 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 146 ) 147 else: 148 ### GENERAL PBC 149 ## put all atoms in central box 150 if cells.shape[0] == 1: 151 coords_pbc = np.dot(coords, reciprocal_cells[0]) 152 at_shifts = -np.floor(coords_pbc).astype(np.int32) 153 coords_pbc = coords + np.dot(at_shifts, cells[0]) 154 else: 155 coords_pbc = np.einsum( 156 "aj,aji->ai", coords, reciprocal_cells[batch_index] 157 ) 158 at_shifts = -np.floor(coords_pbc).astype(np.int32) 159 coords_pbc = coords + np.einsum( 160 "aj,aji->ai", at_shifts, cells[batch_index] 161 ) 162 vec = coords_pbc[p2] - coords_pbc[p1] 163 164 ## compute maximum number of repeats 165 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 166 cdinv = cutoff_skin * inv_distances 167 num_repeats_all = np.ceil(cdinv).astype(np.int32) 168 if "true_sys" in inputs: 169 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 170 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 171 num_repeats = np.max(num_repeats_all, axis=0) 172 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 173 if np.any(num_repeats > num_repeats_prev): 174 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 175 state_up["num_repeats_pbc"] = ( 176 tuple(num_repeats_new), 177 tuple(num_repeats_prev), 178 ) 179 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 180 ## build all possible shifts 181 cell_shift_pbc = np.array( 182 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 183 dtype=np.int32, 184 ).T.reshape(-1, 3) 185 ## shift applied to vectors 186 if cells.shape[0] == 1: 187 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 188 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 189 pbc_shifts = np.broadcast_to( 190 cell_shift_pbc[None, :, :], 191 (p1.shape[0], cell_shift_pbc.shape[0], 3), 192 ).reshape(-1, 3) 193 p1 = np.broadcast_to( 194 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 195 ).flatten() 196 p2 = np.broadcast_to( 197 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 198 ).flatten() 199 if natoms.shape[0] > 1: 200 mask_p12 = np.broadcast_to( 201 mask_p12[:, None], 202 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 203 ).flatten() 204 else: 205 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 206 207 ## get pbc shifts specific to each box 208 cell_shift_pbc = np.broadcast_to( 209 cell_shift_pbc[None, :, :], 210 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 211 ) 212 mask = np.all( 213 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 214 ).flatten() 215 idx = np.nonzero(mask)[0] 216 nshifts = idx.shape[0] 217 nshifts_prev = state.get("nshifts_pbc", 0) 218 if nshifts > nshifts_prev or add_margin: 219 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 220 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 221 new_state["nshifts_pbc"] = nshifts_new 222 223 dvec_filter = dvec.reshape(-1, 3)[idx, :] 224 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 225 226 ## get batch shift in the dvec_filter array 227 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 228 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 229 230 ## compute vectors 231 batch_index_vec = batch_index[p1] 232 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 233 vec = vec.repeat(nrep_vec, axis=0) 234 nvec_pbc = nrep_vec.sum() #vec.shape[0] 235 nvec_pbc_prev = state.get("nvec_pbc", 0) 236 if nvec_pbc > nvec_pbc_prev or add_margin: 237 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 238 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 239 new_state["nvec_pbc"] = nvec_pbc_new 240 241 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 242 ## get shift index 243 dshift = np.concatenate( 244 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 245 ).repeat(nrep_vec) 246 # ishift = np.arange(dshift.shape[0])-dshift 247 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 248 icellshift = ( 249 np.arange(dshift.shape[0]) 250 - dshift 251 + bshift[batch_index_vec].repeat(nrep_vec) 252 ) 253 # shift vectors 254 vec = vec + dvec_filter[icellshift] 255 pbc_shifts = cell_shift_pbc_filter[icellshift] 256 257 p1 = np.repeat(p1, nrep_vec) 258 p2 = np.repeat(p2, nrep_vec) 259 if natoms.shape[0] > 1: 260 mask_p12 = np.repeat(mask_p12, nrep_vec) 261 262 ## compute distances 263 d12 = (vec**2).sum(axis=-1) 264 if natoms.shape[0] > 1: 265 d12 = np.where(mask_p12, d12, cutoff_skin**2) 266 267 ## filter pairs 268 max_pairs = state.get("npairs", 1) 269 mask = d12 < cutoff_skin**2 270 if include_self_image: 271 mask_self = np.logical_or(p1 != p2, d12 > 1.e-3) 272 mask = np.logical_and(mask, mask_self) 273 idx = np.nonzero(mask)[0] 274 npairs = idx.shape[0] 275 if npairs > max_pairs or add_margin: 276 prev_max_pairs = max_pairs 277 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 278 state_up["npairs"] = (max_pairs, prev_max_pairs) 279 new_state["npairs"] = max_pairs 280 281 nat = coords.shape[0] 282 edge_src = np.full(max_pairs, nat, dtype=np.int32) 283 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 284 d12_ = np.full(max_pairs, cutoff_skin**2) 285 edge_src[:npairs] = p1[idx] 286 edge_dst[:npairs] = p2[idx] 287 d12_[:npairs] = d12[idx] 288 d12 = d12_ 289 290 if apply_pbc: 291 pbc_shifts_ = np.zeros((max_pairs, 3), dtype=np.int32) 292 pbc_shifts_[:npairs] = pbc_shifts[idx] 293 pbc_shifts = pbc_shifts_ 294 if not minimage: 295 pbc_shifts[:npairs] = ( 296 pbc_shifts[:npairs] 297 + at_shifts[edge_dst[:npairs]] 298 - at_shifts[edge_src[:npairs]] 299 ) 300 301 302 ## symmetrize 303 if include_self_image: 304 mask_noself = edge_src != edge_dst 305 idx_noself = np.nonzero(mask_noself)[0] 306 npairs_noself = idx_noself.shape[0] 307 max_noself = state.get("npairs_noself", 1) 308 if npairs_noself > max_noself or add_margin: 309 prev_max_noself = max_noself 310 max_noself = int(mult_size * max(npairs_noself, max_noself)) + 1 311 state_up["npairs_noself"] = (max_noself, prev_max_noself) 312 new_state["npairs_noself"] = max_noself 313 314 edge_src_noself = np.full(max_noself, nat, dtype=np.int32) 315 edge_dst_noself = np.full(max_noself, nat, dtype=np.int32) 316 d12_noself = np.full(max_noself, cutoff_skin**2) 317 pbc_shifts_noself = np.zeros((max_noself, 3), dtype=np.int32) 318 319 edge_dst_noself[:npairs_noself] = edge_dst[idx_noself] 320 edge_src_noself[:npairs_noself] = edge_src[idx_noself] 321 d12_noself[:npairs_noself] = d12[idx_noself] 322 pbc_shifts_noself[:npairs_noself] = pbc_shifts[idx_noself] 323 edge_src = np.concatenate((edge_src, edge_dst_noself)) 324 edge_dst = np.concatenate((edge_dst, edge_src_noself)) 325 d12 = np.concatenate((d12, d12_noself)) 326 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts_noself)) 327 else: 328 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 329 (edge_dst, edge_src) 330 ) 331 d12 = np.concatenate((d12, d12)) 332 if apply_pbc: 333 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 334 335 if "nblist_skin" in state: 336 edge_src_skin = edge_src 337 edge_dst_skin = edge_dst 338 if apply_pbc: 339 pbc_shifts_skin = pbc_shifts 340 max_pairs_skin = state.get("npairs_skin", 1) 341 mask = d12 < self.cutoff**2 342 idx = np.nonzero(mask)[0] 343 npairs_skin = idx.shape[0] 344 if npairs_skin > max_pairs_skin or add_margin: 345 prev_max_pairs_skin = max_pairs_skin 346 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 347 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 348 new_state["npairs_skin"] = max_pairs_skin 349 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 350 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 351 d12_ = np.full(max_pairs_skin, self.cutoff**2) 352 edge_src[:npairs_skin] = edge_src_skin[idx] 353 edge_dst[:npairs_skin] = edge_dst_skin[idx] 354 d12_[:npairs_skin] = d12[idx] 355 d12 = d12_ 356 if apply_pbc: 357 pbc_shifts = np.zeros((max_pairs_skin, 3), dtype=np.int32) 358 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 359 360 graph = inputs.get(self.graph_key, {}) 361 graph_out = { 362 **graph, 363 "edge_src": edge_src, 364 "edge_dst": edge_dst, 365 "d12": d12, 366 "overflow": False, 367 "pbc_shifts": pbc_shifts, 368 } 369 if "nblist_skin" in state: 370 graph_out["edge_src_skin"] = edge_src_skin 371 graph_out["edge_dst_skin"] = edge_dst_skin 372 if apply_pbc: 373 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 374 375 if self.k_space and apply_pbc: 376 if "k_points" not in graph: 377 ks, _, _, bewald = get_reciprocal_space_parameters( 378 reciprocal_cells, self.cutoff, self.kmax, self.kthr 379 ) 380 graph_out["k_points"] = ks 381 graph_out["b_ewald"] = bewald 382 383 output = {**inputs, self.graph_key: graph_out} 384 385 if return_state_update: 386 return FrozenDict(new_state), output, state_up 387 return FrozenDict(new_state), output 388 389 def check_reallocate(self, state, inputs, parent_overflow=False): 390 """check for overflow and reallocate nblist if necessary""" 391 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 392 if not overflow: 393 return state, {}, inputs, False 394 395 add_margin = inputs[self.graph_key].get("overflow", False) 396 state, inputs, state_up = self( 397 state, inputs, return_state_update=True, add_margin=add_margin 398 ) 399 return state, state_up, inputs, True 400 401 @partial(jax.jit, static_argnums=(0, 1)) 402 def process(self, state, inputs): 403 """build a nblist on accelerator with jax and precomputed shapes""" 404 if self.graph_key in inputs: 405 graph = inputs[self.graph_key] 406 if "keep_graph" in graph: 407 return inputs 408 coords = inputs["coordinates"] 409 natoms = inputs["natoms"] 410 batch_index = inputs["batch_index"] 411 412 if natoms.shape[0] == 1: 413 max_nat = coords.shape[0] 414 else: 415 max_nat = state.get( 416 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 417 ) 418 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 419 420 ### compute indices of all pairs 421 apply_pbc = "cells" in inputs 422 minimage = "minimum_image" in inputs.get("flags", {}) 423 include_self_image = apply_pbc and not minimage 424 425 shift = 0 if include_self_image else 1 426 p1, p2 = np.triu_indices(max_nat, shift) 427 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 428 pbc_shifts = None 429 if natoms.shape[0] > 1: 430 ## batching => mask irrelevant pairs 431 mask_p12 = ( 432 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 433 ).flatten() 434 shift = jnp.concatenate( 435 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 436 ) 437 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 438 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 439 440 ## compute vectors 441 overflow_repeats = jnp.asarray(False, dtype=bool) 442 if "cells" not in inputs: 443 vec = coords[p2] - coords[p1] 444 else: 445 cells = inputs["cells"] 446 reciprocal_cells = inputs["reciprocal_cells"] 447 # minimage = state.get("minimum_image", True) 448 minimage = "minimum_image" in inputs.get("flags", {}) 449 450 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 451 vecpbc = jnp.dot(vec, reciprocal_cell) 452 if mode == "round": 453 pbc_shifts = -jnp.round(vecpbc).astype(jnp.int32) 454 elif mode == "floor": 455 pbc_shifts = -jnp.floor(vecpbc).astype(jnp.int32) 456 else: 457 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 458 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 459 460 if minimage: 461 ## minimum image convention 462 vec = coords[p2] - coords[p1] 463 464 if cells.shape[0] == 1: 465 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 466 else: 467 batch_index_vec = batch_index[p1] 468 vec, pbc_shifts = jax.vmap(compute_pbc)( 469 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 470 ) 471 else: 472 ### general PBC only for single cell yet 473 # if cells.shape[0] > 1: 474 # raise NotImplementedError( 475 # "General PBC not implemented for batches on accelerator." 476 # ) 477 # cell = cells[0] 478 # reciprocal_cell = reciprocal_cells[0] 479 480 ## put all atoms in central box 481 if cells.shape[0] == 1: 482 coords_pbc, at_shifts = compute_pbc( 483 coords, reciprocal_cells[0], cells[0], mode="floor" 484 ) 485 else: 486 coords_pbc, at_shifts = jax.vmap( 487 partial(compute_pbc, mode="floor") 488 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 489 vec = coords_pbc[p2] - coords_pbc[p1] 490 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 491 # if num_repeats is None: 492 # raise ValueError( 493 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 494 # ) 495 # check if num_repeats is larger than previous 496 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 497 cdinv = cutoff_skin * inv_distances 498 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 499 if "true_sys" in inputs: 500 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 501 num_repeats_new = jnp.max(num_repeats_all, axis=0) 502 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 503 504 cell_shift_pbc = jnp.asarray( 505 np.array( 506 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 507 dtype=np.int32, 508 ).T.reshape(-1, 3) 509 ) 510 511 if cells.shape[0] == 1: 512 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 513 pbc_shifts = jnp.broadcast_to( 514 cell_shift_pbc[None, :, :], 515 (p1.shape[0], cell_shift_pbc.shape[0], 3), 516 ).reshape(-1, 3) 517 p1 = jnp.broadcast_to( 518 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 519 ).flatten() 520 p2 = jnp.broadcast_to( 521 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 522 ).flatten() 523 if natoms.shape[0] > 1: 524 mask_p12 = jnp.broadcast_to( 525 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 526 ).flatten() 527 else: 528 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 529 530 ## get pbc shifts specific to each box 531 cell_shift_pbc = jnp.broadcast_to( 532 cell_shift_pbc[None, :, :], 533 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 534 ) 535 mask = jnp.all( 536 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 537 ).flatten() 538 max_shifts = state.get("nshifts_pbc", 1) 539 540 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 541 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 542 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 543 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 544 mask, 545 max_shifts, 546 (dvecx, 0.), 547 (dvecy, 0.), 548 (dvecz, 0.), 549 (shiftx, 0), 550 (shifty, 0), 551 (shiftz, 0), 552 ) 553 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 554 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1, dtype=jnp.int32) 555 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 556 557 ## get batch shift in the dvec_filter array 558 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 559 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 560 561 ## repeat vectors 562 nvec_max = state.get("nvec_pbc", 1) 563 batch_index_vec = batch_index[p1] 564 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 565 nvec = nrep_vec.sum() 566 overflow_repeats = overflow_repeats | (nvec > nvec_max) 567 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 568 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 569 570 ## get shift index 571 dshift = jnp.concatenate( 572 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 573 ) 574 if nrep_vec.size == 0: 575 dshift = jnp.array([],dtype=jnp.int32) 576 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 577 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 578 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 579 vec = vec + dvec[icellshift] 580 pbc_shifts = cell_shift_pbc[icellshift] 581 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 582 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 583 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 584 585 mask_valid = jnp.arange(nvec_max) < nvec 586 mask_p12 = jnp.where(mask_valid, mask_p12, False) 587 588 589 ## compute distances 590 d12 = (vec**2).sum(axis=-1) 591 if natoms.shape[0] > 1: 592 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 593 594 ## filter pairs 595 max_pairs = state.get("npairs", 1) 596 mask = d12 < cutoff_skin**2 597 if include_self_image: 598 mask_self = jnp.logical_or(p1 != p2, d12 > 1.e-3) 599 mask = jnp.logical_and(mask, mask_self) 600 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 601 mask, 602 max_pairs, 603 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 604 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 605 (d12, cutoff_skin**2), 606 ) 607 if "cells" in inputs: 608 pbc_shifts = ( 609 jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype) 610 .at[scatter_idx] 611 .set(pbc_shifts, mode="drop") 612 ) 613 if not minimage: 614 pbc_shifts = ( 615 pbc_shifts 616 + at_shifts.at[edge_dst].get(fill_value=0) 617 - at_shifts.at[edge_src].get(fill_value=0) 618 ) 619 620 ## check for overflow 621 if natoms.shape[0] == 1: 622 true_max_nat = coords.shape[0] 623 else: 624 true_max_nat = jnp.max(natoms) 625 overflow_count = npairs > max_pairs 626 overflow_at = true_max_nat > max_nat 627 overflow = overflow_count | overflow_at | overflow_repeats 628 629 ## symmetrize 630 if include_self_image: 631 mask_noself = edge_src != edge_dst 632 max_noself = state.get("npairs_noself", 1) 633 (edge_src_noself, edge_dst_noself, d12_noself, pbc_shifts_noself), scatter_idx, npairs_noself = mask_filter_1d( 634 mask_noself, 635 max_noself, 636 (edge_src, coords.shape[0]), 637 (edge_dst, coords.shape[0]), 638 (d12, cutoff_skin**2), 639 (pbc_shifts, jnp.zeros((3,), dtype=pbc_shifts.dtype)), 640 ) 641 overflow = overflow | (npairs_noself > max_noself) 642 edge_src = jnp.concatenate((edge_src, edge_dst_noself)) 643 edge_dst = jnp.concatenate((edge_dst, edge_src_noself)) 644 d12 = jnp.concatenate((d12, d12_noself)) 645 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts_noself)) 646 else: 647 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 648 (edge_dst, edge_src) 649 ) 650 d12 = jnp.concatenate((d12, d12)) 651 if "cells" in inputs: 652 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 653 654 if "nblist_skin" in state: 655 # edge_mask_skin = edge_mask 656 edge_src_skin = edge_src 657 edge_dst_skin = edge_dst 658 if "cells" in inputs: 659 pbc_shifts_skin = pbc_shifts 660 max_pairs_skin = state.get("npairs_skin", 1) 661 mask = d12 < self.cutoff**2 662 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 663 mask, 664 max_pairs_skin, 665 (edge_src, coords.shape[0]), 666 (edge_dst, coords.shape[0]), 667 (d12, self.cutoff**2), 668 ) 669 if "cells" in inputs: 670 pbc_shifts = ( 671 jnp.zeros((max_pairs_skin, 3), dtype=pbc_shifts.dtype) 672 .at[scatter_idx] 673 .set(pbc_shifts, mode="drop") 674 ) 675 overflow = overflow | (npairs_skin > max_pairs_skin) 676 677 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 678 graph_out = { 679 **graph, 680 "edge_src": edge_src, 681 "edge_dst": edge_dst, 682 "d12": d12, 683 "overflow": overflow, 684 "pbc_shifts": pbc_shifts, 685 } 686 if "nblist_skin" in state: 687 graph_out["edge_src_skin"] = edge_src_skin 688 graph_out["edge_dst_skin"] = edge_dst_skin 689 if "cells" in inputs: 690 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 691 692 if self.k_space and "cells" in inputs: 693 if "k_points" not in graph: 694 raise NotImplementedError( 695 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 696 ) 697 return {**inputs, self.graph_key: graph_out} 698 699 @partial(jax.jit, static_argnums=(0,)) 700 def update_skin(self, inputs): 701 """update the nblist without recomputing the full nblist""" 702 graph = inputs[self.graph_key] 703 704 edge_src_skin = graph["edge_src_skin"] 705 edge_dst_skin = graph["edge_dst_skin"] 706 coords = inputs["coordinates"] 707 vec = coords.at[edge_dst_skin].get( 708 mode="fill", fill_value=self.cutoff 709 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 710 711 if "cells" in inputs: 712 pbc_shifts_skin = graph["pbc_shifts_skin"] 713 cells = inputs["cells"] 714 if cells.shape[0] == 1: 715 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 716 else: 717 batch_index_vec = inputs["batch_index"][edge_src_skin] 718 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 719 720 nat = coords.shape[0] 721 d12 = jnp.sum(vec**2, axis=-1) 722 mask = d12 < self.cutoff**2 723 max_pairs = graph["edge_src"].shape[0] 724 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 725 mask, 726 max_pairs, 727 (edge_src_skin, nat), 728 (edge_dst_skin, nat), 729 (d12, self.cutoff**2), 730 ) 731 if "cells" in inputs: 732 pbc_shifts = ( 733 jnp.zeros((max_pairs, 3), dtype=pbc_shifts_skin.dtype) 734 .at[scatter_idx] 735 .set(pbc_shifts_skin) 736 ) 737 738 overflow = graph.get("overflow", False) | (npairs > max_pairs) 739 graph_out = { 740 **graph, 741 "edge_src": edge_src, 742 "edge_dst": edge_dst, 743 "d12": d12, 744 "overflow": overflow, 745 } 746 if "cells" in inputs: 747 graph_out["pbc_shifts"] = pbc_shifts 748 749 if self.k_space and "cells" in inputs: 750 if "k_points" not in graph: 751 raise NotImplementedError( 752 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 753 ) 754 755 return {**inputs, self.graph_key: graph_out} 756 757 758class GraphProcessor(nn.Module): 759 """Process a pre-generated graph 760 761 The pre-generated graph should contain the following keys: 762 - edge_src: source indices of the edges 763 - edge_dst: destination indices of the edges 764 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 765 766 This module is automatically added to a FENNIX model when a GraphGenerator is used. 767 768 """ 769 770 cutoff: float 771 """Cutoff distance for the graph.""" 772 graph_key: str = "graph" 773 """Key of the graph in the outputs.""" 774 switch_params: dict = dataclasses.field(default_factory=dict) 775 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 776 777 @nn.compact 778 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 779 graph = inputs[self.graph_key] 780 coords = inputs["coordinates"] 781 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 782 # edge_mask = edge_src < coords.shape[0] 783 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 784 edge_src 785 ].get(mode="fill", fill_value=0.0) 786 if "cells" in inputs: 787 cells = inputs["cells"] 788 if cells.shape[0] == 1: 789 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 790 else: 791 batch_index_vec = inputs["batch_index"][edge_src] 792 vec = vec + jax.vmap(jnp.dot)( 793 graph["pbc_shifts"], cells[batch_index_vec] 794 ) 795 796 d2 = jnp.sum(vec**2, axis=-1) 797 distances = safe_sqrt(d2) 798 edge_mask = distances < self.cutoff 799 800 switch = SwitchFunction( 801 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 802 )((distances, edge_mask)) 803 804 graph_out = { 805 **graph, 806 "vec": vec, 807 "distances": distances, 808 "switch": switch, 809 "edge_mask": edge_mask, 810 } 811 812 if "alch_group" in inputs: 813 alch_group = inputs["alch_group"] 814 lambda_e = inputs["alch_elambda"] 815 lambda_v = inputs["alch_vlambda"] 816 mask = alch_group[edge_src] == alch_group[edge_dst] 817 graph_out["switch_raw"] = switch 818 graph_out["switch"] = jnp.where( 819 mask, 820 switch, 821 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 822 ) 823 graph_out["distances_raw"] = distances 824 if "alch_softcore_e" in inputs: 825 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 826 else: 827 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 828 829 graph_out["distances"] = jnp.where( 830 mask, 831 distances, 832 safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2)) 833 ) 834 835 836 return {**inputs, self.graph_key: graph_out} 837 838 839@dataclasses.dataclass(frozen=True) 840class GraphFilter: 841 """Filter a graph based on a cutoff distance 842 843 FPID: GRAPH_FILTER 844 """ 845 846 cutoff: float 847 """Cutoff distance for the filtering.""" 848 parent_graph: str 849 """Key of the parent graph in the inputs.""" 850 graph_key: str 851 """Key of the filtered graph in the outputs.""" 852 remove_hydrogens: int = False 853 """Remove edges where the source is a hydrogen atom.""" 854 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 855 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 856 k_space: bool = False 857 """Generate k-space information for the graph.""" 858 kmax: int = 30 859 """Maximum number of k-points to consider.""" 860 kthr: float = 1e-6 861 """Threshold for k-point filtering.""" 862 mult_size: float = 1.05 863 """Multiplicative factor for resizing the nblist.""" 864 865 FPID: ClassVar[str] = "GRAPH_FILTER" 866 867 def init(self): 868 return FrozenDict( 869 { 870 "npairs": 1, 871 "nblist_mult_size": self.mult_size, 872 } 873 ) 874 875 def get_processor(self) -> Tuple[nn.Module, Dict]: 876 return GraphFilterProcessor, { 877 "cutoff": self.cutoff, 878 "graph_key": self.graph_key, 879 "parent_graph": self.parent_graph, 880 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 881 "switch_params": self.switch_params, 882 } 883 884 def get_graph_properties(self): 885 return { 886 self.graph_key: { 887 "cutoff": self.cutoff, 888 "directed": True, 889 "parent_graph": self.parent_graph, 890 } 891 } 892 893 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 894 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 895 graph_in = inputs[self.parent_graph] 896 nat = inputs["species"].shape[0] 897 898 new_state = {**state} 899 state_up = {} 900 mult_size = state.get("nblist_mult_size", self.mult_size) 901 assert mult_size >= 1., "nblist_mult_size should be >= 1." 902 903 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 904 d12 = np.array(graph_in["d12"], dtype=np.float32) 905 if self.remove_hydrogens: 906 species = inputs["species"] 907 src_idx = (edge_src < nat).nonzero()[0] 908 mask = np.zeros(edge_src.shape[0], dtype=bool) 909 mask[src_idx] = (species > 1)[edge_src[src_idx]] 910 d12 = np.where(mask, d12, self.cutoff**2) 911 mask = d12 < self.cutoff**2 912 913 max_pairs = state.get("npairs", 1) 914 idx = np.nonzero(mask)[0] 915 npairs = idx.shape[0] 916 if npairs > max_pairs or add_margin: 917 prev_max_pairs = max_pairs 918 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 919 state_up["npairs"] = (max_pairs, prev_max_pairs) 920 new_state["npairs"] = max_pairs 921 922 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 923 edge_src = np.full(max_pairs, nat, dtype=np.int32) 924 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 925 d12_ = np.full(max_pairs, self.cutoff**2) 926 filter_indices[:npairs] = idx 927 edge_src[:npairs] = graph_in["edge_src"][idx] 928 edge_dst[:npairs] = graph_in["edge_dst"][idx] 929 d12_[:npairs] = d12[idx] 930 d12 = d12_ 931 932 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 933 graph_out = { 934 **graph, 935 "edge_src": edge_src, 936 "edge_dst": edge_dst, 937 "filter_indices": filter_indices, 938 "d12": d12, 939 "overflow": False, 940 } 941 if "cells" in inputs: 942 pbc_shifts = np.zeros((max_pairs, 3), dtype=np.int32) 943 pbc_shifts[:npairs] = graph_in["pbc_shifts"][idx] 944 graph_out["pbc_shifts"] = pbc_shifts 945 946 if self.k_space: 947 if "k_points" not in graph: 948 ks, _, _, bewald = get_reciprocal_space_parameters( 949 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 950 ) 951 graph_out["k_points"] = ks 952 graph_out["b_ewald"] = bewald 953 954 output = {**inputs, self.graph_key: graph_out} 955 if return_state_update: 956 return FrozenDict(new_state), output, state_up 957 return FrozenDict(new_state), output 958 959 def check_reallocate(self, state, inputs, parent_overflow=False): 960 """check for overflow and reallocate nblist if necessary""" 961 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 962 if not overflow: 963 return state, {}, inputs, False 964 965 add_margin = inputs[self.graph_key].get("overflow", False) 966 state, inputs, state_up = self( 967 state, inputs, return_state_update=True, add_margin=add_margin 968 ) 969 return state, state_up, inputs, True 970 971 @partial(jax.jit, static_argnums=(0, 1)) 972 def process(self, state, inputs): 973 """filter a nblist on accelerator with jax and precomputed shapes""" 974 graph_in = inputs[self.parent_graph] 975 if state is None: 976 # skin update mode 977 graph = inputs[self.graph_key] 978 max_pairs = graph["edge_src"].shape[0] 979 else: 980 max_pairs = state.get("npairs", 1) 981 982 max_pairs_in = graph_in["edge_src"].shape[0] 983 nat = inputs["species"].shape[0] 984 985 edge_src = graph_in["edge_src"] 986 d12 = graph_in["d12"] 987 if self.remove_hydrogens: 988 species = inputs["species"] 989 mask = (species > 1)[edge_src] 990 d12 = jnp.where(mask, d12, self.cutoff**2) 991 mask = d12 < self.cutoff**2 992 993 (edge_src, edge_dst, d12, filter_indices), scatter_idx, npairs = mask_filter_1d( 994 mask, 995 max_pairs, 996 (edge_src, nat), 997 (graph_in["edge_dst"], nat), 998 (d12, self.cutoff**2), 999 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 1000 ) 1001 1002 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 1003 overflow = graph.get("overflow", False) | (npairs > max_pairs) 1004 graph_out = { 1005 **graph, 1006 "edge_src": edge_src, 1007 "edge_dst": edge_dst, 1008 "filter_indices": filter_indices, 1009 "d12": d12, 1010 "overflow": overflow, 1011 } 1012 1013 if "cells" in inputs: 1014 pbc_shifts = graph_in["pbc_shifts"] 1015 pbc_shifts = ( 1016 jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype) 1017 .at[scatter_idx].set(pbc_shifts, mode="drop") 1018 ) 1019 graph_out["pbc_shifts"] = pbc_shifts 1020 if self.k_space: 1021 if "k_points" not in graph: 1022 raise NotImplementedError( 1023 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 1024 ) 1025 1026 return {**inputs, self.graph_key: graph_out} 1027 1028 @partial(jax.jit, static_argnums=(0,)) 1029 def update_skin(self, inputs): 1030 return self.process(None, inputs) 1031 1032 1033class GraphFilterProcessor(nn.Module): 1034 """Filter processing for a pre-generated graph 1035 1036 This module is automatically added to a FENNIX model when a GraphFilter is used. 1037 """ 1038 1039 cutoff: float 1040 """Cutoff distance for the filtering.""" 1041 graph_key: str 1042 """Key of the filtered graph in the inputs.""" 1043 parent_graph: str 1044 """Key of the parent graph in the inputs.""" 1045 switch_params: dict = dataclasses.field(default_factory=dict) 1046 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 1047 1048 @nn.compact 1049 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1050 graph_in = inputs[self.parent_graph] 1051 graph = inputs[self.graph_key] 1052 1053 d_key = "distances_raw" if "distances_raw" in graph else "distances" 1054 1055 if graph_in["vec"].shape[0] == 0: 1056 vec = graph_in["vec"] 1057 distances = graph_in[d_key] 1058 filter_indices = jnp.asarray([], dtype=jnp.int32) 1059 else: 1060 filter_indices = graph["filter_indices"] 1061 vec = ( 1062 graph_in["vec"] 1063 .at[filter_indices] 1064 .get(mode="fill", fill_value=self.cutoff) 1065 ) 1066 distances = ( 1067 graph_in[d_key] 1068 .at[filter_indices] 1069 .get(mode="fill", fill_value=self.cutoff) 1070 ) 1071 1072 edge_mask = distances < self.cutoff 1073 switch = SwitchFunction( 1074 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 1075 )((distances, edge_mask)) 1076 1077 graph_out = { 1078 **graph, 1079 "vec": vec, 1080 "distances": distances, 1081 "switch": switch, 1082 "filter_indices": filter_indices, 1083 "edge_mask": edge_mask, 1084 } 1085 1086 if "alch_group" in inputs: 1087 edge_src=graph["edge_src"] 1088 edge_dst=graph["edge_dst"] 1089 alch_group = inputs["alch_group"] 1090 lambda_e = inputs["alch_elambda"] 1091 lambda_v = inputs["alch_vlambda"] 1092 mask = alch_group[edge_src] == alch_group[edge_dst] 1093 graph_out["switch_raw"] = switch 1094 graph_out["switch"] = jnp.where( 1095 mask, 1096 switch, 1097 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 1098 ) 1099 1100 graph_out["distances_raw"] = distances 1101 if "alch_softcore_e" in inputs: 1102 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 1103 else: 1104 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 1105 1106 graph_out["distances"] = jnp.where( 1107 mask, 1108 distances, 1109 safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2)) 1110 ) 1111 1112 1113 return {**inputs, self.graph_key: graph_out} 1114 1115 1116@dataclasses.dataclass(frozen=True) 1117class GraphAngularExtension: 1118 """Add angles list to a graph 1119 1120 FPID: GRAPH_ANGULAR_EXTENSION 1121 """ 1122 1123 mult_size: float = 1.05 1124 """Multiplicative factor for resizing the nblist.""" 1125 add_neigh: int = 5 1126 """Additional neighbors to add to the nblist when resizing.""" 1127 graph_key: str = "graph" 1128 """Key of the graph in the inputs.""" 1129 1130 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1131 1132 def init(self): 1133 return FrozenDict( 1134 { 1135 "nangles": 0, 1136 "nblist_mult_size": self.mult_size, 1137 "max_neigh": self.add_neigh, 1138 "add_neigh": self.add_neigh, 1139 } 1140 ) 1141 1142 def get_processor(self) -> Tuple[nn.Module, Dict]: 1143 return GraphAngleProcessor, { 1144 "graph_key": self.graph_key, 1145 "name": f"{self.graph_key}_AngleProcessor", 1146 } 1147 1148 def get_graph_properties(self): 1149 return { 1150 self.graph_key: { 1151 "has_angles": True, 1152 } 1153 } 1154 1155 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1156 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1157 graph = inputs[self.graph_key] 1158 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1159 1160 new_state = {**state} 1161 state_up = {} 1162 mult_size = state.get("nblist_mult_size", self.mult_size) 1163 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1164 1165 ### count number of neighbors 1166 nat = inputs["species"].shape[0] 1167 count = np.zeros(nat + 1, dtype=np.int32) 1168 np.add.at(count, edge_src, 1) 1169 max_count = int(np.max(count[:-1])) 1170 1171 ### get sizes 1172 max_neigh = state.get("max_neigh", self.add_neigh) 1173 nedge = edge_src.shape[0] 1174 if max_count > max_neigh or add_margin: 1175 prev_max_neigh = max_neigh 1176 max_neigh = max(max_count, max_neigh) + state.get( 1177 "add_neigh", self.add_neigh 1178 ) 1179 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1180 new_state["max_neigh"] = max_neigh 1181 1182 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1183 1184 nedge = edge_src.shape[0] 1185 1186 ### sort edge_src 1187 idx_sort = np.argsort(edge_src) 1188 edge_src_sorted = edge_src[idx_sort] 1189 1190 ### map sparse to dense nblist 1191 offset = np.tile(np.arange(max_count), nat) 1192 if max_count * nat >= nedge: 1193 offset = np.tile(np.arange(max_count), nat)[:nedge] 1194 else: 1195 offset = np.zeros(nedge, dtype=np.int32) 1196 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1197 1198 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1199 mask = edge_src_sorted < nat 1200 indices = edge_src_sorted * max_count + offset 1201 indices = indices[mask] 1202 idx_sort = idx_sort[mask] 1203 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1204 edge_idx[indices] = idx_sort 1205 edge_idx = edge_idx.reshape(nat, max_count) 1206 1207 ### find all triplet for each atom center 1208 local_src, local_dst = np.triu_indices(max_count, 1) 1209 angle_src = edge_idx[:, local_src].flatten() 1210 angle_dst = edge_idx[:, local_dst].flatten() 1211 1212 ### mask for valid angles 1213 mask1 = angle_src < nedge 1214 mask2 = angle_dst < nedge 1215 angle_mask = mask1 & mask2 1216 1217 max_angles = state.get("nangles", 0) 1218 idx = np.nonzero(angle_mask)[0] 1219 nangles = idx.shape[0] 1220 if nangles > max_angles or add_margin: 1221 max_angles_prev = max_angles 1222 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1223 state_up["nangles"] = (max_angles, max_angles_prev) 1224 new_state["nangles"] = max_angles 1225 1226 ## filter angles to sparse representation 1227 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1228 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1229 angle_src_[:nangles] = angle_src[idx] 1230 angle_dst_[:nangles] = angle_dst[idx] 1231 1232 central_atom = np.full(max_angles, nat, dtype=np.int32) 1233 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1234 1235 ## update graph 1236 output = { 1237 **inputs, 1238 self.graph_key: { 1239 **graph, 1240 "angle_src": angle_src_, 1241 "angle_dst": angle_dst_, 1242 "central_atom": central_atom, 1243 "angle_overflow": False, 1244 "max_neigh": max_neigh, 1245 "__max_neigh_array": max_neigh_arr, 1246 }, 1247 } 1248 1249 if return_state_update: 1250 return FrozenDict(new_state), output, state_up 1251 return FrozenDict(new_state), output 1252 1253 def check_reallocate(self, state, inputs, parent_overflow=False): 1254 """check for overflow and reallocate nblist if necessary""" 1255 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1256 if not overflow: 1257 return state, {}, inputs, False 1258 1259 add_margin = inputs[self.graph_key]["angle_overflow"] 1260 state, inputs, state_up = self( 1261 state, inputs, return_state_update=True, add_margin=add_margin 1262 ) 1263 return state, state_up, inputs, True 1264 1265 @partial(jax.jit, static_argnums=(0, 1)) 1266 def process(self, state, inputs): 1267 """build angle nblist on accelerator with jax and precomputed shapes""" 1268 graph = inputs[self.graph_key] 1269 edge_src = graph["edge_src"] 1270 1271 ### count number of neighbors 1272 nat = inputs["species"].shape[0] 1273 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1274 max_count = jnp.max(count) 1275 1276 ### get sizes 1277 if state is None: 1278 max_neigh_arr = graph["__max_neigh_array"] 1279 max_neigh = max_neigh_arr.shape[0] 1280 prev_nangles = graph["angle_src"].shape[0] 1281 else: 1282 max_neigh = state.get("max_neigh", self.add_neigh) 1283 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1284 prev_nangles = state.get("nangles", 0) 1285 1286 nedge = edge_src.shape[0] 1287 1288 ### sort edge_src 1289 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1290 edge_src_sorted = edge_src[idx_sort] 1291 1292 ### map sparse to dense nblist 1293 if max_neigh * nat < nedge: 1294 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1295 offset = jnp.asarray( 1296 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1297 ) 1298 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1299 indices = edge_src_sorted * max_neigh + offset 1300 edge_idx = ( 1301 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1302 .at[indices] 1303 .set(idx_sort, mode="drop") 1304 .reshape(nat, max_neigh) 1305 ) 1306 1307 ### find all triplet for each atom center 1308 local_src, local_dst = np.triu_indices(max_neigh, 1) 1309 angle_src = edge_idx[:, local_src].flatten() 1310 angle_dst = edge_idx[:, local_dst].flatten() 1311 1312 ### mask for valid angles 1313 mask1 = angle_src < nedge 1314 mask2 = angle_dst < nedge 1315 angle_mask = mask1 & mask2 1316 1317 ## filter angles to sparse representation 1318 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1319 angle_mask, 1320 prev_nangles, 1321 (angle_src, nedge), 1322 (angle_dst, nedge), 1323 ) 1324 ## find central atom 1325 central_atom = edge_src[angle_src] 1326 1327 ## check for overflow 1328 angle_overflow = nangles > prev_nangles 1329 neigh_overflow = max_count > max_neigh 1330 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1331 1332 ## update graph 1333 output = { 1334 **inputs, 1335 self.graph_key: { 1336 **graph, 1337 "angle_src": angle_src, 1338 "angle_dst": angle_dst, 1339 "central_atom": central_atom, 1340 "angle_overflow": overflow, 1341 # "max_neigh": max_neigh, 1342 "__max_neigh_array": max_neigh_arr, 1343 }, 1344 } 1345 1346 return output 1347 1348 @partial(jax.jit, static_argnums=(0,)) 1349 def update_skin(self, inputs): 1350 return self.process(None, inputs) 1351 1352 1353class GraphAngleProcessor(nn.Module): 1354 """Process a pre-generated graph to compute angles 1355 1356 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1357 1358 """ 1359 1360 graph_key: str 1361 """Key of the graph in the inputs.""" 1362 1363 @nn.compact 1364 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1365 graph = inputs[self.graph_key] 1366 distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"] 1367 vec = graph["vec"] 1368 angle_src = graph["angle_src"] 1369 angle_dst = graph["angle_dst"] 1370 1371 dir = vec / jnp.clip(distances[:, None], min=1.0e-5) 1372 cos_angles = ( 1373 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1374 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1375 ).sum(axis=-1) 1376 1377 angles = jnp.arccos(0.95 * cos_angles) 1378 1379 return { 1380 **inputs, 1381 self.graph_key: { 1382 **graph, 1383 # "cos_angles": cos_angles, 1384 "angles": angles, 1385 # "angle_mask": angle_mask, 1386 }, 1387 } 1388 1389 1390@dataclasses.dataclass(frozen=True) 1391class SpeciesIndexer: 1392 """Build an index that splits atomic arrays by species. 1393 1394 FPID: SPECIES_INDEXER 1395 1396 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1397 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1398 1399 """ 1400 1401 output_key: str = "species_index" 1402 """Key for the output dictionary.""" 1403 species_order: Optional[str] = None 1404 """Comma separated list of species in the order they should be indexed.""" 1405 add_atoms: int = 0 1406 """Additional atoms to add to the sizes.""" 1407 add_atoms_margin: int = 10 1408 """Additional atoms to add to the sizes when adding margin.""" 1409 1410 FPID: ClassVar[str] = "SPECIES_INDEXER" 1411 1412 def init(self): 1413 return FrozenDict( 1414 { 1415 "sizes": {}, 1416 } 1417 ) 1418 1419 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1420 species = np.array(inputs["species"], dtype=np.int32) 1421 nat = species.shape[0] 1422 set_species, counts = np.unique(species, return_counts=True) 1423 1424 new_state = {**state} 1425 state_up = {} 1426 1427 sizes = state.get("sizes", FrozenDict({})) 1428 new_sizes = {**sizes} 1429 up_sizes = False 1430 counts_dict = {} 1431 for s, c in zip(set_species, counts): 1432 if s <= 0: 1433 continue 1434 counts_dict[s] = c 1435 if c > sizes.get(s, 0): 1436 up_sizes = True 1437 add_atoms = state.get("add_atoms", self.add_atoms) 1438 if add_margin: 1439 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1440 new_sizes[s] = c + add_atoms 1441 1442 new_sizes = FrozenDict(new_sizes) 1443 if up_sizes: 1444 state_up["sizes"] = (new_sizes, sizes) 1445 new_state["sizes"] = new_sizes 1446 1447 if self.species_order is not None: 1448 species_order = [el.strip() for el in self.species_order.split(",")] 1449 max_size_prev = state.get("max_size", 0) 1450 max_size = max(new_sizes.values()) 1451 if max_size > max_size_prev: 1452 state_up["max_size"] = (max_size, max_size_prev) 1453 new_state["max_size"] = max_size 1454 max_size_prev = max_size 1455 1456 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1457 for i, el in enumerate(species_order): 1458 s = PERIODIC_TABLE_REV_IDX[el] 1459 if s in counts_dict.keys(): 1460 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1461 else: 1462 species_index = { 1463 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1464 for s, c in new_sizes.items() 1465 } 1466 for s, c in zip(set_species, counts): 1467 if s <= 0: 1468 continue 1469 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1470 1471 output = { 1472 **inputs, 1473 self.output_key: species_index, 1474 self.output_key + "_overflow": False, 1475 } 1476 1477 if return_state_update: 1478 return FrozenDict(new_state), output, state_up 1479 return FrozenDict(new_state), output 1480 1481 def check_reallocate(self, state, inputs, parent_overflow=False): 1482 """check for overflow and reallocate nblist if necessary""" 1483 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1484 if not overflow: 1485 return state, {}, inputs, False 1486 1487 add_margin = inputs[self.output_key + "_overflow"] 1488 state, inputs, state_up = self( 1489 state, inputs, return_state_update=True, add_margin=add_margin 1490 ) 1491 return state, state_up, inputs, True 1492 # return state, {}, inputs, parent_overflow 1493 1494 @partial(jax.jit, static_argnums=(0, 1)) 1495 def process(self, state, inputs): 1496 # assert ( 1497 # self.output_key in inputs 1498 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1499 1500 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1501 if self.output_key in inputs and not recompute_species_index: 1502 return inputs 1503 1504 if state is None: 1505 raise ValueError("Species Indexer state must be provided on accelerator.") 1506 1507 species = inputs["species"] 1508 nat = species.shape[0] 1509 1510 sizes = state["sizes"] 1511 1512 if self.species_order is not None: 1513 species_order = [el.strip() for el in self.species_order.split(",")] 1514 max_size = state["max_size"] 1515 1516 species_index = jnp.full( 1517 (len(species_order), max_size), nat, dtype=jnp.int32 1518 ) 1519 for i, el in enumerate(species_order): 1520 s = PERIODIC_TABLE_REV_IDX[el] 1521 if s in sizes.keys(): 1522 c = sizes[s] 1523 species_index = species_index.at[i, :].set( 1524 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1525 ) 1526 # if s in counts_dict.keys(): 1527 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1528 else: 1529 # species_index = { 1530 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1531 # for s, c in sizes.items() 1532 # } 1533 species_index = {} 1534 overflow = False 1535 natcount = 0 1536 for s, c in sizes.items(): 1537 mask = species == s 1538 new_size = jnp.sum(mask) 1539 natcount = natcount + new_size 1540 overflow = overflow | (new_size > c) # check if sizes are correct 1541 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1542 species == s, size=c, fill_value=nat 1543 )[0] 1544 1545 mask = species <= 0 1546 new_size = jnp.sum(mask) 1547 natcount = natcount + new_size 1548 overflow = overflow | ( 1549 natcount < species.shape[0] 1550 ) # check if any species missing 1551 1552 return { 1553 **inputs, 1554 self.output_key: species_index, 1555 self.output_key + "_overflow": overflow, 1556 } 1557 1558 @partial(jax.jit, static_argnums=(0,)) 1559 def update_skin(self, inputs): 1560 return self.process(None, inputs) 1561 1562@dataclasses.dataclass(frozen=True) 1563class BlockIndexer: 1564 """Build an index that splits atomic arrays by chemical blocks. 1565 1566 FPID: BLOCK_INDEXER 1567 1568 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1569 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1570 1571 """ 1572 1573 output_key: str = "block_index" 1574 """Key for the output dictionary.""" 1575 add_atoms: int = 0 1576 """Additional atoms to add to the sizes.""" 1577 add_atoms_margin: int = 10 1578 """Additional atoms to add to the sizes when adding margin.""" 1579 split_CNOPSSe: bool = False 1580 1581 FPID: ClassVar[str] = "BLOCK_INDEXER" 1582 1583 def init(self): 1584 return FrozenDict( 1585 { 1586 "sizes": {}, 1587 } 1588 ) 1589 1590 def build_chemical_blocks(self): 1591 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1592 if self.split_CNOPSSe: 1593 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1594 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1595 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1596 if self.split_CNOPSSe: 1597 _CHEMICAL_BLOCKS[6] = 1 1598 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1599 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1600 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1601 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1602 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1603 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1604 1605 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1606 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1607 1608 species = np.array(inputs["species"], dtype=np.int32) 1609 blocks = _CHEMICAL_BLOCKS[species] 1610 nat = species.shape[0] 1611 set_blocks, counts = np.unique(blocks, return_counts=True) 1612 1613 new_state = {**state} 1614 state_up = {} 1615 1616 sizes = state.get("sizes", FrozenDict({})) 1617 new_sizes = {**sizes} 1618 up_sizes = False 1619 for s, c in zip(set_blocks, counts): 1620 if s < 0: 1621 continue 1622 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1623 if c > sizes.get(key, 0): 1624 up_sizes = True 1625 add_atoms = state.get("add_atoms", self.add_atoms) 1626 if add_margin: 1627 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1628 new_sizes[key] = c + add_atoms 1629 1630 new_sizes = FrozenDict(new_sizes) 1631 if up_sizes: 1632 state_up["sizes"] = (new_sizes, sizes) 1633 new_state["sizes"] = new_sizes 1634 1635 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1636 for (_,n), c in new_sizes.items(): 1637 block_index[n] = np.full(c, nat, dtype=np.int32) 1638 # block_index = { 1639 # n: np.full(c, nat, dtype=np.int32) 1640 # for (_,n), c in new_sizes.items() 1641 # } 1642 for s, c in zip(set_blocks, counts): 1643 if s < 0: 1644 continue 1645 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1646 1647 output = { 1648 **inputs, 1649 self.output_key: block_index, 1650 self.output_key + "_overflow": False, 1651 } 1652 1653 if return_state_update: 1654 return FrozenDict(new_state), output, state_up 1655 return FrozenDict(new_state), output 1656 1657 def check_reallocate(self, state, inputs, parent_overflow=False): 1658 """check for overflow and reallocate nblist if necessary""" 1659 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1660 if not overflow: 1661 return state, {}, inputs, False 1662 1663 add_margin = inputs[self.output_key + "_overflow"] 1664 state, inputs, state_up = self( 1665 state, inputs, return_state_update=True, add_margin=add_margin 1666 ) 1667 return state, state_up, inputs, True 1668 # return state, {}, inputs, parent_overflow 1669 1670 @partial(jax.jit, static_argnums=(0, 1)) 1671 def process(self, state, inputs): 1672 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1673 # assert ( 1674 # self.output_key in inputs 1675 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1676 1677 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1678 if self.output_key in inputs and not recompute_species_index: 1679 return inputs 1680 1681 if state is None: 1682 raise ValueError("Block Indexer state must be provided on accelerator.") 1683 1684 species = inputs["species"] 1685 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1686 nat = species.shape[0] 1687 1688 sizes = state["sizes"] 1689 1690 # species_index = { 1691 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1692 # for s, c in sizes.items() 1693 # } 1694 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1695 overflow = False 1696 natcount = 0 1697 for (s,name), c in sizes.items(): 1698 mask = blocks == s 1699 new_size = jnp.sum(mask) 1700 natcount = natcount + new_size 1701 overflow = overflow | (new_size > c) # check if sizes are correct 1702 block_index[name] = jnp.nonzero( 1703 mask, size=c, fill_value=nat 1704 )[0] 1705 1706 mask = blocks < 0 1707 new_size = jnp.sum(mask) 1708 natcount = natcount + new_size 1709 overflow = overflow | ( 1710 natcount < species.shape[0] 1711 ) # check if any species missing 1712 1713 return { 1714 **inputs, 1715 self.output_key: block_index, 1716 self.output_key + "_overflow": overflow, 1717 } 1718 1719 @partial(jax.jit, static_argnums=(0,)) 1720 def update_skin(self, inputs): 1721 return self.process(None, inputs) 1722 1723 1724@dataclasses.dataclass(frozen=True) 1725class AtomPadding: 1726 """Pad atomic arrays to a fixed size.""" 1727 1728 mult_size: float = 1.2 1729 """Multiplicative factor for resizing the atomic arrays.""" 1730 add_sys: int = 0 1731 1732 def init(self): 1733 return {"prev_nat": 0, "prev_nsys": 0} 1734 1735 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1736 species = inputs["species"] 1737 nat = species.shape[0] 1738 1739 prev_nat = state.get("prev_nat", 0) 1740 prev_nat_ = prev_nat 1741 if nat > prev_nat_: 1742 prev_nat_ = int(self.mult_size * nat) + 1 1743 1744 nsys = len(inputs["natoms"]) 1745 prev_nsys = state.get("prev_nsys", 0) 1746 prev_nsys_ = prev_nsys 1747 if nsys > prev_nsys_: 1748 prev_nsys_ = nsys + self.add_sys 1749 1750 add_atoms = prev_nat_ - nat 1751 add_sys = prev_nsys_ - nsys + 1 1752 output = {**inputs} 1753 if add_atoms > 0: 1754 for k, v in inputs.items(): 1755 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1756 if v.shape[0] == nat: 1757 output[k] = np.append( 1758 v, 1759 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1760 axis=0, 1761 ) 1762 elif v.shape[0] == nsys: 1763 if k == "cells": 1764 output[k] = np.append( 1765 v, 1766 1000 1767 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1768 add_sys, axis=0 1769 ), 1770 axis=0, 1771 ) 1772 else: 1773 output[k] = np.append( 1774 v, 1775 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1776 axis=0, 1777 ) 1778 output["natoms"] = np.append( 1779 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1780 ) 1781 output["species"] = np.append( 1782 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1783 ) 1784 output["batch_index"] = np.append( 1785 inputs["batch_index"], 1786 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1787 ) 1788 if "system_index" in inputs: 1789 output["system_index"] = np.append( 1790 inputs["system_index"], 1791 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1792 ) 1793 1794 output["true_atoms"] = output["species"] > 0 1795 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1796 1797 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1798 1799 return FrozenDict(state), output 1800 1801 1802def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1803 """Remove padding from atomic arrays.""" 1804 if "true_atoms" not in inputs: 1805 return inputs 1806 1807 species = np.asarray(inputs["species"]) 1808 true_atoms = np.asarray(inputs["true_atoms"]) 1809 true_sys = np.asarray(inputs["true_sys"]) 1810 natall = species.shape[0] 1811 nat = np.argmax(species <= 0) 1812 if nat == 0: 1813 return inputs 1814 1815 natoms = inputs["natoms"] 1816 nsysall = len(natoms) 1817 1818 output = {**inputs} 1819 for k, v in inputs.items(): 1820 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1821 v = np.asarray(v) 1822 if v.ndim == 0: 1823 output[k] = v 1824 elif v.shape[0] == natall: 1825 output[k] = v[true_atoms] 1826 elif v.shape[0] == nsysall: 1827 output[k] = v[true_sys] 1828 del output["true_sys"] 1829 del output["true_atoms"] 1830 return output 1831 1832 1833def check_input(inputs): 1834 """Check the input dictionary for required keys and types.""" 1835 assert "species" in inputs, "species must be provided" 1836 assert "coordinates" in inputs, "coordinates must be provided" 1837 species = inputs["species"].astype(np.int32) 1838 ifake = np.argmax(species <= 0) 1839 if ifake > 0: 1840 assert np.all(species[:ifake] > 0), "species must be positive" 1841 nat = inputs["species"].shape[0] 1842 1843 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1844 batch_index = inputs.get( 1845 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1846 ).astype(np.int32) 1847 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1848 if "cells" in inputs: 1849 cells = inputs["cells"] 1850 if "reciprocal_cells" not in inputs: 1851 reciprocal_cells = np.linalg.inv(cells) 1852 else: 1853 reciprocal_cells = inputs["reciprocal_cells"] 1854 if cells.ndim == 2: 1855 cells = cells[None, :, :] 1856 if reciprocal_cells.ndim == 2: 1857 reciprocal_cells = reciprocal_cells[None, :, :] 1858 output["cells"] = cells 1859 output["reciprocal_cells"] = reciprocal_cells 1860 1861 return output 1862 1863 1864def convert_to_jax(data): 1865 """Convert a numpy arrays to jax arrays in a pytree.""" 1866 1867 def convert(x): 1868 if isinstance(x, np.ndarray): 1869 # if x.dtype == np.float64: 1870 # return jnp.asarray(x, dtype=jnp.float32) 1871 return jnp.asarray(x) 1872 return x 1873 1874 return jax.tree_util.tree_map(convert, data) 1875 1876 1877class JaxConverter(nn.Module): 1878 """Convert numpy arrays to jax arrays in a pytree.""" 1879 1880 def __call__(self, data): 1881 return convert_to_jax(data) 1882 1883def convert_to_numpy(data): 1884 """Convert jax arrays to numpy arrays in a pytree.""" 1885 1886 def convert(x): 1887 if isinstance(x, jax.Array): 1888 return np.array(x) 1889 return x 1890 1891 return jax.tree_util.tree_map(convert, data) 1892 1893 1894@dataclasses.dataclass(frozen=True) 1895class PreprocessingChain: 1896 """Chain of preprocessing layers.""" 1897 1898 layers: Tuple[Callable[..., Dict[str, Any]]] 1899 """Preprocessing layers.""" 1900 use_atom_padding: bool = False 1901 """Add an AtomPadding layer at the beginning of the chain.""" 1902 atom_padder: AtomPadding = AtomPadding() 1903 """AtomPadding layer.""" 1904 1905 def __post_init__(self): 1906 if not isinstance(self.layers, Sequence): 1907 raise ValueError( 1908 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1909 ) 1910 if not self.layers: 1911 raise ValueError(f"Error: no Preprocessing layers were provided.") 1912 1913 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1914 do_check_input = state.get("check_input", True) 1915 if do_check_input: 1916 inputs = check_input(inputs) 1917 new_state = {**state} 1918 if self.use_atom_padding: 1919 s, inputs = self.atom_padder(state["padder_state"], inputs) 1920 new_state["padder_state"] = s 1921 layer_state = state["layers_state"] 1922 new_layer_state = [] 1923 for i,layer in enumerate(self.layers): 1924 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1925 new_layer_state.append(s) 1926 new_state["layers_state"] = tuple(new_layer_state) 1927 return FrozenDict(new_state), convert_to_jax(inputs) 1928 1929 def check_reallocate(self, state, inputs): 1930 new_state = [] 1931 state_up = [] 1932 layer_state = state["layers_state"] 1933 parent_overflow = False 1934 for i,layer in enumerate(self.layers): 1935 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1936 layer_state[i], inputs, parent_overflow 1937 ) 1938 new_state.append(s) 1939 state_up.append(s_up) 1940 1941 if not parent_overflow: 1942 return state, {}, inputs, False 1943 return ( 1944 FrozenDict({**state, "layers_state": tuple(new_state)}), 1945 state_up, 1946 inputs, 1947 True, 1948 ) 1949 1950 def atom_padding(self, state, inputs): 1951 if self.use_atom_padding: 1952 padder_state,inputs = self.atom_padder(state["padder_state"], inputs) 1953 return FrozenDict({**state,"padder_state": padder_state}), inputs 1954 return state, inputs 1955 1956 @partial(jax.jit, static_argnums=(0, 1)) 1957 def process(self, state, inputs): 1958 layer_state = state["layers_state"] 1959 for i,layer in enumerate(self.layers): 1960 inputs = layer.process(layer_state[i], inputs) 1961 return inputs 1962 1963 @partial(jax.jit, static_argnums=(0)) 1964 def update_skin(self, inputs): 1965 for layer in self.layers: 1966 inputs = layer.update_skin(inputs) 1967 return inputs 1968 1969 def init(self): 1970 state = {"check_input": True} 1971 if self.use_atom_padding: 1972 state["padder_state"] = self.atom_padder.init() 1973 layer_state = [] 1974 for layer in self.layers: 1975 layer_state.append(layer.init()) 1976 state["layers_state"] = tuple(layer_state) 1977 return FrozenDict(state) 1978 1979 def init_with_output(self, inputs): 1980 state = self.init() 1981 return self(state, inputs) 1982 1983 def get_processors(self): 1984 processors = [] 1985 for layer in self.layers: 1986 if hasattr(layer, "get_processor"): 1987 processors.append(layer.get_processor()) 1988 return processors 1989 1990 def get_graphs_properties(self): 1991 properties = {} 1992 for layer in self.layers: 1993 if hasattr(layer, "get_graph_properties"): 1994 properties = deep_update(properties, layer.get_graph_properties()) 1995 return properties 1996 1997 1998# PREPROCESSING = { 1999# "GRAPH": GraphGenerator, 2000# # "GRAPH_FIXED": GraphGeneratorFixed, 2001# "GRAPH_FILTER": GraphFilter, 2002# "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension, 2003# # "GRAPH_DENSE_EXTENSION": GraphDenseExtension, 2004# "SPECIES_INDEXER": SpeciesIndexer, 2005# }
22@dataclasses.dataclass(frozen=True) 23class GraphGenerator: 24 """Generate a graph from a set of coordinates 25 26 FPID: GRAPH 27 28 For now, we generate all pairs of atoms and filter based on cutoff. 29 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 30 """ 31 32 cutoff: float 33 """Cutoff distance for the graph.""" 34 graph_key: str = "graph" 35 """Key of the graph in the outputs.""" 36 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 37 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 38 kmax: int = 30 39 """Maximum number of k-points to consider.""" 40 kthr: float = 1e-6 41 """Threshold for k-point filtering.""" 42 k_space: bool = False 43 """Whether to generate k-space information for the graph.""" 44 mult_size: float = 1.05 45 """Multiplicative factor for resizing the nblist.""" 46 # covalent_cutoff: bool = False 47 48 FPID: ClassVar[str] = "GRAPH" 49 50 def init(self): 51 return FrozenDict( 52 { 53 "max_nat": 1, 54 "npairs": 1, 55 "nblist_mult_size": self.mult_size, 56 } 57 ) 58 59 def get_processor(self) -> Tuple[nn.Module, Dict]: 60 return GraphProcessor, { 61 "cutoff": self.cutoff, 62 "graph_key": self.graph_key, 63 "switch_params": self.switch_params, 64 "name": f"{self.graph_key}_Processor", 65 } 66 67 def get_graph_properties(self): 68 return { 69 self.graph_key: { 70 "cutoff": self.cutoff, 71 "directed": True, 72 } 73 } 74 75 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 76 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 77 if self.graph_key in inputs: 78 graph = inputs[self.graph_key] 79 if "keep_graph" in graph: 80 return state, inputs 81 82 coords = np.array(inputs["coordinates"], dtype=np.float32) 83 natoms = np.array(inputs["natoms"], dtype=np.int32) 84 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 85 86 new_state = {**state} 87 state_up = {} 88 89 mult_size = state.get("nblist_mult_size", self.mult_size) 90 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 91 92 if natoms.shape[0] == 1: 93 max_nat = coords.shape[0] 94 true_max_nat = max_nat 95 else: 96 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 97 true_max_nat = int(np.max(natoms)) 98 if true_max_nat > max_nat: 99 add_atoms = state.get("add_atoms", 0) 100 new_maxnat = true_max_nat + add_atoms 101 state_up["max_nat"] = (new_maxnat, max_nat) 102 new_state["max_nat"] = new_maxnat 103 104 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 105 106 ### compute indices of all pairs 107 apply_pbc = "cells" in inputs 108 minimage = "minimum_image" in inputs.get("flags", {}) 109 include_self_image = apply_pbc and not minimage 110 111 shift = 0 if include_self_image else 1 112 p1, p2 = np.triu_indices(true_max_nat, shift) 113 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 114 pbc_shifts = None 115 if natoms.shape[0] > 1: 116 ## batching => mask irrelevant pairs 117 mask_p12 = ( 118 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 119 ).flatten() 120 shift = np.concatenate( 121 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 122 ) 123 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 124 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 125 126 if not apply_pbc: 127 ### NO PBC 128 vec = coords[p2] - coords[p1] 129 else: 130 cells = np.array(inputs["cells"], dtype=np.float32) 131 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 132 if minimage: 133 ## MINIMUM IMAGE CONVENTION 134 vec = coords[p2] - coords[p1] 135 if cells.shape[0] == 1: 136 vecpbc = np.dot(vec, reciprocal_cells[0]) 137 pbc_shifts = -np.round(vecpbc).astype(np.int32) 138 vec = vec + np.dot(pbc_shifts, cells[0]) 139 else: 140 batch_index_vec = batch_index[p1] 141 vecpbc = np.einsum( 142 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 143 ) 144 pbc_shifts = -np.round(vecpbc).astype(np.int32) 145 vec = vec + np.einsum( 146 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 147 ) 148 else: 149 ### GENERAL PBC 150 ## put all atoms in central box 151 if cells.shape[0] == 1: 152 coords_pbc = np.dot(coords, reciprocal_cells[0]) 153 at_shifts = -np.floor(coords_pbc).astype(np.int32) 154 coords_pbc = coords + np.dot(at_shifts, cells[0]) 155 else: 156 coords_pbc = np.einsum( 157 "aj,aji->ai", coords, reciprocal_cells[batch_index] 158 ) 159 at_shifts = -np.floor(coords_pbc).astype(np.int32) 160 coords_pbc = coords + np.einsum( 161 "aj,aji->ai", at_shifts, cells[batch_index] 162 ) 163 vec = coords_pbc[p2] - coords_pbc[p1] 164 165 ## compute maximum number of repeats 166 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 167 cdinv = cutoff_skin * inv_distances 168 num_repeats_all = np.ceil(cdinv).astype(np.int32) 169 if "true_sys" in inputs: 170 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 171 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 172 num_repeats = np.max(num_repeats_all, axis=0) 173 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 174 if np.any(num_repeats > num_repeats_prev): 175 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 176 state_up["num_repeats_pbc"] = ( 177 tuple(num_repeats_new), 178 tuple(num_repeats_prev), 179 ) 180 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 181 ## build all possible shifts 182 cell_shift_pbc = np.array( 183 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 184 dtype=np.int32, 185 ).T.reshape(-1, 3) 186 ## shift applied to vectors 187 if cells.shape[0] == 1: 188 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 189 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 190 pbc_shifts = np.broadcast_to( 191 cell_shift_pbc[None, :, :], 192 (p1.shape[0], cell_shift_pbc.shape[0], 3), 193 ).reshape(-1, 3) 194 p1 = np.broadcast_to( 195 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 196 ).flatten() 197 p2 = np.broadcast_to( 198 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 199 ).flatten() 200 if natoms.shape[0] > 1: 201 mask_p12 = np.broadcast_to( 202 mask_p12[:, None], 203 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 204 ).flatten() 205 else: 206 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 207 208 ## get pbc shifts specific to each box 209 cell_shift_pbc = np.broadcast_to( 210 cell_shift_pbc[None, :, :], 211 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 212 ) 213 mask = np.all( 214 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 215 ).flatten() 216 idx = np.nonzero(mask)[0] 217 nshifts = idx.shape[0] 218 nshifts_prev = state.get("nshifts_pbc", 0) 219 if nshifts > nshifts_prev or add_margin: 220 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 221 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 222 new_state["nshifts_pbc"] = nshifts_new 223 224 dvec_filter = dvec.reshape(-1, 3)[idx, :] 225 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 226 227 ## get batch shift in the dvec_filter array 228 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 229 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 230 231 ## compute vectors 232 batch_index_vec = batch_index[p1] 233 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 234 vec = vec.repeat(nrep_vec, axis=0) 235 nvec_pbc = nrep_vec.sum() #vec.shape[0] 236 nvec_pbc_prev = state.get("nvec_pbc", 0) 237 if nvec_pbc > nvec_pbc_prev or add_margin: 238 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 239 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 240 new_state["nvec_pbc"] = nvec_pbc_new 241 242 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 243 ## get shift index 244 dshift = np.concatenate( 245 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 246 ).repeat(nrep_vec) 247 # ishift = np.arange(dshift.shape[0])-dshift 248 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 249 icellshift = ( 250 np.arange(dshift.shape[0]) 251 - dshift 252 + bshift[batch_index_vec].repeat(nrep_vec) 253 ) 254 # shift vectors 255 vec = vec + dvec_filter[icellshift] 256 pbc_shifts = cell_shift_pbc_filter[icellshift] 257 258 p1 = np.repeat(p1, nrep_vec) 259 p2 = np.repeat(p2, nrep_vec) 260 if natoms.shape[0] > 1: 261 mask_p12 = np.repeat(mask_p12, nrep_vec) 262 263 ## compute distances 264 d12 = (vec**2).sum(axis=-1) 265 if natoms.shape[0] > 1: 266 d12 = np.where(mask_p12, d12, cutoff_skin**2) 267 268 ## filter pairs 269 max_pairs = state.get("npairs", 1) 270 mask = d12 < cutoff_skin**2 271 if include_self_image: 272 mask_self = np.logical_or(p1 != p2, d12 > 1.e-3) 273 mask = np.logical_and(mask, mask_self) 274 idx = np.nonzero(mask)[0] 275 npairs = idx.shape[0] 276 if npairs > max_pairs or add_margin: 277 prev_max_pairs = max_pairs 278 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 279 state_up["npairs"] = (max_pairs, prev_max_pairs) 280 new_state["npairs"] = max_pairs 281 282 nat = coords.shape[0] 283 edge_src = np.full(max_pairs, nat, dtype=np.int32) 284 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 285 d12_ = np.full(max_pairs, cutoff_skin**2) 286 edge_src[:npairs] = p1[idx] 287 edge_dst[:npairs] = p2[idx] 288 d12_[:npairs] = d12[idx] 289 d12 = d12_ 290 291 if apply_pbc: 292 pbc_shifts_ = np.zeros((max_pairs, 3), dtype=np.int32) 293 pbc_shifts_[:npairs] = pbc_shifts[idx] 294 pbc_shifts = pbc_shifts_ 295 if not minimage: 296 pbc_shifts[:npairs] = ( 297 pbc_shifts[:npairs] 298 + at_shifts[edge_dst[:npairs]] 299 - at_shifts[edge_src[:npairs]] 300 ) 301 302 303 ## symmetrize 304 if include_self_image: 305 mask_noself = edge_src != edge_dst 306 idx_noself = np.nonzero(mask_noself)[0] 307 npairs_noself = idx_noself.shape[0] 308 max_noself = state.get("npairs_noself", 1) 309 if npairs_noself > max_noself or add_margin: 310 prev_max_noself = max_noself 311 max_noself = int(mult_size * max(npairs_noself, max_noself)) + 1 312 state_up["npairs_noself"] = (max_noself, prev_max_noself) 313 new_state["npairs_noself"] = max_noself 314 315 edge_src_noself = np.full(max_noself, nat, dtype=np.int32) 316 edge_dst_noself = np.full(max_noself, nat, dtype=np.int32) 317 d12_noself = np.full(max_noself, cutoff_skin**2) 318 pbc_shifts_noself = np.zeros((max_noself, 3), dtype=np.int32) 319 320 edge_dst_noself[:npairs_noself] = edge_dst[idx_noself] 321 edge_src_noself[:npairs_noself] = edge_src[idx_noself] 322 d12_noself[:npairs_noself] = d12[idx_noself] 323 pbc_shifts_noself[:npairs_noself] = pbc_shifts[idx_noself] 324 edge_src = np.concatenate((edge_src, edge_dst_noself)) 325 edge_dst = np.concatenate((edge_dst, edge_src_noself)) 326 d12 = np.concatenate((d12, d12_noself)) 327 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts_noself)) 328 else: 329 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 330 (edge_dst, edge_src) 331 ) 332 d12 = np.concatenate((d12, d12)) 333 if apply_pbc: 334 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 335 336 if "nblist_skin" in state: 337 edge_src_skin = edge_src 338 edge_dst_skin = edge_dst 339 if apply_pbc: 340 pbc_shifts_skin = pbc_shifts 341 max_pairs_skin = state.get("npairs_skin", 1) 342 mask = d12 < self.cutoff**2 343 idx = np.nonzero(mask)[0] 344 npairs_skin = idx.shape[0] 345 if npairs_skin > max_pairs_skin or add_margin: 346 prev_max_pairs_skin = max_pairs_skin 347 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 348 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 349 new_state["npairs_skin"] = max_pairs_skin 350 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 351 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 352 d12_ = np.full(max_pairs_skin, self.cutoff**2) 353 edge_src[:npairs_skin] = edge_src_skin[idx] 354 edge_dst[:npairs_skin] = edge_dst_skin[idx] 355 d12_[:npairs_skin] = d12[idx] 356 d12 = d12_ 357 if apply_pbc: 358 pbc_shifts = np.zeros((max_pairs_skin, 3), dtype=np.int32) 359 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 360 361 graph = inputs.get(self.graph_key, {}) 362 graph_out = { 363 **graph, 364 "edge_src": edge_src, 365 "edge_dst": edge_dst, 366 "d12": d12, 367 "overflow": False, 368 "pbc_shifts": pbc_shifts, 369 } 370 if "nblist_skin" in state: 371 graph_out["edge_src_skin"] = edge_src_skin 372 graph_out["edge_dst_skin"] = edge_dst_skin 373 if apply_pbc: 374 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 375 376 if self.k_space and apply_pbc: 377 if "k_points" not in graph: 378 ks, _, _, bewald = get_reciprocal_space_parameters( 379 reciprocal_cells, self.cutoff, self.kmax, self.kthr 380 ) 381 graph_out["k_points"] = ks 382 graph_out["b_ewald"] = bewald 383 384 output = {**inputs, self.graph_key: graph_out} 385 386 if return_state_update: 387 return FrozenDict(new_state), output, state_up 388 return FrozenDict(new_state), output 389 390 def check_reallocate(self, state, inputs, parent_overflow=False): 391 """check for overflow and reallocate nblist if necessary""" 392 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 393 if not overflow: 394 return state, {}, inputs, False 395 396 add_margin = inputs[self.graph_key].get("overflow", False) 397 state, inputs, state_up = self( 398 state, inputs, return_state_update=True, add_margin=add_margin 399 ) 400 return state, state_up, inputs, True 401 402 @partial(jax.jit, static_argnums=(0, 1)) 403 def process(self, state, inputs): 404 """build a nblist on accelerator with jax and precomputed shapes""" 405 if self.graph_key in inputs: 406 graph = inputs[self.graph_key] 407 if "keep_graph" in graph: 408 return inputs 409 coords = inputs["coordinates"] 410 natoms = inputs["natoms"] 411 batch_index = inputs["batch_index"] 412 413 if natoms.shape[0] == 1: 414 max_nat = coords.shape[0] 415 else: 416 max_nat = state.get( 417 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 418 ) 419 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 420 421 ### compute indices of all pairs 422 apply_pbc = "cells" in inputs 423 minimage = "minimum_image" in inputs.get("flags", {}) 424 include_self_image = apply_pbc and not minimage 425 426 shift = 0 if include_self_image else 1 427 p1, p2 = np.triu_indices(max_nat, shift) 428 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 429 pbc_shifts = None 430 if natoms.shape[0] > 1: 431 ## batching => mask irrelevant pairs 432 mask_p12 = ( 433 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 434 ).flatten() 435 shift = jnp.concatenate( 436 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 437 ) 438 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 439 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 440 441 ## compute vectors 442 overflow_repeats = jnp.asarray(False, dtype=bool) 443 if "cells" not in inputs: 444 vec = coords[p2] - coords[p1] 445 else: 446 cells = inputs["cells"] 447 reciprocal_cells = inputs["reciprocal_cells"] 448 # minimage = state.get("minimum_image", True) 449 minimage = "minimum_image" in inputs.get("flags", {}) 450 451 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 452 vecpbc = jnp.dot(vec, reciprocal_cell) 453 if mode == "round": 454 pbc_shifts = -jnp.round(vecpbc).astype(jnp.int32) 455 elif mode == "floor": 456 pbc_shifts = -jnp.floor(vecpbc).astype(jnp.int32) 457 else: 458 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 459 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 460 461 if minimage: 462 ## minimum image convention 463 vec = coords[p2] - coords[p1] 464 465 if cells.shape[0] == 1: 466 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 467 else: 468 batch_index_vec = batch_index[p1] 469 vec, pbc_shifts = jax.vmap(compute_pbc)( 470 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 471 ) 472 else: 473 ### general PBC only for single cell yet 474 # if cells.shape[0] > 1: 475 # raise NotImplementedError( 476 # "General PBC not implemented for batches on accelerator." 477 # ) 478 # cell = cells[0] 479 # reciprocal_cell = reciprocal_cells[0] 480 481 ## put all atoms in central box 482 if cells.shape[0] == 1: 483 coords_pbc, at_shifts = compute_pbc( 484 coords, reciprocal_cells[0], cells[0], mode="floor" 485 ) 486 else: 487 coords_pbc, at_shifts = jax.vmap( 488 partial(compute_pbc, mode="floor") 489 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 490 vec = coords_pbc[p2] - coords_pbc[p1] 491 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 492 # if num_repeats is None: 493 # raise ValueError( 494 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 495 # ) 496 # check if num_repeats is larger than previous 497 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 498 cdinv = cutoff_skin * inv_distances 499 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 500 if "true_sys" in inputs: 501 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 502 num_repeats_new = jnp.max(num_repeats_all, axis=0) 503 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 504 505 cell_shift_pbc = jnp.asarray( 506 np.array( 507 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 508 dtype=np.int32, 509 ).T.reshape(-1, 3) 510 ) 511 512 if cells.shape[0] == 1: 513 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 514 pbc_shifts = jnp.broadcast_to( 515 cell_shift_pbc[None, :, :], 516 (p1.shape[0], cell_shift_pbc.shape[0], 3), 517 ).reshape(-1, 3) 518 p1 = jnp.broadcast_to( 519 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 520 ).flatten() 521 p2 = jnp.broadcast_to( 522 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 523 ).flatten() 524 if natoms.shape[0] > 1: 525 mask_p12 = jnp.broadcast_to( 526 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 527 ).flatten() 528 else: 529 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 530 531 ## get pbc shifts specific to each box 532 cell_shift_pbc = jnp.broadcast_to( 533 cell_shift_pbc[None, :, :], 534 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 535 ) 536 mask = jnp.all( 537 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 538 ).flatten() 539 max_shifts = state.get("nshifts_pbc", 1) 540 541 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 542 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 543 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 544 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 545 mask, 546 max_shifts, 547 (dvecx, 0.), 548 (dvecy, 0.), 549 (dvecz, 0.), 550 (shiftx, 0), 551 (shifty, 0), 552 (shiftz, 0), 553 ) 554 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 555 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1, dtype=jnp.int32) 556 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 557 558 ## get batch shift in the dvec_filter array 559 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 560 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 561 562 ## repeat vectors 563 nvec_max = state.get("nvec_pbc", 1) 564 batch_index_vec = batch_index[p1] 565 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 566 nvec = nrep_vec.sum() 567 overflow_repeats = overflow_repeats | (nvec > nvec_max) 568 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 569 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 570 571 ## get shift index 572 dshift = jnp.concatenate( 573 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 574 ) 575 if nrep_vec.size == 0: 576 dshift = jnp.array([],dtype=jnp.int32) 577 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 578 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 579 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 580 vec = vec + dvec[icellshift] 581 pbc_shifts = cell_shift_pbc[icellshift] 582 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 583 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 584 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 585 586 mask_valid = jnp.arange(nvec_max) < nvec 587 mask_p12 = jnp.where(mask_valid, mask_p12, False) 588 589 590 ## compute distances 591 d12 = (vec**2).sum(axis=-1) 592 if natoms.shape[0] > 1: 593 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 594 595 ## filter pairs 596 max_pairs = state.get("npairs", 1) 597 mask = d12 < cutoff_skin**2 598 if include_self_image: 599 mask_self = jnp.logical_or(p1 != p2, d12 > 1.e-3) 600 mask = jnp.logical_and(mask, mask_self) 601 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 602 mask, 603 max_pairs, 604 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 605 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 606 (d12, cutoff_skin**2), 607 ) 608 if "cells" in inputs: 609 pbc_shifts = ( 610 jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype) 611 .at[scatter_idx] 612 .set(pbc_shifts, mode="drop") 613 ) 614 if not minimage: 615 pbc_shifts = ( 616 pbc_shifts 617 + at_shifts.at[edge_dst].get(fill_value=0) 618 - at_shifts.at[edge_src].get(fill_value=0) 619 ) 620 621 ## check for overflow 622 if natoms.shape[0] == 1: 623 true_max_nat = coords.shape[0] 624 else: 625 true_max_nat = jnp.max(natoms) 626 overflow_count = npairs > max_pairs 627 overflow_at = true_max_nat > max_nat 628 overflow = overflow_count | overflow_at | overflow_repeats 629 630 ## symmetrize 631 if include_self_image: 632 mask_noself = edge_src != edge_dst 633 max_noself = state.get("npairs_noself", 1) 634 (edge_src_noself, edge_dst_noself, d12_noself, pbc_shifts_noself), scatter_idx, npairs_noself = mask_filter_1d( 635 mask_noself, 636 max_noself, 637 (edge_src, coords.shape[0]), 638 (edge_dst, coords.shape[0]), 639 (d12, cutoff_skin**2), 640 (pbc_shifts, jnp.zeros((3,), dtype=pbc_shifts.dtype)), 641 ) 642 overflow = overflow | (npairs_noself > max_noself) 643 edge_src = jnp.concatenate((edge_src, edge_dst_noself)) 644 edge_dst = jnp.concatenate((edge_dst, edge_src_noself)) 645 d12 = jnp.concatenate((d12, d12_noself)) 646 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts_noself)) 647 else: 648 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 649 (edge_dst, edge_src) 650 ) 651 d12 = jnp.concatenate((d12, d12)) 652 if "cells" in inputs: 653 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 654 655 if "nblist_skin" in state: 656 # edge_mask_skin = edge_mask 657 edge_src_skin = edge_src 658 edge_dst_skin = edge_dst 659 if "cells" in inputs: 660 pbc_shifts_skin = pbc_shifts 661 max_pairs_skin = state.get("npairs_skin", 1) 662 mask = d12 < self.cutoff**2 663 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 664 mask, 665 max_pairs_skin, 666 (edge_src, coords.shape[0]), 667 (edge_dst, coords.shape[0]), 668 (d12, self.cutoff**2), 669 ) 670 if "cells" in inputs: 671 pbc_shifts = ( 672 jnp.zeros((max_pairs_skin, 3), dtype=pbc_shifts.dtype) 673 .at[scatter_idx] 674 .set(pbc_shifts, mode="drop") 675 ) 676 overflow = overflow | (npairs_skin > max_pairs_skin) 677 678 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 679 graph_out = { 680 **graph, 681 "edge_src": edge_src, 682 "edge_dst": edge_dst, 683 "d12": d12, 684 "overflow": overflow, 685 "pbc_shifts": pbc_shifts, 686 } 687 if "nblist_skin" in state: 688 graph_out["edge_src_skin"] = edge_src_skin 689 graph_out["edge_dst_skin"] = edge_dst_skin 690 if "cells" in inputs: 691 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 692 693 if self.k_space and "cells" in inputs: 694 if "k_points" not in graph: 695 raise NotImplementedError( 696 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 697 ) 698 return {**inputs, self.graph_key: graph_out} 699 700 @partial(jax.jit, static_argnums=(0,)) 701 def update_skin(self, inputs): 702 """update the nblist without recomputing the full nblist""" 703 graph = inputs[self.graph_key] 704 705 edge_src_skin = graph["edge_src_skin"] 706 edge_dst_skin = graph["edge_dst_skin"] 707 coords = inputs["coordinates"] 708 vec = coords.at[edge_dst_skin].get( 709 mode="fill", fill_value=self.cutoff 710 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 711 712 if "cells" in inputs: 713 pbc_shifts_skin = graph["pbc_shifts_skin"] 714 cells = inputs["cells"] 715 if cells.shape[0] == 1: 716 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 717 else: 718 batch_index_vec = inputs["batch_index"][edge_src_skin] 719 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 720 721 nat = coords.shape[0] 722 d12 = jnp.sum(vec**2, axis=-1) 723 mask = d12 < self.cutoff**2 724 max_pairs = graph["edge_src"].shape[0] 725 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 726 mask, 727 max_pairs, 728 (edge_src_skin, nat), 729 (edge_dst_skin, nat), 730 (d12, self.cutoff**2), 731 ) 732 if "cells" in inputs: 733 pbc_shifts = ( 734 jnp.zeros((max_pairs, 3), dtype=pbc_shifts_skin.dtype) 735 .at[scatter_idx] 736 .set(pbc_shifts_skin) 737 ) 738 739 overflow = graph.get("overflow", False) | (npairs > max_pairs) 740 graph_out = { 741 **graph, 742 "edge_src": edge_src, 743 "edge_dst": edge_dst, 744 "d12": d12, 745 "overflow": overflow, 746 } 747 if "cells" in inputs: 748 graph_out["pbc_shifts"] = pbc_shifts 749 750 if self.k_space and "cells" in inputs: 751 if "k_points" not in graph: 752 raise NotImplementedError( 753 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 754 ) 755 756 return {**inputs, self.graph_key: graph_out}
Generate a graph from a set of coordinates
FPID: GRAPH
For now, we generate all pairs of atoms and filter based on cutoff.
If a nblist_skin is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin method to update the original graph without recomputing the full nblist.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.
390 def check_reallocate(self, state, inputs, parent_overflow=False): 391 """check for overflow and reallocate nblist if necessary""" 392 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 393 if not overflow: 394 return state, {}, inputs, False 395 396 add_margin = inputs[self.graph_key].get("overflow", False) 397 state, inputs, state_up = self( 398 state, inputs, return_state_update=True, add_margin=add_margin 399 ) 400 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
402 @partial(jax.jit, static_argnums=(0, 1)) 403 def process(self, state, inputs): 404 """build a nblist on accelerator with jax and precomputed shapes""" 405 if self.graph_key in inputs: 406 graph = inputs[self.graph_key] 407 if "keep_graph" in graph: 408 return inputs 409 coords = inputs["coordinates"] 410 natoms = inputs["natoms"] 411 batch_index = inputs["batch_index"] 412 413 if natoms.shape[0] == 1: 414 max_nat = coords.shape[0] 415 else: 416 max_nat = state.get( 417 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 418 ) 419 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 420 421 ### compute indices of all pairs 422 apply_pbc = "cells" in inputs 423 minimage = "minimum_image" in inputs.get("flags", {}) 424 include_self_image = apply_pbc and not minimage 425 426 shift = 0 if include_self_image else 1 427 p1, p2 = np.triu_indices(max_nat, shift) 428 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 429 pbc_shifts = None 430 if natoms.shape[0] > 1: 431 ## batching => mask irrelevant pairs 432 mask_p12 = ( 433 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 434 ).flatten() 435 shift = jnp.concatenate( 436 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 437 ) 438 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 439 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 440 441 ## compute vectors 442 overflow_repeats = jnp.asarray(False, dtype=bool) 443 if "cells" not in inputs: 444 vec = coords[p2] - coords[p1] 445 else: 446 cells = inputs["cells"] 447 reciprocal_cells = inputs["reciprocal_cells"] 448 # minimage = state.get("minimum_image", True) 449 minimage = "minimum_image" in inputs.get("flags", {}) 450 451 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 452 vecpbc = jnp.dot(vec, reciprocal_cell) 453 if mode == "round": 454 pbc_shifts = -jnp.round(vecpbc).astype(jnp.int32) 455 elif mode == "floor": 456 pbc_shifts = -jnp.floor(vecpbc).astype(jnp.int32) 457 else: 458 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 459 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 460 461 if minimage: 462 ## minimum image convention 463 vec = coords[p2] - coords[p1] 464 465 if cells.shape[0] == 1: 466 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 467 else: 468 batch_index_vec = batch_index[p1] 469 vec, pbc_shifts = jax.vmap(compute_pbc)( 470 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 471 ) 472 else: 473 ### general PBC only for single cell yet 474 # if cells.shape[0] > 1: 475 # raise NotImplementedError( 476 # "General PBC not implemented for batches on accelerator." 477 # ) 478 # cell = cells[0] 479 # reciprocal_cell = reciprocal_cells[0] 480 481 ## put all atoms in central box 482 if cells.shape[0] == 1: 483 coords_pbc, at_shifts = compute_pbc( 484 coords, reciprocal_cells[0], cells[0], mode="floor" 485 ) 486 else: 487 coords_pbc, at_shifts = jax.vmap( 488 partial(compute_pbc, mode="floor") 489 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 490 vec = coords_pbc[p2] - coords_pbc[p1] 491 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 492 # if num_repeats is None: 493 # raise ValueError( 494 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 495 # ) 496 # check if num_repeats is larger than previous 497 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 498 cdinv = cutoff_skin * inv_distances 499 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 500 if "true_sys" in inputs: 501 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 502 num_repeats_new = jnp.max(num_repeats_all, axis=0) 503 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 504 505 cell_shift_pbc = jnp.asarray( 506 np.array( 507 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 508 dtype=np.int32, 509 ).T.reshape(-1, 3) 510 ) 511 512 if cells.shape[0] == 1: 513 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 514 pbc_shifts = jnp.broadcast_to( 515 cell_shift_pbc[None, :, :], 516 (p1.shape[0], cell_shift_pbc.shape[0], 3), 517 ).reshape(-1, 3) 518 p1 = jnp.broadcast_to( 519 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 520 ).flatten() 521 p2 = jnp.broadcast_to( 522 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 523 ).flatten() 524 if natoms.shape[0] > 1: 525 mask_p12 = jnp.broadcast_to( 526 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 527 ).flatten() 528 else: 529 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 530 531 ## get pbc shifts specific to each box 532 cell_shift_pbc = jnp.broadcast_to( 533 cell_shift_pbc[None, :, :], 534 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 535 ) 536 mask = jnp.all( 537 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 538 ).flatten() 539 max_shifts = state.get("nshifts_pbc", 1) 540 541 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 542 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 543 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 544 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 545 mask, 546 max_shifts, 547 (dvecx, 0.), 548 (dvecy, 0.), 549 (dvecz, 0.), 550 (shiftx, 0), 551 (shifty, 0), 552 (shiftz, 0), 553 ) 554 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 555 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1, dtype=jnp.int32) 556 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 557 558 ## get batch shift in the dvec_filter array 559 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 560 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 561 562 ## repeat vectors 563 nvec_max = state.get("nvec_pbc", 1) 564 batch_index_vec = batch_index[p1] 565 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 566 nvec = nrep_vec.sum() 567 overflow_repeats = overflow_repeats | (nvec > nvec_max) 568 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 569 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 570 571 ## get shift index 572 dshift = jnp.concatenate( 573 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 574 ) 575 if nrep_vec.size == 0: 576 dshift = jnp.array([],dtype=jnp.int32) 577 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 578 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 579 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 580 vec = vec + dvec[icellshift] 581 pbc_shifts = cell_shift_pbc[icellshift] 582 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 583 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 584 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 585 586 mask_valid = jnp.arange(nvec_max) < nvec 587 mask_p12 = jnp.where(mask_valid, mask_p12, False) 588 589 590 ## compute distances 591 d12 = (vec**2).sum(axis=-1) 592 if natoms.shape[0] > 1: 593 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 594 595 ## filter pairs 596 max_pairs = state.get("npairs", 1) 597 mask = d12 < cutoff_skin**2 598 if include_self_image: 599 mask_self = jnp.logical_or(p1 != p2, d12 > 1.e-3) 600 mask = jnp.logical_and(mask, mask_self) 601 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 602 mask, 603 max_pairs, 604 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 605 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 606 (d12, cutoff_skin**2), 607 ) 608 if "cells" in inputs: 609 pbc_shifts = ( 610 jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype) 611 .at[scatter_idx] 612 .set(pbc_shifts, mode="drop") 613 ) 614 if not minimage: 615 pbc_shifts = ( 616 pbc_shifts 617 + at_shifts.at[edge_dst].get(fill_value=0) 618 - at_shifts.at[edge_src].get(fill_value=0) 619 ) 620 621 ## check for overflow 622 if natoms.shape[0] == 1: 623 true_max_nat = coords.shape[0] 624 else: 625 true_max_nat = jnp.max(natoms) 626 overflow_count = npairs > max_pairs 627 overflow_at = true_max_nat > max_nat 628 overflow = overflow_count | overflow_at | overflow_repeats 629 630 ## symmetrize 631 if include_self_image: 632 mask_noself = edge_src != edge_dst 633 max_noself = state.get("npairs_noself", 1) 634 (edge_src_noself, edge_dst_noself, d12_noself, pbc_shifts_noself), scatter_idx, npairs_noself = mask_filter_1d( 635 mask_noself, 636 max_noself, 637 (edge_src, coords.shape[0]), 638 (edge_dst, coords.shape[0]), 639 (d12, cutoff_skin**2), 640 (pbc_shifts, jnp.zeros((3,), dtype=pbc_shifts.dtype)), 641 ) 642 overflow = overflow | (npairs_noself > max_noself) 643 edge_src = jnp.concatenate((edge_src, edge_dst_noself)) 644 edge_dst = jnp.concatenate((edge_dst, edge_src_noself)) 645 d12 = jnp.concatenate((d12, d12_noself)) 646 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts_noself)) 647 else: 648 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 649 (edge_dst, edge_src) 650 ) 651 d12 = jnp.concatenate((d12, d12)) 652 if "cells" in inputs: 653 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 654 655 if "nblist_skin" in state: 656 # edge_mask_skin = edge_mask 657 edge_src_skin = edge_src 658 edge_dst_skin = edge_dst 659 if "cells" in inputs: 660 pbc_shifts_skin = pbc_shifts 661 max_pairs_skin = state.get("npairs_skin", 1) 662 mask = d12 < self.cutoff**2 663 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 664 mask, 665 max_pairs_skin, 666 (edge_src, coords.shape[0]), 667 (edge_dst, coords.shape[0]), 668 (d12, self.cutoff**2), 669 ) 670 if "cells" in inputs: 671 pbc_shifts = ( 672 jnp.zeros((max_pairs_skin, 3), dtype=pbc_shifts.dtype) 673 .at[scatter_idx] 674 .set(pbc_shifts, mode="drop") 675 ) 676 overflow = overflow | (npairs_skin > max_pairs_skin) 677 678 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 679 graph_out = { 680 **graph, 681 "edge_src": edge_src, 682 "edge_dst": edge_dst, 683 "d12": d12, 684 "overflow": overflow, 685 "pbc_shifts": pbc_shifts, 686 } 687 if "nblist_skin" in state: 688 graph_out["edge_src_skin"] = edge_src_skin 689 graph_out["edge_dst_skin"] = edge_dst_skin 690 if "cells" in inputs: 691 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 692 693 if self.k_space and "cells" in inputs: 694 if "k_points" not in graph: 695 raise NotImplementedError( 696 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 697 ) 698 return {**inputs, self.graph_key: graph_out}
build a nblist on accelerator with jax and precomputed shapes
700 @partial(jax.jit, static_argnums=(0,)) 701 def update_skin(self, inputs): 702 """update the nblist without recomputing the full nblist""" 703 graph = inputs[self.graph_key] 704 705 edge_src_skin = graph["edge_src_skin"] 706 edge_dst_skin = graph["edge_dst_skin"] 707 coords = inputs["coordinates"] 708 vec = coords.at[edge_dst_skin].get( 709 mode="fill", fill_value=self.cutoff 710 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 711 712 if "cells" in inputs: 713 pbc_shifts_skin = graph["pbc_shifts_skin"] 714 cells = inputs["cells"] 715 if cells.shape[0] == 1: 716 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 717 else: 718 batch_index_vec = inputs["batch_index"][edge_src_skin] 719 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 720 721 nat = coords.shape[0] 722 d12 = jnp.sum(vec**2, axis=-1) 723 mask = d12 < self.cutoff**2 724 max_pairs = graph["edge_src"].shape[0] 725 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 726 mask, 727 max_pairs, 728 (edge_src_skin, nat), 729 (edge_dst_skin, nat), 730 (d12, self.cutoff**2), 731 ) 732 if "cells" in inputs: 733 pbc_shifts = ( 734 jnp.zeros((max_pairs, 3), dtype=pbc_shifts_skin.dtype) 735 .at[scatter_idx] 736 .set(pbc_shifts_skin) 737 ) 738 739 overflow = graph.get("overflow", False) | (npairs > max_pairs) 740 graph_out = { 741 **graph, 742 "edge_src": edge_src, 743 "edge_dst": edge_dst, 744 "d12": d12, 745 "overflow": overflow, 746 } 747 if "cells" in inputs: 748 graph_out["pbc_shifts"] = pbc_shifts 749 750 if self.k_space and "cells" in inputs: 751 if "k_points" not in graph: 752 raise NotImplementedError( 753 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 754 ) 755 756 return {**inputs, self.graph_key: graph_out}
update the nblist without recomputing the full nblist
759class GraphProcessor(nn.Module): 760 """Process a pre-generated graph 761 762 The pre-generated graph should contain the following keys: 763 - edge_src: source indices of the edges 764 - edge_dst: destination indices of the edges 765 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 766 767 This module is automatically added to a FENNIX model when a GraphGenerator is used. 768 769 """ 770 771 cutoff: float 772 """Cutoff distance for the graph.""" 773 graph_key: str = "graph" 774 """Key of the graph in the outputs.""" 775 switch_params: dict = dataclasses.field(default_factory=dict) 776 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 777 778 @nn.compact 779 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 780 graph = inputs[self.graph_key] 781 coords = inputs["coordinates"] 782 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 783 # edge_mask = edge_src < coords.shape[0] 784 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 785 edge_src 786 ].get(mode="fill", fill_value=0.0) 787 if "cells" in inputs: 788 cells = inputs["cells"] 789 if cells.shape[0] == 1: 790 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 791 else: 792 batch_index_vec = inputs["batch_index"][edge_src] 793 vec = vec + jax.vmap(jnp.dot)( 794 graph["pbc_shifts"], cells[batch_index_vec] 795 ) 796 797 d2 = jnp.sum(vec**2, axis=-1) 798 distances = safe_sqrt(d2) 799 edge_mask = distances < self.cutoff 800 801 switch = SwitchFunction( 802 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 803 )((distances, edge_mask)) 804 805 graph_out = { 806 **graph, 807 "vec": vec, 808 "distances": distances, 809 "switch": switch, 810 "edge_mask": edge_mask, 811 } 812 813 if "alch_group" in inputs: 814 alch_group = inputs["alch_group"] 815 lambda_e = inputs["alch_elambda"] 816 lambda_v = inputs["alch_vlambda"] 817 mask = alch_group[edge_src] == alch_group[edge_dst] 818 graph_out["switch_raw"] = switch 819 graph_out["switch"] = jnp.where( 820 mask, 821 switch, 822 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 823 ) 824 graph_out["distances_raw"] = distances 825 if "alch_softcore_e" in inputs: 826 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 827 else: 828 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 829 830 graph_out["distances"] = jnp.where( 831 mask, 832 distances, 833 safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2)) 834 ) 835 836 837 return {**inputs, self.graph_key: graph_out}
Process a pre-generated graph
The pre-generated graph should contain the following keys:
- edge_src: source indices of the edges
- edge_dst: destination indices of the edges
- pbcs_shifts: pbc shifts for the edges (only if
cellsare present in the inputs)
This module is automatically added to a FENNIX model when a GraphGenerator is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
840@dataclasses.dataclass(frozen=True) 841class GraphFilter: 842 """Filter a graph based on a cutoff distance 843 844 FPID: GRAPH_FILTER 845 """ 846 847 cutoff: float 848 """Cutoff distance for the filtering.""" 849 parent_graph: str 850 """Key of the parent graph in the inputs.""" 851 graph_key: str 852 """Key of the filtered graph in the outputs.""" 853 remove_hydrogens: int = False 854 """Remove edges where the source is a hydrogen atom.""" 855 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 856 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 857 k_space: bool = False 858 """Generate k-space information for the graph.""" 859 kmax: int = 30 860 """Maximum number of k-points to consider.""" 861 kthr: float = 1e-6 862 """Threshold for k-point filtering.""" 863 mult_size: float = 1.05 864 """Multiplicative factor for resizing the nblist.""" 865 866 FPID: ClassVar[str] = "GRAPH_FILTER" 867 868 def init(self): 869 return FrozenDict( 870 { 871 "npairs": 1, 872 "nblist_mult_size": self.mult_size, 873 } 874 ) 875 876 def get_processor(self) -> Tuple[nn.Module, Dict]: 877 return GraphFilterProcessor, { 878 "cutoff": self.cutoff, 879 "graph_key": self.graph_key, 880 "parent_graph": self.parent_graph, 881 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 882 "switch_params": self.switch_params, 883 } 884 885 def get_graph_properties(self): 886 return { 887 self.graph_key: { 888 "cutoff": self.cutoff, 889 "directed": True, 890 "parent_graph": self.parent_graph, 891 } 892 } 893 894 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 895 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 896 graph_in = inputs[self.parent_graph] 897 nat = inputs["species"].shape[0] 898 899 new_state = {**state} 900 state_up = {} 901 mult_size = state.get("nblist_mult_size", self.mult_size) 902 assert mult_size >= 1., "nblist_mult_size should be >= 1." 903 904 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 905 d12 = np.array(graph_in["d12"], dtype=np.float32) 906 if self.remove_hydrogens: 907 species = inputs["species"] 908 src_idx = (edge_src < nat).nonzero()[0] 909 mask = np.zeros(edge_src.shape[0], dtype=bool) 910 mask[src_idx] = (species > 1)[edge_src[src_idx]] 911 d12 = np.where(mask, d12, self.cutoff**2) 912 mask = d12 < self.cutoff**2 913 914 max_pairs = state.get("npairs", 1) 915 idx = np.nonzero(mask)[0] 916 npairs = idx.shape[0] 917 if npairs > max_pairs or add_margin: 918 prev_max_pairs = max_pairs 919 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 920 state_up["npairs"] = (max_pairs, prev_max_pairs) 921 new_state["npairs"] = max_pairs 922 923 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 924 edge_src = np.full(max_pairs, nat, dtype=np.int32) 925 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 926 d12_ = np.full(max_pairs, self.cutoff**2) 927 filter_indices[:npairs] = idx 928 edge_src[:npairs] = graph_in["edge_src"][idx] 929 edge_dst[:npairs] = graph_in["edge_dst"][idx] 930 d12_[:npairs] = d12[idx] 931 d12 = d12_ 932 933 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 934 graph_out = { 935 **graph, 936 "edge_src": edge_src, 937 "edge_dst": edge_dst, 938 "filter_indices": filter_indices, 939 "d12": d12, 940 "overflow": False, 941 } 942 if "cells" in inputs: 943 pbc_shifts = np.zeros((max_pairs, 3), dtype=np.int32) 944 pbc_shifts[:npairs] = graph_in["pbc_shifts"][idx] 945 graph_out["pbc_shifts"] = pbc_shifts 946 947 if self.k_space: 948 if "k_points" not in graph: 949 ks, _, _, bewald = get_reciprocal_space_parameters( 950 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 951 ) 952 graph_out["k_points"] = ks 953 graph_out["b_ewald"] = bewald 954 955 output = {**inputs, self.graph_key: graph_out} 956 if return_state_update: 957 return FrozenDict(new_state), output, state_up 958 return FrozenDict(new_state), output 959 960 def check_reallocate(self, state, inputs, parent_overflow=False): 961 """check for overflow and reallocate nblist if necessary""" 962 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 963 if not overflow: 964 return state, {}, inputs, False 965 966 add_margin = inputs[self.graph_key].get("overflow", False) 967 state, inputs, state_up = self( 968 state, inputs, return_state_update=True, add_margin=add_margin 969 ) 970 return state, state_up, inputs, True 971 972 @partial(jax.jit, static_argnums=(0, 1)) 973 def process(self, state, inputs): 974 """filter a nblist on accelerator with jax and precomputed shapes""" 975 graph_in = inputs[self.parent_graph] 976 if state is None: 977 # skin update mode 978 graph = inputs[self.graph_key] 979 max_pairs = graph["edge_src"].shape[0] 980 else: 981 max_pairs = state.get("npairs", 1) 982 983 max_pairs_in = graph_in["edge_src"].shape[0] 984 nat = inputs["species"].shape[0] 985 986 edge_src = graph_in["edge_src"] 987 d12 = graph_in["d12"] 988 if self.remove_hydrogens: 989 species = inputs["species"] 990 mask = (species > 1)[edge_src] 991 d12 = jnp.where(mask, d12, self.cutoff**2) 992 mask = d12 < self.cutoff**2 993 994 (edge_src, edge_dst, d12, filter_indices), scatter_idx, npairs = mask_filter_1d( 995 mask, 996 max_pairs, 997 (edge_src, nat), 998 (graph_in["edge_dst"], nat), 999 (d12, self.cutoff**2), 1000 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 1001 ) 1002 1003 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 1004 overflow = graph.get("overflow", False) | (npairs > max_pairs) 1005 graph_out = { 1006 **graph, 1007 "edge_src": edge_src, 1008 "edge_dst": edge_dst, 1009 "filter_indices": filter_indices, 1010 "d12": d12, 1011 "overflow": overflow, 1012 } 1013 1014 if "cells" in inputs: 1015 pbc_shifts = graph_in["pbc_shifts"] 1016 pbc_shifts = ( 1017 jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype) 1018 .at[scatter_idx].set(pbc_shifts, mode="drop") 1019 ) 1020 graph_out["pbc_shifts"] = pbc_shifts 1021 if self.k_space: 1022 if "k_points" not in graph: 1023 raise NotImplementedError( 1024 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 1025 ) 1026 1027 return {**inputs, self.graph_key: graph_out} 1028 1029 @partial(jax.jit, static_argnums=(0,)) 1030 def update_skin(self, inputs): 1031 return self.process(None, inputs)
Filter a graph based on a cutoff distance
FPID: GRAPH_FILTER
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.
960 def check_reallocate(self, state, inputs, parent_overflow=False): 961 """check for overflow and reallocate nblist if necessary""" 962 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 963 if not overflow: 964 return state, {}, inputs, False 965 966 add_margin = inputs[self.graph_key].get("overflow", False) 967 state, inputs, state_up = self( 968 state, inputs, return_state_update=True, add_margin=add_margin 969 ) 970 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
972 @partial(jax.jit, static_argnums=(0, 1)) 973 def process(self, state, inputs): 974 """filter a nblist on accelerator with jax and precomputed shapes""" 975 graph_in = inputs[self.parent_graph] 976 if state is None: 977 # skin update mode 978 graph = inputs[self.graph_key] 979 max_pairs = graph["edge_src"].shape[0] 980 else: 981 max_pairs = state.get("npairs", 1) 982 983 max_pairs_in = graph_in["edge_src"].shape[0] 984 nat = inputs["species"].shape[0] 985 986 edge_src = graph_in["edge_src"] 987 d12 = graph_in["d12"] 988 if self.remove_hydrogens: 989 species = inputs["species"] 990 mask = (species > 1)[edge_src] 991 d12 = jnp.where(mask, d12, self.cutoff**2) 992 mask = d12 < self.cutoff**2 993 994 (edge_src, edge_dst, d12, filter_indices), scatter_idx, npairs = mask_filter_1d( 995 mask, 996 max_pairs, 997 (edge_src, nat), 998 (graph_in["edge_dst"], nat), 999 (d12, self.cutoff**2), 1000 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 1001 ) 1002 1003 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 1004 overflow = graph.get("overflow", False) | (npairs > max_pairs) 1005 graph_out = { 1006 **graph, 1007 "edge_src": edge_src, 1008 "edge_dst": edge_dst, 1009 "filter_indices": filter_indices, 1010 "d12": d12, 1011 "overflow": overflow, 1012 } 1013 1014 if "cells" in inputs: 1015 pbc_shifts = graph_in["pbc_shifts"] 1016 pbc_shifts = ( 1017 jnp.zeros((max_pairs, 3), dtype=pbc_shifts.dtype) 1018 .at[scatter_idx].set(pbc_shifts, mode="drop") 1019 ) 1020 graph_out["pbc_shifts"] = pbc_shifts 1021 if self.k_space: 1022 if "k_points" not in graph: 1023 raise NotImplementedError( 1024 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 1025 ) 1026 1027 return {**inputs, self.graph_key: graph_out}
filter a nblist on accelerator with jax and precomputed shapes
1034class GraphFilterProcessor(nn.Module): 1035 """Filter processing for a pre-generated graph 1036 1037 This module is automatically added to a FENNIX model when a GraphFilter is used. 1038 """ 1039 1040 cutoff: float 1041 """Cutoff distance for the filtering.""" 1042 graph_key: str 1043 """Key of the filtered graph in the inputs.""" 1044 parent_graph: str 1045 """Key of the parent graph in the inputs.""" 1046 switch_params: dict = dataclasses.field(default_factory=dict) 1047 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 1048 1049 @nn.compact 1050 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1051 graph_in = inputs[self.parent_graph] 1052 graph = inputs[self.graph_key] 1053 1054 d_key = "distances_raw" if "distances_raw" in graph else "distances" 1055 1056 if graph_in["vec"].shape[0] == 0: 1057 vec = graph_in["vec"] 1058 distances = graph_in[d_key] 1059 filter_indices = jnp.asarray([], dtype=jnp.int32) 1060 else: 1061 filter_indices = graph["filter_indices"] 1062 vec = ( 1063 graph_in["vec"] 1064 .at[filter_indices] 1065 .get(mode="fill", fill_value=self.cutoff) 1066 ) 1067 distances = ( 1068 graph_in[d_key] 1069 .at[filter_indices] 1070 .get(mode="fill", fill_value=self.cutoff) 1071 ) 1072 1073 edge_mask = distances < self.cutoff 1074 switch = SwitchFunction( 1075 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 1076 )((distances, edge_mask)) 1077 1078 graph_out = { 1079 **graph, 1080 "vec": vec, 1081 "distances": distances, 1082 "switch": switch, 1083 "filter_indices": filter_indices, 1084 "edge_mask": edge_mask, 1085 } 1086 1087 if "alch_group" in inputs: 1088 edge_src=graph["edge_src"] 1089 edge_dst=graph["edge_dst"] 1090 alch_group = inputs["alch_group"] 1091 lambda_e = inputs["alch_elambda"] 1092 lambda_v = inputs["alch_vlambda"] 1093 mask = alch_group[edge_src] == alch_group[edge_dst] 1094 graph_out["switch_raw"] = switch 1095 graph_out["switch"] = jnp.where( 1096 mask, 1097 switch, 1098 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 1099 ) 1100 1101 graph_out["distances_raw"] = distances 1102 if "alch_softcore_e" in inputs: 1103 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 1104 else: 1105 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 1106 1107 graph_out["distances"] = jnp.where( 1108 mask, 1109 distances, 1110 safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2)) 1111 ) 1112 1113 1114 return {**inputs, self.graph_key: graph_out}
Filter processing for a pre-generated graph
This module is automatically added to a FENNIX model when a GraphFilter is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1117@dataclasses.dataclass(frozen=True) 1118class GraphAngularExtension: 1119 """Add angles list to a graph 1120 1121 FPID: GRAPH_ANGULAR_EXTENSION 1122 """ 1123 1124 mult_size: float = 1.05 1125 """Multiplicative factor for resizing the nblist.""" 1126 add_neigh: int = 5 1127 """Additional neighbors to add to the nblist when resizing.""" 1128 graph_key: str = "graph" 1129 """Key of the graph in the inputs.""" 1130 1131 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1132 1133 def init(self): 1134 return FrozenDict( 1135 { 1136 "nangles": 0, 1137 "nblist_mult_size": self.mult_size, 1138 "max_neigh": self.add_neigh, 1139 "add_neigh": self.add_neigh, 1140 } 1141 ) 1142 1143 def get_processor(self) -> Tuple[nn.Module, Dict]: 1144 return GraphAngleProcessor, { 1145 "graph_key": self.graph_key, 1146 "name": f"{self.graph_key}_AngleProcessor", 1147 } 1148 1149 def get_graph_properties(self): 1150 return { 1151 self.graph_key: { 1152 "has_angles": True, 1153 } 1154 } 1155 1156 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1157 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1158 graph = inputs[self.graph_key] 1159 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1160 1161 new_state = {**state} 1162 state_up = {} 1163 mult_size = state.get("nblist_mult_size", self.mult_size) 1164 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1165 1166 ### count number of neighbors 1167 nat = inputs["species"].shape[0] 1168 count = np.zeros(nat + 1, dtype=np.int32) 1169 np.add.at(count, edge_src, 1) 1170 max_count = int(np.max(count[:-1])) 1171 1172 ### get sizes 1173 max_neigh = state.get("max_neigh", self.add_neigh) 1174 nedge = edge_src.shape[0] 1175 if max_count > max_neigh or add_margin: 1176 prev_max_neigh = max_neigh 1177 max_neigh = max(max_count, max_neigh) + state.get( 1178 "add_neigh", self.add_neigh 1179 ) 1180 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1181 new_state["max_neigh"] = max_neigh 1182 1183 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1184 1185 nedge = edge_src.shape[0] 1186 1187 ### sort edge_src 1188 idx_sort = np.argsort(edge_src) 1189 edge_src_sorted = edge_src[idx_sort] 1190 1191 ### map sparse to dense nblist 1192 offset = np.tile(np.arange(max_count), nat) 1193 if max_count * nat >= nedge: 1194 offset = np.tile(np.arange(max_count), nat)[:nedge] 1195 else: 1196 offset = np.zeros(nedge, dtype=np.int32) 1197 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1198 1199 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1200 mask = edge_src_sorted < nat 1201 indices = edge_src_sorted * max_count + offset 1202 indices = indices[mask] 1203 idx_sort = idx_sort[mask] 1204 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1205 edge_idx[indices] = idx_sort 1206 edge_idx = edge_idx.reshape(nat, max_count) 1207 1208 ### find all triplet for each atom center 1209 local_src, local_dst = np.triu_indices(max_count, 1) 1210 angle_src = edge_idx[:, local_src].flatten() 1211 angle_dst = edge_idx[:, local_dst].flatten() 1212 1213 ### mask for valid angles 1214 mask1 = angle_src < nedge 1215 mask2 = angle_dst < nedge 1216 angle_mask = mask1 & mask2 1217 1218 max_angles = state.get("nangles", 0) 1219 idx = np.nonzero(angle_mask)[0] 1220 nangles = idx.shape[0] 1221 if nangles > max_angles or add_margin: 1222 max_angles_prev = max_angles 1223 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1224 state_up["nangles"] = (max_angles, max_angles_prev) 1225 new_state["nangles"] = max_angles 1226 1227 ## filter angles to sparse representation 1228 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1229 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1230 angle_src_[:nangles] = angle_src[idx] 1231 angle_dst_[:nangles] = angle_dst[idx] 1232 1233 central_atom = np.full(max_angles, nat, dtype=np.int32) 1234 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1235 1236 ## update graph 1237 output = { 1238 **inputs, 1239 self.graph_key: { 1240 **graph, 1241 "angle_src": angle_src_, 1242 "angle_dst": angle_dst_, 1243 "central_atom": central_atom, 1244 "angle_overflow": False, 1245 "max_neigh": max_neigh, 1246 "__max_neigh_array": max_neigh_arr, 1247 }, 1248 } 1249 1250 if return_state_update: 1251 return FrozenDict(new_state), output, state_up 1252 return FrozenDict(new_state), output 1253 1254 def check_reallocate(self, state, inputs, parent_overflow=False): 1255 """check for overflow and reallocate nblist if necessary""" 1256 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1257 if not overflow: 1258 return state, {}, inputs, False 1259 1260 add_margin = inputs[self.graph_key]["angle_overflow"] 1261 state, inputs, state_up = self( 1262 state, inputs, return_state_update=True, add_margin=add_margin 1263 ) 1264 return state, state_up, inputs, True 1265 1266 @partial(jax.jit, static_argnums=(0, 1)) 1267 def process(self, state, inputs): 1268 """build angle nblist on accelerator with jax and precomputed shapes""" 1269 graph = inputs[self.graph_key] 1270 edge_src = graph["edge_src"] 1271 1272 ### count number of neighbors 1273 nat = inputs["species"].shape[0] 1274 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1275 max_count = jnp.max(count) 1276 1277 ### get sizes 1278 if state is None: 1279 max_neigh_arr = graph["__max_neigh_array"] 1280 max_neigh = max_neigh_arr.shape[0] 1281 prev_nangles = graph["angle_src"].shape[0] 1282 else: 1283 max_neigh = state.get("max_neigh", self.add_neigh) 1284 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1285 prev_nangles = state.get("nangles", 0) 1286 1287 nedge = edge_src.shape[0] 1288 1289 ### sort edge_src 1290 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1291 edge_src_sorted = edge_src[idx_sort] 1292 1293 ### map sparse to dense nblist 1294 if max_neigh * nat < nedge: 1295 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1296 offset = jnp.asarray( 1297 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1298 ) 1299 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1300 indices = edge_src_sorted * max_neigh + offset 1301 edge_idx = ( 1302 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1303 .at[indices] 1304 .set(idx_sort, mode="drop") 1305 .reshape(nat, max_neigh) 1306 ) 1307 1308 ### find all triplet for each atom center 1309 local_src, local_dst = np.triu_indices(max_neigh, 1) 1310 angle_src = edge_idx[:, local_src].flatten() 1311 angle_dst = edge_idx[:, local_dst].flatten() 1312 1313 ### mask for valid angles 1314 mask1 = angle_src < nedge 1315 mask2 = angle_dst < nedge 1316 angle_mask = mask1 & mask2 1317 1318 ## filter angles to sparse representation 1319 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1320 angle_mask, 1321 prev_nangles, 1322 (angle_src, nedge), 1323 (angle_dst, nedge), 1324 ) 1325 ## find central atom 1326 central_atom = edge_src[angle_src] 1327 1328 ## check for overflow 1329 angle_overflow = nangles > prev_nangles 1330 neigh_overflow = max_count > max_neigh 1331 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1332 1333 ## update graph 1334 output = { 1335 **inputs, 1336 self.graph_key: { 1337 **graph, 1338 "angle_src": angle_src, 1339 "angle_dst": angle_dst, 1340 "central_atom": central_atom, 1341 "angle_overflow": overflow, 1342 # "max_neigh": max_neigh, 1343 "__max_neigh_array": max_neigh_arr, 1344 }, 1345 } 1346 1347 return output 1348 1349 @partial(jax.jit, static_argnums=(0,)) 1350 def update_skin(self, inputs): 1351 return self.process(None, inputs)
Add angles list to a graph
FPID: GRAPH_ANGULAR_EXTENSION
1254 def check_reallocate(self, state, inputs, parent_overflow=False): 1255 """check for overflow and reallocate nblist if necessary""" 1256 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1257 if not overflow: 1258 return state, {}, inputs, False 1259 1260 add_margin = inputs[self.graph_key]["angle_overflow"] 1261 state, inputs, state_up = self( 1262 state, inputs, return_state_update=True, add_margin=add_margin 1263 ) 1264 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
1266 @partial(jax.jit, static_argnums=(0, 1)) 1267 def process(self, state, inputs): 1268 """build angle nblist on accelerator with jax and precomputed shapes""" 1269 graph = inputs[self.graph_key] 1270 edge_src = graph["edge_src"] 1271 1272 ### count number of neighbors 1273 nat = inputs["species"].shape[0] 1274 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1275 max_count = jnp.max(count) 1276 1277 ### get sizes 1278 if state is None: 1279 max_neigh_arr = graph["__max_neigh_array"] 1280 max_neigh = max_neigh_arr.shape[0] 1281 prev_nangles = graph["angle_src"].shape[0] 1282 else: 1283 max_neigh = state.get("max_neigh", self.add_neigh) 1284 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1285 prev_nangles = state.get("nangles", 0) 1286 1287 nedge = edge_src.shape[0] 1288 1289 ### sort edge_src 1290 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1291 edge_src_sorted = edge_src[idx_sort] 1292 1293 ### map sparse to dense nblist 1294 if max_neigh * nat < nedge: 1295 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1296 offset = jnp.asarray( 1297 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1298 ) 1299 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1300 indices = edge_src_sorted * max_neigh + offset 1301 edge_idx = ( 1302 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1303 .at[indices] 1304 .set(idx_sort, mode="drop") 1305 .reshape(nat, max_neigh) 1306 ) 1307 1308 ### find all triplet for each atom center 1309 local_src, local_dst = np.triu_indices(max_neigh, 1) 1310 angle_src = edge_idx[:, local_src].flatten() 1311 angle_dst = edge_idx[:, local_dst].flatten() 1312 1313 ### mask for valid angles 1314 mask1 = angle_src < nedge 1315 mask2 = angle_dst < nedge 1316 angle_mask = mask1 & mask2 1317 1318 ## filter angles to sparse representation 1319 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1320 angle_mask, 1321 prev_nangles, 1322 (angle_src, nedge), 1323 (angle_dst, nedge), 1324 ) 1325 ## find central atom 1326 central_atom = edge_src[angle_src] 1327 1328 ## check for overflow 1329 angle_overflow = nangles > prev_nangles 1330 neigh_overflow = max_count > max_neigh 1331 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1332 1333 ## update graph 1334 output = { 1335 **inputs, 1336 self.graph_key: { 1337 **graph, 1338 "angle_src": angle_src, 1339 "angle_dst": angle_dst, 1340 "central_atom": central_atom, 1341 "angle_overflow": overflow, 1342 # "max_neigh": max_neigh, 1343 "__max_neigh_array": max_neigh_arr, 1344 }, 1345 } 1346 1347 return output
build angle nblist on accelerator with jax and precomputed shapes
1354class GraphAngleProcessor(nn.Module): 1355 """Process a pre-generated graph to compute angles 1356 1357 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1358 1359 """ 1360 1361 graph_key: str 1362 """Key of the graph in the inputs.""" 1363 1364 @nn.compact 1365 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1366 graph = inputs[self.graph_key] 1367 distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"] 1368 vec = graph["vec"] 1369 angle_src = graph["angle_src"] 1370 angle_dst = graph["angle_dst"] 1371 1372 dir = vec / jnp.clip(distances[:, None], min=1.0e-5) 1373 cos_angles = ( 1374 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1375 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1376 ).sum(axis=-1) 1377 1378 angles = jnp.arccos(0.95 * cos_angles) 1379 1380 return { 1381 **inputs, 1382 self.graph_key: { 1383 **graph, 1384 # "cos_angles": cos_angles, 1385 "angles": angles, 1386 # "angle_mask": angle_mask, 1387 }, 1388 }
Process a pre-generated graph to compute angles
This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1391@dataclasses.dataclass(frozen=True) 1392class SpeciesIndexer: 1393 """Build an index that splits atomic arrays by species. 1394 1395 FPID: SPECIES_INDEXER 1396 1397 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1398 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1399 1400 """ 1401 1402 output_key: str = "species_index" 1403 """Key for the output dictionary.""" 1404 species_order: Optional[str] = None 1405 """Comma separated list of species in the order they should be indexed.""" 1406 add_atoms: int = 0 1407 """Additional atoms to add to the sizes.""" 1408 add_atoms_margin: int = 10 1409 """Additional atoms to add to the sizes when adding margin.""" 1410 1411 FPID: ClassVar[str] = "SPECIES_INDEXER" 1412 1413 def init(self): 1414 return FrozenDict( 1415 { 1416 "sizes": {}, 1417 } 1418 ) 1419 1420 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1421 species = np.array(inputs["species"], dtype=np.int32) 1422 nat = species.shape[0] 1423 set_species, counts = np.unique(species, return_counts=True) 1424 1425 new_state = {**state} 1426 state_up = {} 1427 1428 sizes = state.get("sizes", FrozenDict({})) 1429 new_sizes = {**sizes} 1430 up_sizes = False 1431 counts_dict = {} 1432 for s, c in zip(set_species, counts): 1433 if s <= 0: 1434 continue 1435 counts_dict[s] = c 1436 if c > sizes.get(s, 0): 1437 up_sizes = True 1438 add_atoms = state.get("add_atoms", self.add_atoms) 1439 if add_margin: 1440 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1441 new_sizes[s] = c + add_atoms 1442 1443 new_sizes = FrozenDict(new_sizes) 1444 if up_sizes: 1445 state_up["sizes"] = (new_sizes, sizes) 1446 new_state["sizes"] = new_sizes 1447 1448 if self.species_order is not None: 1449 species_order = [el.strip() for el in self.species_order.split(",")] 1450 max_size_prev = state.get("max_size", 0) 1451 max_size = max(new_sizes.values()) 1452 if max_size > max_size_prev: 1453 state_up["max_size"] = (max_size, max_size_prev) 1454 new_state["max_size"] = max_size 1455 max_size_prev = max_size 1456 1457 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1458 for i, el in enumerate(species_order): 1459 s = PERIODIC_TABLE_REV_IDX[el] 1460 if s in counts_dict.keys(): 1461 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1462 else: 1463 species_index = { 1464 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1465 for s, c in new_sizes.items() 1466 } 1467 for s, c in zip(set_species, counts): 1468 if s <= 0: 1469 continue 1470 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1471 1472 output = { 1473 **inputs, 1474 self.output_key: species_index, 1475 self.output_key + "_overflow": False, 1476 } 1477 1478 if return_state_update: 1479 return FrozenDict(new_state), output, state_up 1480 return FrozenDict(new_state), output 1481 1482 def check_reallocate(self, state, inputs, parent_overflow=False): 1483 """check for overflow and reallocate nblist if necessary""" 1484 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1485 if not overflow: 1486 return state, {}, inputs, False 1487 1488 add_margin = inputs[self.output_key + "_overflow"] 1489 state, inputs, state_up = self( 1490 state, inputs, return_state_update=True, add_margin=add_margin 1491 ) 1492 return state, state_up, inputs, True 1493 # return state, {}, inputs, parent_overflow 1494 1495 @partial(jax.jit, static_argnums=(0, 1)) 1496 def process(self, state, inputs): 1497 # assert ( 1498 # self.output_key in inputs 1499 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1500 1501 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1502 if self.output_key in inputs and not recompute_species_index: 1503 return inputs 1504 1505 if state is None: 1506 raise ValueError("Species Indexer state must be provided on accelerator.") 1507 1508 species = inputs["species"] 1509 nat = species.shape[0] 1510 1511 sizes = state["sizes"] 1512 1513 if self.species_order is not None: 1514 species_order = [el.strip() for el in self.species_order.split(",")] 1515 max_size = state["max_size"] 1516 1517 species_index = jnp.full( 1518 (len(species_order), max_size), nat, dtype=jnp.int32 1519 ) 1520 for i, el in enumerate(species_order): 1521 s = PERIODIC_TABLE_REV_IDX[el] 1522 if s in sizes.keys(): 1523 c = sizes[s] 1524 species_index = species_index.at[i, :].set( 1525 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1526 ) 1527 # if s in counts_dict.keys(): 1528 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1529 else: 1530 # species_index = { 1531 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1532 # for s, c in sizes.items() 1533 # } 1534 species_index = {} 1535 overflow = False 1536 natcount = 0 1537 for s, c in sizes.items(): 1538 mask = species == s 1539 new_size = jnp.sum(mask) 1540 natcount = natcount + new_size 1541 overflow = overflow | (new_size > c) # check if sizes are correct 1542 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1543 species == s, size=c, fill_value=nat 1544 )[0] 1545 1546 mask = species <= 0 1547 new_size = jnp.sum(mask) 1548 natcount = natcount + new_size 1549 overflow = overflow | ( 1550 natcount < species.shape[0] 1551 ) # check if any species missing 1552 1553 return { 1554 **inputs, 1555 self.output_key: species_index, 1556 self.output_key + "_overflow": overflow, 1557 } 1558 1559 @partial(jax.jit, static_argnums=(0,)) 1560 def update_skin(self, inputs): 1561 return self.process(None, inputs)
Build an index that splits atomic arrays by species.
FPID: SPECIES_INDEXER
If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
Comma separated list of species in the order they should be indexed.
1482 def check_reallocate(self, state, inputs, parent_overflow=False): 1483 """check for overflow and reallocate nblist if necessary""" 1484 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1485 if not overflow: 1486 return state, {}, inputs, False 1487 1488 add_margin = inputs[self.output_key + "_overflow"] 1489 state, inputs, state_up = self( 1490 state, inputs, return_state_update=True, add_margin=add_margin 1491 ) 1492 return state, state_up, inputs, True 1493 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1495 @partial(jax.jit, static_argnums=(0, 1)) 1496 def process(self, state, inputs): 1497 # assert ( 1498 # self.output_key in inputs 1499 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1500 1501 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1502 if self.output_key in inputs and not recompute_species_index: 1503 return inputs 1504 1505 if state is None: 1506 raise ValueError("Species Indexer state must be provided on accelerator.") 1507 1508 species = inputs["species"] 1509 nat = species.shape[0] 1510 1511 sizes = state["sizes"] 1512 1513 if self.species_order is not None: 1514 species_order = [el.strip() for el in self.species_order.split(",")] 1515 max_size = state["max_size"] 1516 1517 species_index = jnp.full( 1518 (len(species_order), max_size), nat, dtype=jnp.int32 1519 ) 1520 for i, el in enumerate(species_order): 1521 s = PERIODIC_TABLE_REV_IDX[el] 1522 if s in sizes.keys(): 1523 c = sizes[s] 1524 species_index = species_index.at[i, :].set( 1525 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1526 ) 1527 # if s in counts_dict.keys(): 1528 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1529 else: 1530 # species_index = { 1531 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1532 # for s, c in sizes.items() 1533 # } 1534 species_index = {} 1535 overflow = False 1536 natcount = 0 1537 for s, c in sizes.items(): 1538 mask = species == s 1539 new_size = jnp.sum(mask) 1540 natcount = natcount + new_size 1541 overflow = overflow | (new_size > c) # check if sizes are correct 1542 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1543 species == s, size=c, fill_value=nat 1544 )[0] 1545 1546 mask = species <= 0 1547 new_size = jnp.sum(mask) 1548 natcount = natcount + new_size 1549 overflow = overflow | ( 1550 natcount < species.shape[0] 1551 ) # check if any species missing 1552 1553 return { 1554 **inputs, 1555 self.output_key: species_index, 1556 self.output_key + "_overflow": overflow, 1557 }
1563@dataclasses.dataclass(frozen=True) 1564class BlockIndexer: 1565 """Build an index that splits atomic arrays by chemical blocks. 1566 1567 FPID: BLOCK_INDEXER 1568 1569 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1570 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1571 1572 """ 1573 1574 output_key: str = "block_index" 1575 """Key for the output dictionary.""" 1576 add_atoms: int = 0 1577 """Additional atoms to add to the sizes.""" 1578 add_atoms_margin: int = 10 1579 """Additional atoms to add to the sizes when adding margin.""" 1580 split_CNOPSSe: bool = False 1581 1582 FPID: ClassVar[str] = "BLOCK_INDEXER" 1583 1584 def init(self): 1585 return FrozenDict( 1586 { 1587 "sizes": {}, 1588 } 1589 ) 1590 1591 def build_chemical_blocks(self): 1592 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1593 if self.split_CNOPSSe: 1594 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1595 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1596 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1597 if self.split_CNOPSSe: 1598 _CHEMICAL_BLOCKS[6] = 1 1599 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1600 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1601 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1602 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1603 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1604 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1605 1606 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1607 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1608 1609 species = np.array(inputs["species"], dtype=np.int32) 1610 blocks = _CHEMICAL_BLOCKS[species] 1611 nat = species.shape[0] 1612 set_blocks, counts = np.unique(blocks, return_counts=True) 1613 1614 new_state = {**state} 1615 state_up = {} 1616 1617 sizes = state.get("sizes", FrozenDict({})) 1618 new_sizes = {**sizes} 1619 up_sizes = False 1620 for s, c in zip(set_blocks, counts): 1621 if s < 0: 1622 continue 1623 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1624 if c > sizes.get(key, 0): 1625 up_sizes = True 1626 add_atoms = state.get("add_atoms", self.add_atoms) 1627 if add_margin: 1628 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1629 new_sizes[key] = c + add_atoms 1630 1631 new_sizes = FrozenDict(new_sizes) 1632 if up_sizes: 1633 state_up["sizes"] = (new_sizes, sizes) 1634 new_state["sizes"] = new_sizes 1635 1636 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1637 for (_,n), c in new_sizes.items(): 1638 block_index[n] = np.full(c, nat, dtype=np.int32) 1639 # block_index = { 1640 # n: np.full(c, nat, dtype=np.int32) 1641 # for (_,n), c in new_sizes.items() 1642 # } 1643 for s, c in zip(set_blocks, counts): 1644 if s < 0: 1645 continue 1646 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1647 1648 output = { 1649 **inputs, 1650 self.output_key: block_index, 1651 self.output_key + "_overflow": False, 1652 } 1653 1654 if return_state_update: 1655 return FrozenDict(new_state), output, state_up 1656 return FrozenDict(new_state), output 1657 1658 def check_reallocate(self, state, inputs, parent_overflow=False): 1659 """check for overflow and reallocate nblist if necessary""" 1660 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1661 if not overflow: 1662 return state, {}, inputs, False 1663 1664 add_margin = inputs[self.output_key + "_overflow"] 1665 state, inputs, state_up = self( 1666 state, inputs, return_state_update=True, add_margin=add_margin 1667 ) 1668 return state, state_up, inputs, True 1669 # return state, {}, inputs, parent_overflow 1670 1671 @partial(jax.jit, static_argnums=(0, 1)) 1672 def process(self, state, inputs): 1673 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1674 # assert ( 1675 # self.output_key in inputs 1676 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1677 1678 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1679 if self.output_key in inputs and not recompute_species_index: 1680 return inputs 1681 1682 if state is None: 1683 raise ValueError("Block Indexer state must be provided on accelerator.") 1684 1685 species = inputs["species"] 1686 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1687 nat = species.shape[0] 1688 1689 sizes = state["sizes"] 1690 1691 # species_index = { 1692 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1693 # for s, c in sizes.items() 1694 # } 1695 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1696 overflow = False 1697 natcount = 0 1698 for (s,name), c in sizes.items(): 1699 mask = blocks == s 1700 new_size = jnp.sum(mask) 1701 natcount = natcount + new_size 1702 overflow = overflow | (new_size > c) # check if sizes are correct 1703 block_index[name] = jnp.nonzero( 1704 mask, size=c, fill_value=nat 1705 )[0] 1706 1707 mask = blocks < 0 1708 new_size = jnp.sum(mask) 1709 natcount = natcount + new_size 1710 overflow = overflow | ( 1711 natcount < species.shape[0] 1712 ) # check if any species missing 1713 1714 return { 1715 **inputs, 1716 self.output_key: block_index, 1717 self.output_key + "_overflow": overflow, 1718 } 1719 1720 @partial(jax.jit, static_argnums=(0,)) 1721 def update_skin(self, inputs): 1722 return self.process(None, inputs)
Build an index that splits atomic arrays by chemical blocks.
FPID: BLOCK_INDEXER
If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1591 def build_chemical_blocks(self): 1592 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1593 if self.split_CNOPSSe: 1594 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1595 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1596 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1597 if self.split_CNOPSSe: 1598 _CHEMICAL_BLOCKS[6] = 1 1599 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1600 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1601 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1602 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1603 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1604 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1658 def check_reallocate(self, state, inputs, parent_overflow=False): 1659 """check for overflow and reallocate nblist if necessary""" 1660 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1661 if not overflow: 1662 return state, {}, inputs, False 1663 1664 add_margin = inputs[self.output_key + "_overflow"] 1665 state, inputs, state_up = self( 1666 state, inputs, return_state_update=True, add_margin=add_margin 1667 ) 1668 return state, state_up, inputs, True 1669 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1671 @partial(jax.jit, static_argnums=(0, 1)) 1672 def process(self, state, inputs): 1673 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1674 # assert ( 1675 # self.output_key in inputs 1676 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1677 1678 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1679 if self.output_key in inputs and not recompute_species_index: 1680 return inputs 1681 1682 if state is None: 1683 raise ValueError("Block Indexer state must be provided on accelerator.") 1684 1685 species = inputs["species"] 1686 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1687 nat = species.shape[0] 1688 1689 sizes = state["sizes"] 1690 1691 # species_index = { 1692 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1693 # for s, c in sizes.items() 1694 # } 1695 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1696 overflow = False 1697 natcount = 0 1698 for (s,name), c in sizes.items(): 1699 mask = blocks == s 1700 new_size = jnp.sum(mask) 1701 natcount = natcount + new_size 1702 overflow = overflow | (new_size > c) # check if sizes are correct 1703 block_index[name] = jnp.nonzero( 1704 mask, size=c, fill_value=nat 1705 )[0] 1706 1707 mask = blocks < 0 1708 new_size = jnp.sum(mask) 1709 natcount = natcount + new_size 1710 overflow = overflow | ( 1711 natcount < species.shape[0] 1712 ) # check if any species missing 1713 1714 return { 1715 **inputs, 1716 self.output_key: block_index, 1717 self.output_key + "_overflow": overflow, 1718 }
1725@dataclasses.dataclass(frozen=True) 1726class AtomPadding: 1727 """Pad atomic arrays to a fixed size.""" 1728 1729 mult_size: float = 1.2 1730 """Multiplicative factor for resizing the atomic arrays.""" 1731 add_sys: int = 0 1732 1733 def init(self): 1734 return {"prev_nat": 0, "prev_nsys": 0} 1735 1736 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1737 species = inputs["species"] 1738 nat = species.shape[0] 1739 1740 prev_nat = state.get("prev_nat", 0) 1741 prev_nat_ = prev_nat 1742 if nat > prev_nat_: 1743 prev_nat_ = int(self.mult_size * nat) + 1 1744 1745 nsys = len(inputs["natoms"]) 1746 prev_nsys = state.get("prev_nsys", 0) 1747 prev_nsys_ = prev_nsys 1748 if nsys > prev_nsys_: 1749 prev_nsys_ = nsys + self.add_sys 1750 1751 add_atoms = prev_nat_ - nat 1752 add_sys = prev_nsys_ - nsys + 1 1753 output = {**inputs} 1754 if add_atoms > 0: 1755 for k, v in inputs.items(): 1756 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1757 if v.shape[0] == nat: 1758 output[k] = np.append( 1759 v, 1760 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1761 axis=0, 1762 ) 1763 elif v.shape[0] == nsys: 1764 if k == "cells": 1765 output[k] = np.append( 1766 v, 1767 1000 1768 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1769 add_sys, axis=0 1770 ), 1771 axis=0, 1772 ) 1773 else: 1774 output[k] = np.append( 1775 v, 1776 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1777 axis=0, 1778 ) 1779 output["natoms"] = np.append( 1780 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1781 ) 1782 output["species"] = np.append( 1783 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1784 ) 1785 output["batch_index"] = np.append( 1786 inputs["batch_index"], 1787 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1788 ) 1789 if "system_index" in inputs: 1790 output["system_index"] = np.append( 1791 inputs["system_index"], 1792 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1793 ) 1794 1795 output["true_atoms"] = output["species"] > 0 1796 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1797 1798 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1799 1800 return FrozenDict(state), output
Pad atomic arrays to a fixed size.
1803def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1804 """Remove padding from atomic arrays.""" 1805 if "true_atoms" not in inputs: 1806 return inputs 1807 1808 species = np.asarray(inputs["species"]) 1809 true_atoms = np.asarray(inputs["true_atoms"]) 1810 true_sys = np.asarray(inputs["true_sys"]) 1811 natall = species.shape[0] 1812 nat = np.argmax(species <= 0) 1813 if nat == 0: 1814 return inputs 1815 1816 natoms = inputs["natoms"] 1817 nsysall = len(natoms) 1818 1819 output = {**inputs} 1820 for k, v in inputs.items(): 1821 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1822 v = np.asarray(v) 1823 if v.ndim == 0: 1824 output[k] = v 1825 elif v.shape[0] == natall: 1826 output[k] = v[true_atoms] 1827 elif v.shape[0] == nsysall: 1828 output[k] = v[true_sys] 1829 del output["true_sys"] 1830 del output["true_atoms"] 1831 return output
Remove padding from atomic arrays.
1834def check_input(inputs): 1835 """Check the input dictionary for required keys and types.""" 1836 assert "species" in inputs, "species must be provided" 1837 assert "coordinates" in inputs, "coordinates must be provided" 1838 species = inputs["species"].astype(np.int32) 1839 ifake = np.argmax(species <= 0) 1840 if ifake > 0: 1841 assert np.all(species[:ifake] > 0), "species must be positive" 1842 nat = inputs["species"].shape[0] 1843 1844 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1845 batch_index = inputs.get( 1846 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1847 ).astype(np.int32) 1848 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1849 if "cells" in inputs: 1850 cells = inputs["cells"] 1851 if "reciprocal_cells" not in inputs: 1852 reciprocal_cells = np.linalg.inv(cells) 1853 else: 1854 reciprocal_cells = inputs["reciprocal_cells"] 1855 if cells.ndim == 2: 1856 cells = cells[None, :, :] 1857 if reciprocal_cells.ndim == 2: 1858 reciprocal_cells = reciprocal_cells[None, :, :] 1859 output["cells"] = cells 1860 output["reciprocal_cells"] = reciprocal_cells 1861 1862 return output
Check the input dictionary for required keys and types.
1865def convert_to_jax(data): 1866 """Convert a numpy arrays to jax arrays in a pytree.""" 1867 1868 def convert(x): 1869 if isinstance(x, np.ndarray): 1870 # if x.dtype == np.float64: 1871 # return jnp.asarray(x, dtype=jnp.float32) 1872 return jnp.asarray(x) 1873 return x 1874 1875 return jax.tree_util.tree_map(convert, data)
Convert a numpy arrays to jax arrays in a pytree.
1878class JaxConverter(nn.Module): 1879 """Convert numpy arrays to jax arrays in a pytree.""" 1880 1881 def __call__(self, data): 1882 return convert_to_jax(data)
Convert numpy arrays to jax arrays in a pytree.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1884def convert_to_numpy(data): 1885 """Convert jax arrays to numpy arrays in a pytree.""" 1886 1887 def convert(x): 1888 if isinstance(x, jax.Array): 1889 return np.array(x) 1890 return x 1891 1892 return jax.tree_util.tree_map(convert, data)
Convert jax arrays to numpy arrays in a pytree.
1895@dataclasses.dataclass(frozen=True) 1896class PreprocessingChain: 1897 """Chain of preprocessing layers.""" 1898 1899 layers: Tuple[Callable[..., Dict[str, Any]]] 1900 """Preprocessing layers.""" 1901 use_atom_padding: bool = False 1902 """Add an AtomPadding layer at the beginning of the chain.""" 1903 atom_padder: AtomPadding = AtomPadding() 1904 """AtomPadding layer.""" 1905 1906 def __post_init__(self): 1907 if not isinstance(self.layers, Sequence): 1908 raise ValueError( 1909 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1910 ) 1911 if not self.layers: 1912 raise ValueError(f"Error: no Preprocessing layers were provided.") 1913 1914 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1915 do_check_input = state.get("check_input", True) 1916 if do_check_input: 1917 inputs = check_input(inputs) 1918 new_state = {**state} 1919 if self.use_atom_padding: 1920 s, inputs = self.atom_padder(state["padder_state"], inputs) 1921 new_state["padder_state"] = s 1922 layer_state = state["layers_state"] 1923 new_layer_state = [] 1924 for i,layer in enumerate(self.layers): 1925 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1926 new_layer_state.append(s) 1927 new_state["layers_state"] = tuple(new_layer_state) 1928 return FrozenDict(new_state), convert_to_jax(inputs) 1929 1930 def check_reallocate(self, state, inputs): 1931 new_state = [] 1932 state_up = [] 1933 layer_state = state["layers_state"] 1934 parent_overflow = False 1935 for i,layer in enumerate(self.layers): 1936 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1937 layer_state[i], inputs, parent_overflow 1938 ) 1939 new_state.append(s) 1940 state_up.append(s_up) 1941 1942 if not parent_overflow: 1943 return state, {}, inputs, False 1944 return ( 1945 FrozenDict({**state, "layers_state": tuple(new_state)}), 1946 state_up, 1947 inputs, 1948 True, 1949 ) 1950 1951 def atom_padding(self, state, inputs): 1952 if self.use_atom_padding: 1953 padder_state,inputs = self.atom_padder(state["padder_state"], inputs) 1954 return FrozenDict({**state,"padder_state": padder_state}), inputs 1955 return state, inputs 1956 1957 @partial(jax.jit, static_argnums=(0, 1)) 1958 def process(self, state, inputs): 1959 layer_state = state["layers_state"] 1960 for i,layer in enumerate(self.layers): 1961 inputs = layer.process(layer_state[i], inputs) 1962 return inputs 1963 1964 @partial(jax.jit, static_argnums=(0)) 1965 def update_skin(self, inputs): 1966 for layer in self.layers: 1967 inputs = layer.update_skin(inputs) 1968 return inputs 1969 1970 def init(self): 1971 state = {"check_input": True} 1972 if self.use_atom_padding: 1973 state["padder_state"] = self.atom_padder.init() 1974 layer_state = [] 1975 for layer in self.layers: 1976 layer_state.append(layer.init()) 1977 state["layers_state"] = tuple(layer_state) 1978 return FrozenDict(state) 1979 1980 def init_with_output(self, inputs): 1981 state = self.init() 1982 return self(state, inputs) 1983 1984 def get_processors(self): 1985 processors = [] 1986 for layer in self.layers: 1987 if hasattr(layer, "get_processor"): 1988 processors.append(layer.get_processor()) 1989 return processors 1990 1991 def get_graphs_properties(self): 1992 properties = {} 1993 for layer in self.layers: 1994 if hasattr(layer, "get_graph_properties"): 1995 properties = deep_update(properties, layer.get_graph_properties()) 1996 return properties
Chain of preprocessing layers.
1930 def check_reallocate(self, state, inputs): 1931 new_state = [] 1932 state_up = [] 1933 layer_state = state["layers_state"] 1934 parent_overflow = False 1935 for i,layer in enumerate(self.layers): 1936 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1937 layer_state[i], inputs, parent_overflow 1938 ) 1939 new_state.append(s) 1940 state_up.append(s_up) 1941 1942 if not parent_overflow: 1943 return state, {}, inputs, False 1944 return ( 1945 FrozenDict({**state, "layers_state": tuple(new_state)}), 1946 state_up, 1947 inputs, 1948 True, 1949 )
1970 def init(self): 1971 state = {"check_input": True} 1972 if self.use_atom_padding: 1973 state["padder_state"] = self.atom_padder.init() 1974 layer_state = [] 1975 for layer in self.layers: 1976 layer_state.append(layer.init()) 1977 state["layers_state"] = tuple(layer_state) 1978 return FrozenDict(state)