fennol.models.embeddings.mace
1import functools 2import math 3import jax 4import jax.numpy as jnp 5import flax.linen as nn 6from typing import Sequence, Dict, Union, ClassVar, Optional, Set 7import dataclasses 8 9from ..misc.encodings import RadialBasis 10from ...utils.activations import activation_from_str 11 12 13try: 14 import e3nn_jax as e3nn 15 16 E3NN_AVAILABLE = True 17 E3NN_EXCEPTION = None 18 Irreps = e3nn.Irreps 19 Irrep = e3nn.Irrep 20except Exception as e: 21 E3NN_AVAILABLE = False 22 E3NN_EXCEPTION = e 23 e3nn = None 24 25 class Irreps(tuple): 26 pass 27 28 class Irrep(tuple): 29 pass 30 31 32class MACE(nn.Module): 33 """MACE equivariant message passing neural network. 34 35 adapted from MACE-jax github repo by M. Geiger and I. Batatia 36 37 T. Plé reordered some operations and changed defaults to match the recent mace-torch version 38 -> compatibility with pretrained torch models requires some work on the parameters: 39 - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling 40 - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors 41 - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters 42 43 References: 44 - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. 45 https://doi.org/10.48550/arXiv.2206.07697 46 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). 47 https://doi.org/10.48550/arXiv.2205.06643 48 49 """ 50 _graphs_properties: Dict 51 output_irreps: Union[Irreps, str] = "1x0e" 52 """The output irreps of the model.""" 53 hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o" 54 """The hidden irreps of the model.""" 55 readout_mlp_irreps: Union[Irreps, str] = "16x0e" 56 """The hidden irreps of the readout MLP.""" 57 graph_key: str = "graph" 58 """The key in the input dictionary that corresponds to the molecular graph to use.""" 59 output_key: Optional[str] = None 60 """The key of the embedding in the output dictionary.""" 61 avg_num_neighbors: float = 1.0 62 """The expected average number of neighbors.""" 63 ninteractions: int = 2 64 """The number of interaction layers.""" 65 num_features: Optional[int] = None 66 """The number of features per node. default gcd of hidden_irreps multiplicities""" 67 radial_basis: dict = dataclasses.field( 68 default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False} 69 ) 70 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 71 lmax: int = 1 72 """The maximum angular momentum to consider.""" 73 correlation: int = 3 74 """The correlation order at each layer.""" 75 activation: str = "silu" 76 """The activation function to use.""" 77 symmetric_tensor_product_basis: bool = False 78 """Whether to use the symmetric tensor product basis.""" 79 interaction_irreps: Union[Irreps, str] = "o3_restricted" 80 skip_connection_first_layer: bool = True 81 radial_network_hidden: Sequence[int] = dataclasses.field( 82 default_factory=lambda: [64, 64, 64] 83 ) 84 scalar_output: bool = False 85 zmax: int = 86 86 """The maximum atomic number to consider.""" 87 convolution_mode: int = 1 88 species_encoding_key: Optional[str] = None 89 90 FID: ClassVar[str] = "MACE" 91 92 @nn.compact 93 def __call__(self, inputs): 94 if not E3NN_AVAILABLE: 95 raise E3NN_EXCEPTION 96 97 species_indices = inputs["species"] 98 graph = inputs[self.graph_key] 99 distances = graph["distances"] 100 vec = e3nn.IrrepsArray("1o", graph["vec"]) 101 switch = graph["switch"] 102 edge_src = graph["edge_src"] 103 edge_dst = graph["edge_dst"] 104 105 output_irreps = e3nn.Irreps(self.output_irreps) 106 hidden_irreps = e3nn.Irreps(self.hidden_irreps) 107 readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps) 108 109 # extract or set num_features 110 if self.num_features is None: 111 num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps)) 112 hidden_irreps = e3nn.Irreps( 113 [(mul // num_features, ir) for mul, ir in hidden_irreps] 114 ) 115 else: 116 num_features = self.num_features 117 118 # get interaction irreps 119 if self.interaction_irreps == "o3_restricted": 120 interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax) 121 elif self.interaction_irreps == "o3_full": 122 interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax)) 123 else: 124 interaction_irreps = e3nn.Irreps(self.interaction_irreps) 125 convol_irreps = num_features * interaction_irreps 126 127 # convert species to internal indices 128 # maxidx = max(PERIODIC_TABLE_REV_IDX.values()) 129 # conv_tensor = [0] * (maxidx + 2) 130 # if isinstance(self.species_order, str): 131 # species_order = [el.strip() for el in self.species_order.split(",")] 132 # else: 133 # species_order = [el for el in self.species_order] 134 # for i, s in enumerate(species_order): 135 # conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i 136 # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species] 137 num_species = self.zmax + 2 138 139 # species encoding 140 encoding_irreps: e3nn.Irreps = ( 141 (num_features * hidden_irreps).filter("0e").regroup() 142 ) 143 if self.species_encoding_key is not None: 144 species_encoding = nn.Dense(encoding_irreps.dim,use_bias=False)(inputs[self.species_encoding_key]) 145 else: 146 species_encoding = self.param( 147 "species_encoding", 148 lambda key, shape: jax.nn.standardize( 149 jax.random.normal(key, shape, dtype=jnp.float32) 150 ), 151 (num_species, encoding_irreps.dim), 152 )[species_indices] 153 # convert to IrrepsArray 154 node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding) 155 156 # radial embedding 157 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 158 radial_embedding = ( 159 RadialBasis( 160 **{ 161 **self.radial_basis, 162 "end": cutoff, 163 "name": f"RadialBasis", 164 } 165 )(distances) 166 * switch[:, None] 167 ) 168 169 # spherical harmonics 170 assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2" 171 if self.convolution_mode == 0: 172 Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True) 173 elif self.convolution_mode == 1: 174 Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True) 175 176 outputs = [] 177 node_feats_all = [] 178 for layer in range(self.ninteractions): 179 first = layer == 0 180 last = layer == self.ninteractions - 1 181 182 layer_irreps = num_features * ( 183 hidden_irreps if not last else hidden_irreps.filter(output_irreps) 184 ) 185 186 # Linear skip connection 187 sc = None 188 if not first or self.skip_connection_first_layer: 189 sc = e3nn.flax.Linear( 190 layer_irreps, 191 num_indexed_weights=num_species, 192 name=f"skip_tp_{layer}", 193 force_irreps_out=True, 194 )(species_indices, node_feats) 195 196 ################################################ 197 # Interaction block (Message passing convolution) 198 node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")( 199 node_feats 200 ) 201 202 203 messages = node_feats[edge_src] 204 if self.convolution_mode == 0: 205 messages = e3nn.tensor_product( 206 messages, 207 Yij, 208 filter_ir_out=convol_irreps, 209 regroup_output=True, 210 ) 211 elif self.convolution_mode == 1: 212 messages = e3nn.concatenate( 213 [ 214 messages.filter(convol_irreps), 215 e3nn.tensor_product( 216 messages, 217 Yij, 218 filter_ir_out=convol_irreps, 219 ), 220 # e3nn.tensor_product_with_spherical_harmonics( 221 # messages, vectors, self.max_ell 222 # ).filter(convol_irreps), 223 ] 224 ).regroup() 225 else: 226 messages = e3nn.tensor_product_with_spherical_harmonics( 227 messages, vec, self.lmax 228 ).filter(convol_irreps).regroup() 229 230 # mix = FullyConnectedNet( 231 # [*self.radial_network_hidden, messages.irreps.num_irreps], 232 # activation=activation_from_str(self.activation), 233 # name=f"radial_network_{layer}", 234 # use_bias=False, 235 # )(radial_embedding) 236 mix = e3nn.flax.MultiLayerPerceptron( 237 [*self.radial_network_hidden, messages.irreps.num_irreps], 238 act=activation_from_str(self.activation), 239 output_activation=False, 240 name=f"radial_network_{layer}", 241 gradient_normalization="element", 242 )( 243 radial_embedding 244 ) 245 246 messages = messages * mix 247 node_feats = ( 248 e3nn.IrrepsArray.zeros( 249 messages.irreps, node_feats.shape[:1], messages.dtype 250 ) 251 .at[edge_dst] 252 .add(messages) 253 ) 254 # print("irreps_mid jax",node_feats.irreps) 255 # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570]) 256 257 node_feats = ( 258 e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats) 259 / self.avg_num_neighbors 260 ) 261 262 if first and not self.skip_connection_first_layer: 263 node_feats = e3nn.flax.Linear( 264 node_feats.irreps, 265 num_indexed_weights=num_species, 266 name=f"skip_tp_{layer}", 267 )(species_indices, node_feats) 268 269 ################################################ 270 # Equivariant product basis block 271 272 # symmetric contractions 273 node_feats = SymmetricContraction( 274 keep_irrep_out={ir for _, ir in layer_irreps}, 275 correlation=self.correlation, 276 num_species=num_species, 277 gradient_normalization="element", # NOTE: This is to copy mace-torch 278 symmetric_tensor_product_basis=self.symmetric_tensor_product_basis, 279 )( 280 node_feats, species_indices 281 ) 282 283 284 node_feats = e3nn.flax.Linear( 285 layer_irreps, name=f"linear_contraction_{layer}" 286 )(node_feats) 287 288 289 if sc is not None: 290 # add skip connection 291 node_feats = node_feats + sc 292 293 294 ################################################ 295 296 # Readout block 297 if last: 298 num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps 299 layer_out = e3nn.flax.Linear( 300 (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(), 301 name=f"hidden_linear_readout_last", 302 )(node_feats) 303 layer_out = e3nn.gate( 304 layer_out, 305 even_act=activation_from_str(self.activation), 306 even_gate_act=None, 307 ) 308 layer_out = e3nn.flax.Linear( 309 output_irreps, name=f"linear_readout_last" 310 )(layer_out) 311 else: 312 layer_out = e3nn.flax.Linear( 313 output_irreps, 314 name=f"linear_readout_{layer}", 315 )(node_feats) 316 317 if self.scalar_output: 318 layer_out = layer_out.filter("0e").array 319 320 outputs.append(layer_out) 321 node_feats_all.append(node_feats.filter("0e").array) 322 323 if self.scalar_output: 324 output = jnp.stack(outputs, axis=1) 325 else: 326 output = e3nn.stack(outputs, axis=1) 327 328 node_feats_all = jnp.concatenate(node_feats_all, axis=-1) 329 330 output_key = self.output_key if self.output_key is not None else self.name 331 return { 332 **inputs, 333 output_key: output, 334 output_key + "_node_feats": node_feats_all, 335 } 336 337 338class SymmetricContraction(nn.Module): 339 340 correlation: int 341 keep_irrep_out: Set[Irrep] 342 num_species: int 343 gradient_normalization: Union[str, float] 344 symmetric_tensor_product_basis: bool 345 346 @nn.compact 347 def __call__(self, input, index): 348 if not E3NN_AVAILABLE: 349 raise E3NN_EXCEPTION 350 351 if self.gradient_normalization is None: 352 gradient_normalization = e3nn.config("gradient_normalization") 353 else: 354 gradient_normalization = self.gradient_normalization 355 if isinstance(gradient_normalization, str): 356 gradient_normalization = {"element": 0.0, "path": 1.0}[ 357 gradient_normalization 358 ] 359 360 keep_irrep_out = self.keep_irrep_out 361 if isinstance(keep_irrep_out, str): 362 keep_irrep_out = e3nn.Irreps(keep_irrep_out) 363 assert all(mul == 1 for mul, _ in keep_irrep_out) 364 365 keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out} 366 367 input = input.mul_to_axis().remove_nones() 368 369 ### PREPARE WEIGHTS 370 ws = [] 371 Us = [] 372 for order in range(1, self.correlation + 1): # correlation, ..., 1 373 if self.symmetric_tensor_product_basis: 374 U = e3nn.reduced_symmetric_tensor_product_basis( 375 input.irreps, order, keep_ir=keep_irrep_out 376 ) 377 else: 378 U = e3nn.reduced_tensor_product_basis( 379 [input.irreps] * order, keep_ir=keep_irrep_out 380 ) 381 # U = U / order # normalization TODO(mario): put back after testing 382 # NOTE(mario): The normalization constants (/order and /mul**0.5) 383 # has been numerically checked to be correct. 384 385 # TODO(mario) implement norm_p 386 Us.append(U) 387 388 wsorder = [] 389 for (mul, ir_out), u in zip(U.irreps, U.list): 390 u = u.astype(input.array.dtype) 391 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 392 393 w = self.param( 394 f"w{order}_{ir_out}", 395 nn.initializers.normal( 396 stddev=(mul**-0.5) ** (1.0 - gradient_normalization) 397 ), 398 (self.num_species, mul, input.shape[-2]), 399 ) 400 w = w * (mul**-0.5) ** gradient_normalization 401 wsorder.append(w) 402 ws.append(wsorder) 403 404 def fn(input: e3nn.IrrepsArray, index: jnp.ndarray): 405 # - This operation is parallel on the feature dimension (but each feature has its own parameters) 406 # This operation is an efficient implementation of 407 # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x) 408 # up to x power self.correlation 409 assert input.ndim == 2 # [num_features, irreps_x.dim] 410 assert index.ndim == 0 # int 411 412 out = dict() 413 x_ = input.array 414 415 for order in range(self.correlation, 0, -1): # correlation, ..., 1 416 417 U = Us[order - 1] 418 419 # ((w3 x + w2) x + w1) x 420 # \-----------/ 421 # out 422 423 for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)): 424 u = u.astype(x_.dtype) 425 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 426 427 w = ws[order - 1][ii][index] 428 if ir_out not in out: 429 out[ir_out] = ( 430 "special", 431 jnp.einsum("...jki,kc,cj->c...i", u, w, x_), 432 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 433 else: 434 out[ir_out] += jnp.einsum( 435 "...ki,kc->c...i", u, w 436 ) # [num_features, (irreps_x.dim)^order, ir_out.dim] 437 438 # ((w3 x + w2) x + w1) x 439 # \----------------/ 440 # out (in the normal case) 441 442 for ir_out in out: 443 if isinstance(out[ir_out], tuple): 444 out[ir_out] = out[ir_out][1] 445 continue # already done (special case optimization above) 446 447 out[ir_out] = jnp.einsum( 448 "c...ji,cj->c...i", out[ir_out], x_ 449 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 450 451 # ((w3 x + w2) x + w1) x 452 # \-------------------/ 453 # out 454 455 # out[irrep_out] : [num_features, ir_out.dim] 456 irreps_out = e3nn.Irreps(sorted(out.keys())) 457 return e3nn.IrrepsArray.from_list( 458 irreps_out, 459 [out[ir][:, None, :] for (_, ir) in irreps_out], 460 (input.shape[0],), 461 ) 462 463 # Treat batch indices using vmap 464 shape = jnp.broadcast_shapes(input.shape[:-2], index.shape) 465 input = input.broadcast_to(shape + input.shape[-2:]) 466 index = jnp.broadcast_to(index, shape) 467 468 fn_mapped = fn 469 for _ in range(input.ndim - 2): 470 fn_mapped = jax.vmap(fn_mapped) 471 472 return fn_mapped(input, index).axis_to_mul() 473 474 475# class SymmetricContraction(nn.Module): 476 477# correlation: int 478# keep_irrep_out: Set[Irrep] 479# num_species: int 480# gradient_normalization: Union[str, float] 481# symmetric_tensor_product_basis: bool 482 483# @nn.compact 484# def __call__(self, input: IrrepsArray, index: jnp.ndarray): 485# if not E3NN_AVAILABLE: 486# raise E3NN_EXCEPTION 487 488# if self.gradient_normalization is None: 489# gradient_normalization = e3nn.config("gradient_normalization") 490# else: 491# gradient_normalization = self.gradient_normalization 492# if isinstance(gradient_normalization, str): 493# gradient_normalization = {"element": 0.0, "path": 1.0}[ 494# gradient_normalization 495# ] 496 497# keep_irrep_out = self.keep_irrep_out 498# if isinstance(keep_irrep_out, str): 499# keep_irrep_out = e3nn.Irreps(keep_irrep_out) 500# assert all(mul == 1 for mul, _ in keep_irrep_out) 501 502# keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out} 503 504# onehot = jnp.eye(self.num_species)[index] 505 506# ### PREPARE WEIGHTS 507# ws = [] 508# us = [] 509# for ir_out in keep_irrep_out: 510# usorder = [] 511# wsorder = [] 512# for order in range(1, self.correlation + 1): # correlation, ..., 1 513# if self.symmetric_tensor_product_basis: 514# U = e3nn.reduced_symmetric_tensor_product_basis( 515# input.irreps, order, keep_ir=[ir_out] 516# ) 517# else: 518# U = e3nn.reduced_tensor_product_basis( 519# [input.irreps] * order, keep_ir=[ir_out] 520# ) 521# u = jnp.moveaxis(U.list[0].astype(input.array.dtype), -1, 0) 522# usorder.append(u) 523 524# mul, _ = U.irreps[0] 525# w = self.param( 526# f"w{order}_{ir_out}", 527# nn.initializers.normal( 528# stddev=(mul**-0.5) ** (1.0 - gradient_normalization) 529# ), 530# (self.num_species, mul, input.shape[-2]), 531# ) 532# w = w * (mul**-0.5) ** gradient_normalization 533# wsorder.append(w) 534# ws.append(wsorder) 535# us.append(usorder) 536 537# x = input.array 538 539# outs = [] 540# for i, ir in enumerate(keep_irrep_out): 541# w = ws[i][-1] # [index] 542# u = us[i][-1] 543# out = jnp.einsum("...jk,ekc,bcj,be->bc...", u, w, x, onehot) 544 545# for order in range(self.correlation - 1, 0, -1): 546# w = ws[i][order - 1] # [index] 547# u = us[i][order - 1] 548 549# c_tensor = jnp.einsum("...k,ekc,be->bc...", u, w, onehot) + out 550# out = jnp.einsum("bc...j,bcj->bc...", c_tensor, x) 551 552# outs.append(out.reshape(x.shape[0], -1)) 553 554# out = jnp.concatenate(outs, axis=-1) 555 556# return e3nn.IrrepsArray(input.shape[1] * e3nn.Irreps(keep_irrep_out), out)
33class MACE(nn.Module): 34 """MACE equivariant message passing neural network. 35 36 adapted from MACE-jax github repo by M. Geiger and I. Batatia 37 38 T. Plé reordered some operations and changed defaults to match the recent mace-torch version 39 -> compatibility with pretrained torch models requires some work on the parameters: 40 - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling 41 - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors 42 - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters 43 44 References: 45 - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. 46 https://doi.org/10.48550/arXiv.2206.07697 47 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). 48 https://doi.org/10.48550/arXiv.2205.06643 49 50 """ 51 _graphs_properties: Dict 52 output_irreps: Union[Irreps, str] = "1x0e" 53 """The output irreps of the model.""" 54 hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o" 55 """The hidden irreps of the model.""" 56 readout_mlp_irreps: Union[Irreps, str] = "16x0e" 57 """The hidden irreps of the readout MLP.""" 58 graph_key: str = "graph" 59 """The key in the input dictionary that corresponds to the molecular graph to use.""" 60 output_key: Optional[str] = None 61 """The key of the embedding in the output dictionary.""" 62 avg_num_neighbors: float = 1.0 63 """The expected average number of neighbors.""" 64 ninteractions: int = 2 65 """The number of interaction layers.""" 66 num_features: Optional[int] = None 67 """The number of features per node. default gcd of hidden_irreps multiplicities""" 68 radial_basis: dict = dataclasses.field( 69 default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False} 70 ) 71 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 72 lmax: int = 1 73 """The maximum angular momentum to consider.""" 74 correlation: int = 3 75 """The correlation order at each layer.""" 76 activation: str = "silu" 77 """The activation function to use.""" 78 symmetric_tensor_product_basis: bool = False 79 """Whether to use the symmetric tensor product basis.""" 80 interaction_irreps: Union[Irreps, str] = "o3_restricted" 81 skip_connection_first_layer: bool = True 82 radial_network_hidden: Sequence[int] = dataclasses.field( 83 default_factory=lambda: [64, 64, 64] 84 ) 85 scalar_output: bool = False 86 zmax: int = 86 87 """The maximum atomic number to consider.""" 88 convolution_mode: int = 1 89 species_encoding_key: Optional[str] = None 90 91 FID: ClassVar[str] = "MACE" 92 93 @nn.compact 94 def __call__(self, inputs): 95 if not E3NN_AVAILABLE: 96 raise E3NN_EXCEPTION 97 98 species_indices = inputs["species"] 99 graph = inputs[self.graph_key] 100 distances = graph["distances"] 101 vec = e3nn.IrrepsArray("1o", graph["vec"]) 102 switch = graph["switch"] 103 edge_src = graph["edge_src"] 104 edge_dst = graph["edge_dst"] 105 106 output_irreps = e3nn.Irreps(self.output_irreps) 107 hidden_irreps = e3nn.Irreps(self.hidden_irreps) 108 readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps) 109 110 # extract or set num_features 111 if self.num_features is None: 112 num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps)) 113 hidden_irreps = e3nn.Irreps( 114 [(mul // num_features, ir) for mul, ir in hidden_irreps] 115 ) 116 else: 117 num_features = self.num_features 118 119 # get interaction irreps 120 if self.interaction_irreps == "o3_restricted": 121 interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax) 122 elif self.interaction_irreps == "o3_full": 123 interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax)) 124 else: 125 interaction_irreps = e3nn.Irreps(self.interaction_irreps) 126 convol_irreps = num_features * interaction_irreps 127 128 # convert species to internal indices 129 # maxidx = max(PERIODIC_TABLE_REV_IDX.values()) 130 # conv_tensor = [0] * (maxidx + 2) 131 # if isinstance(self.species_order, str): 132 # species_order = [el.strip() for el in self.species_order.split(",")] 133 # else: 134 # species_order = [el for el in self.species_order] 135 # for i, s in enumerate(species_order): 136 # conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i 137 # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species] 138 num_species = self.zmax + 2 139 140 # species encoding 141 encoding_irreps: e3nn.Irreps = ( 142 (num_features * hidden_irreps).filter("0e").regroup() 143 ) 144 if self.species_encoding_key is not None: 145 species_encoding = nn.Dense(encoding_irreps.dim,use_bias=False)(inputs[self.species_encoding_key]) 146 else: 147 species_encoding = self.param( 148 "species_encoding", 149 lambda key, shape: jax.nn.standardize( 150 jax.random.normal(key, shape, dtype=jnp.float32) 151 ), 152 (num_species, encoding_irreps.dim), 153 )[species_indices] 154 # convert to IrrepsArray 155 node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding) 156 157 # radial embedding 158 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 159 radial_embedding = ( 160 RadialBasis( 161 **{ 162 **self.radial_basis, 163 "end": cutoff, 164 "name": f"RadialBasis", 165 } 166 )(distances) 167 * switch[:, None] 168 ) 169 170 # spherical harmonics 171 assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2" 172 if self.convolution_mode == 0: 173 Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True) 174 elif self.convolution_mode == 1: 175 Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True) 176 177 outputs = [] 178 node_feats_all = [] 179 for layer in range(self.ninteractions): 180 first = layer == 0 181 last = layer == self.ninteractions - 1 182 183 layer_irreps = num_features * ( 184 hidden_irreps if not last else hidden_irreps.filter(output_irreps) 185 ) 186 187 # Linear skip connection 188 sc = None 189 if not first or self.skip_connection_first_layer: 190 sc = e3nn.flax.Linear( 191 layer_irreps, 192 num_indexed_weights=num_species, 193 name=f"skip_tp_{layer}", 194 force_irreps_out=True, 195 )(species_indices, node_feats) 196 197 ################################################ 198 # Interaction block (Message passing convolution) 199 node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")( 200 node_feats 201 ) 202 203 204 messages = node_feats[edge_src] 205 if self.convolution_mode == 0: 206 messages = e3nn.tensor_product( 207 messages, 208 Yij, 209 filter_ir_out=convol_irreps, 210 regroup_output=True, 211 ) 212 elif self.convolution_mode == 1: 213 messages = e3nn.concatenate( 214 [ 215 messages.filter(convol_irreps), 216 e3nn.tensor_product( 217 messages, 218 Yij, 219 filter_ir_out=convol_irreps, 220 ), 221 # e3nn.tensor_product_with_spherical_harmonics( 222 # messages, vectors, self.max_ell 223 # ).filter(convol_irreps), 224 ] 225 ).regroup() 226 else: 227 messages = e3nn.tensor_product_with_spherical_harmonics( 228 messages, vec, self.lmax 229 ).filter(convol_irreps).regroup() 230 231 # mix = FullyConnectedNet( 232 # [*self.radial_network_hidden, messages.irreps.num_irreps], 233 # activation=activation_from_str(self.activation), 234 # name=f"radial_network_{layer}", 235 # use_bias=False, 236 # )(radial_embedding) 237 mix = e3nn.flax.MultiLayerPerceptron( 238 [*self.radial_network_hidden, messages.irreps.num_irreps], 239 act=activation_from_str(self.activation), 240 output_activation=False, 241 name=f"radial_network_{layer}", 242 gradient_normalization="element", 243 )( 244 radial_embedding 245 ) 246 247 messages = messages * mix 248 node_feats = ( 249 e3nn.IrrepsArray.zeros( 250 messages.irreps, node_feats.shape[:1], messages.dtype 251 ) 252 .at[edge_dst] 253 .add(messages) 254 ) 255 # print("irreps_mid jax",node_feats.irreps) 256 # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570]) 257 258 node_feats = ( 259 e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats) 260 / self.avg_num_neighbors 261 ) 262 263 if first and not self.skip_connection_first_layer: 264 node_feats = e3nn.flax.Linear( 265 node_feats.irreps, 266 num_indexed_weights=num_species, 267 name=f"skip_tp_{layer}", 268 )(species_indices, node_feats) 269 270 ################################################ 271 # Equivariant product basis block 272 273 # symmetric contractions 274 node_feats = SymmetricContraction( 275 keep_irrep_out={ir for _, ir in layer_irreps}, 276 correlation=self.correlation, 277 num_species=num_species, 278 gradient_normalization="element", # NOTE: This is to copy mace-torch 279 symmetric_tensor_product_basis=self.symmetric_tensor_product_basis, 280 )( 281 node_feats, species_indices 282 ) 283 284 285 node_feats = e3nn.flax.Linear( 286 layer_irreps, name=f"linear_contraction_{layer}" 287 )(node_feats) 288 289 290 if sc is not None: 291 # add skip connection 292 node_feats = node_feats + sc 293 294 295 ################################################ 296 297 # Readout block 298 if last: 299 num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps 300 layer_out = e3nn.flax.Linear( 301 (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(), 302 name=f"hidden_linear_readout_last", 303 )(node_feats) 304 layer_out = e3nn.gate( 305 layer_out, 306 even_act=activation_from_str(self.activation), 307 even_gate_act=None, 308 ) 309 layer_out = e3nn.flax.Linear( 310 output_irreps, name=f"linear_readout_last" 311 )(layer_out) 312 else: 313 layer_out = e3nn.flax.Linear( 314 output_irreps, 315 name=f"linear_readout_{layer}", 316 )(node_feats) 317 318 if self.scalar_output: 319 layer_out = layer_out.filter("0e").array 320 321 outputs.append(layer_out) 322 node_feats_all.append(node_feats.filter("0e").array) 323 324 if self.scalar_output: 325 output = jnp.stack(outputs, axis=1) 326 else: 327 output = e3nn.stack(outputs, axis=1) 328 329 node_feats_all = jnp.concatenate(node_feats_all, axis=-1) 330 331 output_key = self.output_key if self.output_key is not None else self.name 332 return { 333 **inputs, 334 output_key: output, 335 output_key + "_node_feats": node_feats_all, 336 }
MACE equivariant message passing neural network.
adapted from MACE-jax github repo by M. Geiger and I. Batatia
T. Plé reordered some operations and changed defaults to match the recent mace-torch version -> compatibility with pretrained torch models requires some work on the parameters: - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters
References: - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. https://doi.org/10.48550/arXiv.2206.07697 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). https://doi.org/10.48550/arXiv.2205.06643
The hidden irreps of the readout MLP.
The key in the input dictionary that corresponds to the molecular graph to use.
The number of features per node. default gcd of hidden_irreps multiplicities
The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
339class SymmetricContraction(nn.Module): 340 341 correlation: int 342 keep_irrep_out: Set[Irrep] 343 num_species: int 344 gradient_normalization: Union[str, float] 345 symmetric_tensor_product_basis: bool 346 347 @nn.compact 348 def __call__(self, input, index): 349 if not E3NN_AVAILABLE: 350 raise E3NN_EXCEPTION 351 352 if self.gradient_normalization is None: 353 gradient_normalization = e3nn.config("gradient_normalization") 354 else: 355 gradient_normalization = self.gradient_normalization 356 if isinstance(gradient_normalization, str): 357 gradient_normalization = {"element": 0.0, "path": 1.0}[ 358 gradient_normalization 359 ] 360 361 keep_irrep_out = self.keep_irrep_out 362 if isinstance(keep_irrep_out, str): 363 keep_irrep_out = e3nn.Irreps(keep_irrep_out) 364 assert all(mul == 1 for mul, _ in keep_irrep_out) 365 366 keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out} 367 368 input = input.mul_to_axis().remove_nones() 369 370 ### PREPARE WEIGHTS 371 ws = [] 372 Us = [] 373 for order in range(1, self.correlation + 1): # correlation, ..., 1 374 if self.symmetric_tensor_product_basis: 375 U = e3nn.reduced_symmetric_tensor_product_basis( 376 input.irreps, order, keep_ir=keep_irrep_out 377 ) 378 else: 379 U = e3nn.reduced_tensor_product_basis( 380 [input.irreps] * order, keep_ir=keep_irrep_out 381 ) 382 # U = U / order # normalization TODO(mario): put back after testing 383 # NOTE(mario): The normalization constants (/order and /mul**0.5) 384 # has been numerically checked to be correct. 385 386 # TODO(mario) implement norm_p 387 Us.append(U) 388 389 wsorder = [] 390 for (mul, ir_out), u in zip(U.irreps, U.list): 391 u = u.astype(input.array.dtype) 392 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 393 394 w = self.param( 395 f"w{order}_{ir_out}", 396 nn.initializers.normal( 397 stddev=(mul**-0.5) ** (1.0 - gradient_normalization) 398 ), 399 (self.num_species, mul, input.shape[-2]), 400 ) 401 w = w * (mul**-0.5) ** gradient_normalization 402 wsorder.append(w) 403 ws.append(wsorder) 404 405 def fn(input: e3nn.IrrepsArray, index: jnp.ndarray): 406 # - This operation is parallel on the feature dimension (but each feature has its own parameters) 407 # This operation is an efficient implementation of 408 # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x) 409 # up to x power self.correlation 410 assert input.ndim == 2 # [num_features, irreps_x.dim] 411 assert index.ndim == 0 # int 412 413 out = dict() 414 x_ = input.array 415 416 for order in range(self.correlation, 0, -1): # correlation, ..., 1 417 418 U = Us[order - 1] 419 420 # ((w3 x + w2) x + w1) x 421 # \-----------/ 422 # out 423 424 for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)): 425 u = u.astype(x_.dtype) 426 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 427 428 w = ws[order - 1][ii][index] 429 if ir_out not in out: 430 out[ir_out] = ( 431 "special", 432 jnp.einsum("...jki,kc,cj->c...i", u, w, x_), 433 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 434 else: 435 out[ir_out] += jnp.einsum( 436 "...ki,kc->c...i", u, w 437 ) # [num_features, (irreps_x.dim)^order, ir_out.dim] 438 439 # ((w3 x + w2) x + w1) x 440 # \----------------/ 441 # out (in the normal case) 442 443 for ir_out in out: 444 if isinstance(out[ir_out], tuple): 445 out[ir_out] = out[ir_out][1] 446 continue # already done (special case optimization above) 447 448 out[ir_out] = jnp.einsum( 449 "c...ji,cj->c...i", out[ir_out], x_ 450 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 451 452 # ((w3 x + w2) x + w1) x 453 # \-------------------/ 454 # out 455 456 # out[irrep_out] : [num_features, ir_out.dim] 457 irreps_out = e3nn.Irreps(sorted(out.keys())) 458 return e3nn.IrrepsArray.from_list( 459 irreps_out, 460 [out[ir][:, None, :] for (_, ir) in irreps_out], 461 (input.shape[0],), 462 ) 463 464 # Treat batch indices using vmap 465 shape = jnp.broadcast_shapes(input.shape[:-2], index.shape) 466 input = input.broadcast_to(shape + input.shape[-2:]) 467 index = jnp.broadcast_to(index, shape) 468 469 fn_mapped = fn 470 for _ in range(input.ndim - 2): 471 fn_mapped = jax.vmap(fn_mapped) 472 473 return fn_mapped(input, index).axis_to_mul()
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.