fennol.models.embeddings.mace

  1import functools
  2import math
  3import jax
  4import jax.numpy as jnp
  5import flax.linen as nn
  6from typing import Sequence, Dict, Union, ClassVar, Optional, Set
  7import dataclasses
  8
  9from ..misc.encodings import RadialBasis
 10from ...utils.activations import activation_from_str
 11
 12
 13try:
 14    import e3nn_jax as e3nn
 15
 16    E3NN_AVAILABLE = True
 17    E3NN_EXCEPTION = None
 18    Irreps = e3nn.Irreps
 19    Irrep = e3nn.Irrep
 20except Exception as e:
 21    E3NN_AVAILABLE = False
 22    E3NN_EXCEPTION = e
 23    e3nn = None
 24
 25    class Irreps(tuple):
 26        pass
 27
 28    class Irrep(tuple):
 29        pass
 30
 31
 32class MACE(nn.Module):
 33    """MACE equivariant message passing neural network.
 34
 35    adapted from MACE-jax github repo by M. Geiger and I. Batatia
 36    
 37    T. Plé reordered some operations and changed defaults to match the recent mace-torch version 
 38    -> compatibility with pretrained torch models requires some work on the parameters:
 39        - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling
 40        - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors
 41        - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters
 42    
 43    References:
 44        - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436.
 45        https://doi.org/10.48550/arXiv.2206.07697
 46        - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022).
 47        https://doi.org/10.48550/arXiv.2205.06643
 48
 49    """
 50    _graphs_properties: Dict
 51    output_irreps: Union[Irreps, str] = "1x0e"
 52    """The output irreps of the model."""
 53    hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o"
 54    """The hidden irreps of the model."""
 55    readout_mlp_irreps: Union[Irreps, str] = "16x0e"
 56    """The hidden irreps of the readout MLP."""
 57    graph_key: str = "graph"
 58    """The key in the input dictionary that corresponds to the molecular graph to use."""
 59    output_key: Optional[str] = None
 60    """The key of the embedding in the output dictionary."""
 61    avg_num_neighbors: float = 1.0
 62    """The expected average number of neighbors."""
 63    ninteractions: int = 2
 64    """The number of interaction layers."""
 65    num_features: Optional[int] = None
 66    """The number of features per node. default gcd of hidden_irreps multiplicities"""
 67    radial_basis: dict = dataclasses.field(
 68        default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False}
 69    )
 70    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 71    lmax: int = 1
 72    """The maximum angular momentum to consider."""
 73    correlation: int = 3
 74    """The correlation order at each layer."""
 75    activation: str = "silu"
 76    """The activation function to use."""
 77    symmetric_tensor_product_basis: bool = False
 78    """Whether to use the symmetric tensor product basis."""
 79    interaction_irreps: Union[Irreps, str] = "o3_restricted"
 80    skip_connection_first_layer: bool = True
 81    radial_network_hidden: Sequence[int] = dataclasses.field(
 82        default_factory=lambda: [64, 64, 64]
 83    )
 84    scalar_output: bool = False
 85    zmax: int = 86
 86    """The maximum atomic number to consider."""
 87    convolution_mode: int = 1
 88    species_encoding_key: Optional[str] = None
 89
 90    FID: ClassVar[str] = "MACE"
 91
 92    @nn.compact
 93    def __call__(self, inputs):
 94        if not E3NN_AVAILABLE:
 95            raise E3NN_EXCEPTION
 96
 97        species_indices = inputs["species"]
 98        graph = inputs[self.graph_key]
 99        distances = graph["distances"]
100        vec = e3nn.IrrepsArray("1o", graph["vec"])
101        switch = graph["switch"]
102        edge_src = graph["edge_src"]
103        edge_dst = graph["edge_dst"]
104
105        output_irreps = e3nn.Irreps(self.output_irreps)
106        hidden_irreps = e3nn.Irreps(self.hidden_irreps)
107        readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps)
108
109        # extract or set num_features
110        if self.num_features is None:
111            num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps))
112            hidden_irreps = e3nn.Irreps(
113                [(mul // num_features, ir) for mul, ir in hidden_irreps]
114            )
115        else:
116            num_features = self.num_features
117
118        # get interaction irreps
119        if self.interaction_irreps == "o3_restricted":
120            interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax)
121        elif self.interaction_irreps == "o3_full":
122            interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax))
123        else:
124            interaction_irreps = e3nn.Irreps(self.interaction_irreps)
125        convol_irreps = num_features * interaction_irreps
126
127        # convert species to internal indices
128        # maxidx = max(PERIODIC_TABLE_REV_IDX.values())
129        # conv_tensor = [0] * (maxidx + 2)
130        # if isinstance(self.species_order, str):
131        #     species_order = [el.strip() for el in self.species_order.split(",")]
132        # else:
133        #     species_order = [el for el in self.species_order]
134        # for i, s in enumerate(species_order):
135        #     conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i
136        # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species]
137        num_species = self.zmax + 2
138
139        # species encoding
140        encoding_irreps: e3nn.Irreps = (
141            (num_features * hidden_irreps).filter("0e").regroup()
142        )
143        if self.species_encoding_key is not None:
144            species_encoding = nn.Dense(encoding_irreps.dim,use_bias=False)(inputs[self.species_encoding_key])
145        else:
146            species_encoding = self.param(
147                "species_encoding",
148                lambda key, shape: jax.nn.standardize(
149                    jax.random.normal(key, shape, dtype=jnp.float32)
150                ),
151                (num_species, encoding_irreps.dim),
152            )[species_indices]
153        # convert to IrrepsArray
154        node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding)
155
156        # radial embedding
157        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
158        radial_embedding = (
159            RadialBasis(
160                **{
161                    **self.radial_basis,
162                    "end": cutoff,
163                    "name": f"RadialBasis",
164                }
165            )(distances)
166            * switch[:, None]
167        )
168
169        # spherical harmonics
170        assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2"
171        if self.convolution_mode == 0:
172            Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True)
173        elif self.convolution_mode == 1:
174            Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True)
175
176        outputs = []
177        node_feats_all = []
178        for layer in range(self.ninteractions):
179            first = layer == 0
180            last = layer == self.ninteractions - 1
181
182            layer_irreps = num_features * (
183                hidden_irreps if not last else hidden_irreps.filter(output_irreps)
184            )
185
186            # Linear skip connection
187            sc = None
188            if not first or self.skip_connection_first_layer:
189                sc = e3nn.flax.Linear(
190                    layer_irreps,
191                    num_indexed_weights=num_species,
192                    name=f"skip_tp_{layer}",
193                    force_irreps_out=True,
194                )(species_indices, node_feats)
195
196            ################################################
197            # Interaction block (Message passing convolution)
198            node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")(
199                node_feats
200            )
201
202
203            messages = node_feats[edge_src]
204            if self.convolution_mode == 0:
205                messages = e3nn.tensor_product(
206                    messages,
207                    Yij,
208                    filter_ir_out=convol_irreps,
209                    regroup_output=True,
210                )
211            elif self.convolution_mode == 1:
212                messages = e3nn.concatenate(
213                    [
214                        messages.filter(convol_irreps),
215                        e3nn.tensor_product(
216                            messages,
217                            Yij,
218                            filter_ir_out=convol_irreps,
219                        ),
220                        # e3nn.tensor_product_with_spherical_harmonics(
221                        #     messages, vectors, self.max_ell
222                        # ).filter(convol_irreps),
223                    ]
224                ).regroup()
225            else:
226                messages = e3nn.tensor_product_with_spherical_harmonics(
227                    messages, vec, self.lmax
228                ).filter(convol_irreps).regroup()
229
230            # mix = FullyConnectedNet(
231            #     [*self.radial_network_hidden, messages.irreps.num_irreps],
232            #     activation=activation_from_str(self.activation),
233            #     name=f"radial_network_{layer}",
234            #     use_bias=False,
235            # )(radial_embedding)
236            mix = e3nn.flax.MultiLayerPerceptron(
237                [*self.radial_network_hidden, messages.irreps.num_irreps],
238                act=activation_from_str(self.activation),
239                output_activation=False,
240                name=f"radial_network_{layer}",
241                gradient_normalization="element",
242            )(
243                radial_embedding
244            )
245
246            messages = messages * mix
247            node_feats = (
248                e3nn.IrrepsArray.zeros(
249                    messages.irreps, node_feats.shape[:1], messages.dtype
250                )
251                .at[edge_dst]
252                .add(messages)
253            )
254            # print("irreps_mid jax",node_feats.irreps)
255            # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570])
256
257            node_feats = (
258                e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats)
259                / self.avg_num_neighbors
260            )
261
262            if first and not self.skip_connection_first_layer:
263                node_feats = e3nn.flax.Linear(
264                    node_feats.irreps,
265                    num_indexed_weights=num_species,
266                    name=f"skip_tp_{layer}",
267                )(species_indices, node_feats)
268
269            ################################################
270            # Equivariant product basis block
271
272            # symmetric contractions
273            node_feats = SymmetricContraction(
274                keep_irrep_out={ir for _, ir in layer_irreps},
275                correlation=self.correlation,
276                num_species=num_species,
277                gradient_normalization="element",  # NOTE: This is to copy mace-torch
278                symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
279            )(
280                node_feats, species_indices
281            )
282
283
284            node_feats = e3nn.flax.Linear(
285                layer_irreps, name=f"linear_contraction_{layer}"
286            )(node_feats)
287
288
289            if sc is not None:
290                # add skip connection
291                node_feats = node_feats + sc
292
293
294            ################################################
295            
296            # Readout block
297            if last:
298                num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps
299                layer_out = e3nn.flax.Linear(
300                    (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(),
301                    name=f"hidden_linear_readout_last",
302                )(node_feats)
303                layer_out = e3nn.gate(
304                    layer_out,
305                    even_act=activation_from_str(self.activation),
306                    even_gate_act=None,
307                )
308                layer_out = e3nn.flax.Linear(
309                    output_irreps, name=f"linear_readout_last"
310                )(layer_out)
311            else:
312                layer_out = e3nn.flax.Linear(
313                    output_irreps,
314                    name=f"linear_readout_{layer}",
315                )(node_feats)
316
317            if self.scalar_output:
318                layer_out = layer_out.filter("0e").array
319
320            outputs.append(layer_out)
321            node_feats_all.append(node_feats.filter("0e").array)
322
323        if self.scalar_output:
324            output = jnp.stack(outputs, axis=1)
325        else:
326            output = e3nn.stack(outputs, axis=1)
327
328        node_feats_all = jnp.concatenate(node_feats_all, axis=-1)
329
330        output_key = self.output_key if self.output_key is not None else self.name
331        return {
332            **inputs,
333            output_key: output,
334            output_key + "_node_feats": node_feats_all,
335        }
336
337
338class SymmetricContraction(nn.Module):
339
340    correlation: int
341    keep_irrep_out: Set[Irrep]
342    num_species: int
343    gradient_normalization: Union[str, float]
344    symmetric_tensor_product_basis: bool
345
346    @nn.compact
347    def __call__(self, input, index):
348        if not E3NN_AVAILABLE:
349            raise E3NN_EXCEPTION
350
351        if self.gradient_normalization is None:
352            gradient_normalization = e3nn.config("gradient_normalization")
353        else:
354            gradient_normalization = self.gradient_normalization
355        if isinstance(gradient_normalization, str):
356            gradient_normalization = {"element": 0.0, "path": 1.0}[
357                gradient_normalization
358            ]
359
360        keep_irrep_out = self.keep_irrep_out
361        if isinstance(keep_irrep_out, str):
362            keep_irrep_out = e3nn.Irreps(keep_irrep_out)
363            assert all(mul == 1 for mul, _ in keep_irrep_out)
364
365        keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out}
366
367        input = input.mul_to_axis().remove_nones()
368
369        ### PREPARE WEIGHTS
370        ws = []
371        Us = []
372        for order in range(1, self.correlation + 1):  # correlation, ..., 1
373            if self.symmetric_tensor_product_basis:
374                U = e3nn.reduced_symmetric_tensor_product_basis(
375                    input.irreps, order, keep_ir=keep_irrep_out
376                )
377            else:
378                U = e3nn.reduced_tensor_product_basis(
379                    [input.irreps] * order, keep_ir=keep_irrep_out
380                )
381            # U = U / order  # normalization TODO(mario): put back after testing
382            # NOTE(mario): The normalization constants (/order and /mul**0.5)
383            # has been numerically checked to be correct.
384
385            # TODO(mario) implement norm_p
386            Us.append(U)
387
388            wsorder = []
389            for (mul, ir_out), u in zip(U.irreps, U.list):
390                u = u.astype(input.array.dtype)
391                # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
392
393                w = self.param(
394                    f"w{order}_{ir_out}",
395                    nn.initializers.normal(
396                        stddev=(mul**-0.5) ** (1.0 - gradient_normalization)
397                    ),
398                    (self.num_species, mul, input.shape[-2]),
399                )
400                w = w * (mul**-0.5) ** gradient_normalization
401                wsorder.append(w)
402            ws.append(wsorder)
403
404        def fn(input: e3nn.IrrepsArray, index: jnp.ndarray):
405            # - This operation is parallel on the feature dimension (but each feature has its own parameters)
406            # This operation is an efficient implementation of
407            # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x)
408            # up to x power self.correlation
409            assert input.ndim == 2  # [num_features, irreps_x.dim]
410            assert index.ndim == 0  # int
411
412            out = dict()
413            x_ = input.array
414
415            for order in range(self.correlation, 0, -1):  # correlation, ..., 1
416
417                U = Us[order - 1]
418
419                # ((w3 x + w2) x + w1) x
420                #  \-----------/
421                #       out
422
423                for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)):
424                    u = u.astype(x_.dtype)
425                    # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
426
427                    w = ws[order - 1][ii][index]
428                    if ir_out not in out:
429                        out[ir_out] = (
430                            "special",
431                            jnp.einsum("...jki,kc,cj->c...i", u, w, x_),
432                        )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
433                    else:
434                        out[ir_out] += jnp.einsum(
435                            "...ki,kc->c...i", u, w
436                        )  # [num_features, (irreps_x.dim)^order, ir_out.dim]
437
438                # ((w3 x + w2) x + w1) x
439                #  \----------------/
440                #         out (in the normal case)
441
442                for ir_out in out:
443                    if isinstance(out[ir_out], tuple):
444                        out[ir_out] = out[ir_out][1]
445                        continue  # already done (special case optimization above)
446
447                    out[ir_out] = jnp.einsum(
448                        "c...ji,cj->c...i", out[ir_out], x_
449                    )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
450
451                # ((w3 x + w2) x + w1) x
452                #  \-------------------/
453                #           out
454
455            # out[irrep_out] : [num_features, ir_out.dim]
456            irreps_out = e3nn.Irreps(sorted(out.keys()))
457            return e3nn.IrrepsArray.from_list(
458                irreps_out,
459                [out[ir][:, None, :] for (_, ir) in irreps_out],
460                (input.shape[0],),
461            )
462
463        # Treat batch indices using vmap
464        shape = jnp.broadcast_shapes(input.shape[:-2], index.shape)
465        input = input.broadcast_to(shape + input.shape[-2:])
466        index = jnp.broadcast_to(index, shape)
467
468        fn_mapped = fn
469        for _ in range(input.ndim - 2):
470            fn_mapped = jax.vmap(fn_mapped)
471
472        return fn_mapped(input, index).axis_to_mul()
473
474
475# class SymmetricContraction(nn.Module):
476
477#     correlation: int
478#     keep_irrep_out: Set[Irrep]
479#     num_species: int
480#     gradient_normalization: Union[str, float]
481#     symmetric_tensor_product_basis: bool
482
483#     @nn.compact
484#     def __call__(self, input: IrrepsArray, index: jnp.ndarray):
485#         if not E3NN_AVAILABLE:
486#             raise E3NN_EXCEPTION
487
488#         if self.gradient_normalization is None:
489#             gradient_normalization = e3nn.config("gradient_normalization")
490#         else:
491#             gradient_normalization = self.gradient_normalization
492#         if isinstance(gradient_normalization, str):
493#             gradient_normalization = {"element": 0.0, "path": 1.0}[
494#                 gradient_normalization
495#             ]
496
497#         keep_irrep_out = self.keep_irrep_out
498#         if isinstance(keep_irrep_out, str):
499#             keep_irrep_out = e3nn.Irreps(keep_irrep_out)
500#             assert all(mul == 1 for mul, _ in keep_irrep_out)
501
502#         keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out}
503
504#         onehot = jnp.eye(self.num_species)[index]
505
506#         ### PREPARE WEIGHTS
507#         ws = []
508#         us = []
509#         for ir_out in keep_irrep_out:
510#             usorder = []
511#             wsorder = []
512#             for order in range(1, self.correlation + 1):  # correlation, ..., 1
513#                 if self.symmetric_tensor_product_basis:
514#                     U = e3nn.reduced_symmetric_tensor_product_basis(
515#                         input.irreps, order, keep_ir=[ir_out]
516#                     )
517#                 else:
518#                     U = e3nn.reduced_tensor_product_basis(
519#                         [input.irreps] * order, keep_ir=[ir_out]
520#                     )
521#                 u = jnp.moveaxis(U.list[0].astype(input.array.dtype), -1, 0)
522#                 usorder.append(u)
523
524#                 mul, _ = U.irreps[0]
525#                 w = self.param(
526#                     f"w{order}_{ir_out}",
527#                     nn.initializers.normal(
528#                         stddev=(mul**-0.5) ** (1.0 - gradient_normalization)
529#                     ),
530#                     (self.num_species, mul, input.shape[-2]),
531#                 )
532#                 w = w * (mul**-0.5) ** gradient_normalization
533#                 wsorder.append(w)
534#             ws.append(wsorder)
535#             us.append(usorder)
536
537#         x = input.array
538
539#         outs = []
540#         for i, ir in enumerate(keep_irrep_out):
541#             w = ws[i][-1]  # [index]
542#             u = us[i][-1]
543#             out = jnp.einsum("...jk,ekc,bcj,be->bc...", u, w, x, onehot)
544
545#             for order in range(self.correlation - 1, 0, -1):
546#                 w = ws[i][order - 1]  # [index]
547#                 u = us[i][order - 1]
548
549#                 c_tensor = jnp.einsum("...k,ekc,be->bc...", u, w, onehot) + out
550#                 out = jnp.einsum("bc...j,bcj->bc...", c_tensor, x)
551
552#             outs.append(out.reshape(x.shape[0], -1))
553
554#         out = jnp.concatenate(outs, axis=-1)
555
556#         return e3nn.IrrepsArray(input.shape[1] * e3nn.Irreps(keep_irrep_out), out)
class MACE(flax.linen.module.Module):
 33class MACE(nn.Module):
 34    """MACE equivariant message passing neural network.
 35
 36    adapted from MACE-jax github repo by M. Geiger and I. Batatia
 37    
 38    T. Plé reordered some operations and changed defaults to match the recent mace-torch version 
 39    -> compatibility with pretrained torch models requires some work on the parameters:
 40        - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling
 41        - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors
 42        - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters
 43    
 44    References:
 45        - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436.
 46        https://doi.org/10.48550/arXiv.2206.07697
 47        - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022).
 48        https://doi.org/10.48550/arXiv.2205.06643
 49
 50    """
 51    _graphs_properties: Dict
 52    output_irreps: Union[Irreps, str] = "1x0e"
 53    """The output irreps of the model."""
 54    hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o"
 55    """The hidden irreps of the model."""
 56    readout_mlp_irreps: Union[Irreps, str] = "16x0e"
 57    """The hidden irreps of the readout MLP."""
 58    graph_key: str = "graph"
 59    """The key in the input dictionary that corresponds to the molecular graph to use."""
 60    output_key: Optional[str] = None
 61    """The key of the embedding in the output dictionary."""
 62    avg_num_neighbors: float = 1.0
 63    """The expected average number of neighbors."""
 64    ninteractions: int = 2
 65    """The number of interaction layers."""
 66    num_features: Optional[int] = None
 67    """The number of features per node. default gcd of hidden_irreps multiplicities"""
 68    radial_basis: dict = dataclasses.field(
 69        default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False}
 70    )
 71    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 72    lmax: int = 1
 73    """The maximum angular momentum to consider."""
 74    correlation: int = 3
 75    """The correlation order at each layer."""
 76    activation: str = "silu"
 77    """The activation function to use."""
 78    symmetric_tensor_product_basis: bool = False
 79    """Whether to use the symmetric tensor product basis."""
 80    interaction_irreps: Union[Irreps, str] = "o3_restricted"
 81    skip_connection_first_layer: bool = True
 82    radial_network_hidden: Sequence[int] = dataclasses.field(
 83        default_factory=lambda: [64, 64, 64]
 84    )
 85    scalar_output: bool = False
 86    zmax: int = 86
 87    """The maximum atomic number to consider."""
 88    convolution_mode: int = 1
 89    species_encoding_key: Optional[str] = None
 90
 91    FID: ClassVar[str] = "MACE"
 92
 93    @nn.compact
 94    def __call__(self, inputs):
 95        if not E3NN_AVAILABLE:
 96            raise E3NN_EXCEPTION
 97
 98        species_indices = inputs["species"]
 99        graph = inputs[self.graph_key]
100        distances = graph["distances"]
101        vec = e3nn.IrrepsArray("1o", graph["vec"])
102        switch = graph["switch"]
103        edge_src = graph["edge_src"]
104        edge_dst = graph["edge_dst"]
105
106        output_irreps = e3nn.Irreps(self.output_irreps)
107        hidden_irreps = e3nn.Irreps(self.hidden_irreps)
108        readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps)
109
110        # extract or set num_features
111        if self.num_features is None:
112            num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps))
113            hidden_irreps = e3nn.Irreps(
114                [(mul // num_features, ir) for mul, ir in hidden_irreps]
115            )
116        else:
117            num_features = self.num_features
118
119        # get interaction irreps
120        if self.interaction_irreps == "o3_restricted":
121            interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax)
122        elif self.interaction_irreps == "o3_full":
123            interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax))
124        else:
125            interaction_irreps = e3nn.Irreps(self.interaction_irreps)
126        convol_irreps = num_features * interaction_irreps
127
128        # convert species to internal indices
129        # maxidx = max(PERIODIC_TABLE_REV_IDX.values())
130        # conv_tensor = [0] * (maxidx + 2)
131        # if isinstance(self.species_order, str):
132        #     species_order = [el.strip() for el in self.species_order.split(",")]
133        # else:
134        #     species_order = [el for el in self.species_order]
135        # for i, s in enumerate(species_order):
136        #     conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i
137        # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species]
138        num_species = self.zmax + 2
139
140        # species encoding
141        encoding_irreps: e3nn.Irreps = (
142            (num_features * hidden_irreps).filter("0e").regroup()
143        )
144        if self.species_encoding_key is not None:
145            species_encoding = nn.Dense(encoding_irreps.dim,use_bias=False)(inputs[self.species_encoding_key])
146        else:
147            species_encoding = self.param(
148                "species_encoding",
149                lambda key, shape: jax.nn.standardize(
150                    jax.random.normal(key, shape, dtype=jnp.float32)
151                ),
152                (num_species, encoding_irreps.dim),
153            )[species_indices]
154        # convert to IrrepsArray
155        node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding)
156
157        # radial embedding
158        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
159        radial_embedding = (
160            RadialBasis(
161                **{
162                    **self.radial_basis,
163                    "end": cutoff,
164                    "name": f"RadialBasis",
165                }
166            )(distances)
167            * switch[:, None]
168        )
169
170        # spherical harmonics
171        assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2"
172        if self.convolution_mode == 0:
173            Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True)
174        elif self.convolution_mode == 1:
175            Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True)
176
177        outputs = []
178        node_feats_all = []
179        for layer in range(self.ninteractions):
180            first = layer == 0
181            last = layer == self.ninteractions - 1
182
183            layer_irreps = num_features * (
184                hidden_irreps if not last else hidden_irreps.filter(output_irreps)
185            )
186
187            # Linear skip connection
188            sc = None
189            if not first or self.skip_connection_first_layer:
190                sc = e3nn.flax.Linear(
191                    layer_irreps,
192                    num_indexed_weights=num_species,
193                    name=f"skip_tp_{layer}",
194                    force_irreps_out=True,
195                )(species_indices, node_feats)
196
197            ################################################
198            # Interaction block (Message passing convolution)
199            node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")(
200                node_feats
201            )
202
203
204            messages = node_feats[edge_src]
205            if self.convolution_mode == 0:
206                messages = e3nn.tensor_product(
207                    messages,
208                    Yij,
209                    filter_ir_out=convol_irreps,
210                    regroup_output=True,
211                )
212            elif self.convolution_mode == 1:
213                messages = e3nn.concatenate(
214                    [
215                        messages.filter(convol_irreps),
216                        e3nn.tensor_product(
217                            messages,
218                            Yij,
219                            filter_ir_out=convol_irreps,
220                        ),
221                        # e3nn.tensor_product_with_spherical_harmonics(
222                        #     messages, vectors, self.max_ell
223                        # ).filter(convol_irreps),
224                    ]
225                ).regroup()
226            else:
227                messages = e3nn.tensor_product_with_spherical_harmonics(
228                    messages, vec, self.lmax
229                ).filter(convol_irreps).regroup()
230
231            # mix = FullyConnectedNet(
232            #     [*self.radial_network_hidden, messages.irreps.num_irreps],
233            #     activation=activation_from_str(self.activation),
234            #     name=f"radial_network_{layer}",
235            #     use_bias=False,
236            # )(radial_embedding)
237            mix = e3nn.flax.MultiLayerPerceptron(
238                [*self.radial_network_hidden, messages.irreps.num_irreps],
239                act=activation_from_str(self.activation),
240                output_activation=False,
241                name=f"radial_network_{layer}",
242                gradient_normalization="element",
243            )(
244                radial_embedding
245            )
246
247            messages = messages * mix
248            node_feats = (
249                e3nn.IrrepsArray.zeros(
250                    messages.irreps, node_feats.shape[:1], messages.dtype
251                )
252                .at[edge_dst]
253                .add(messages)
254            )
255            # print("irreps_mid jax",node_feats.irreps)
256            # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570])
257
258            node_feats = (
259                e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats)
260                / self.avg_num_neighbors
261            )
262
263            if first and not self.skip_connection_first_layer:
264                node_feats = e3nn.flax.Linear(
265                    node_feats.irreps,
266                    num_indexed_weights=num_species,
267                    name=f"skip_tp_{layer}",
268                )(species_indices, node_feats)
269
270            ################################################
271            # Equivariant product basis block
272
273            # symmetric contractions
274            node_feats = SymmetricContraction(
275                keep_irrep_out={ir for _, ir in layer_irreps},
276                correlation=self.correlation,
277                num_species=num_species,
278                gradient_normalization="element",  # NOTE: This is to copy mace-torch
279                symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
280            )(
281                node_feats, species_indices
282            )
283
284
285            node_feats = e3nn.flax.Linear(
286                layer_irreps, name=f"linear_contraction_{layer}"
287            )(node_feats)
288
289
290            if sc is not None:
291                # add skip connection
292                node_feats = node_feats + sc
293
294
295            ################################################
296            
297            # Readout block
298            if last:
299                num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps
300                layer_out = e3nn.flax.Linear(
301                    (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(),
302                    name=f"hidden_linear_readout_last",
303                )(node_feats)
304                layer_out = e3nn.gate(
305                    layer_out,
306                    even_act=activation_from_str(self.activation),
307                    even_gate_act=None,
308                )
309                layer_out = e3nn.flax.Linear(
310                    output_irreps, name=f"linear_readout_last"
311                )(layer_out)
312            else:
313                layer_out = e3nn.flax.Linear(
314                    output_irreps,
315                    name=f"linear_readout_{layer}",
316                )(node_feats)
317
318            if self.scalar_output:
319                layer_out = layer_out.filter("0e").array
320
321            outputs.append(layer_out)
322            node_feats_all.append(node_feats.filter("0e").array)
323
324        if self.scalar_output:
325            output = jnp.stack(outputs, axis=1)
326        else:
327            output = e3nn.stack(outputs, axis=1)
328
329        node_feats_all = jnp.concatenate(node_feats_all, axis=-1)
330
331        output_key = self.output_key if self.output_key is not None else self.name
332        return {
333            **inputs,
334            output_key: output,
335            output_key + "_node_feats": node_feats_all,
336        }

MACE equivariant message passing neural network.

adapted from MACE-jax github repo by M. Geiger and I. Batatia

T. Plé reordered some operations and changed defaults to match the recent mace-torch version -> compatibility with pretrained torch models requires some work on the parameters: - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters

References: - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. https://doi.org/10.48550/arXiv.2206.07697 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). https://doi.org/10.48550/arXiv.2205.06643

MACE( _graphs_properties: Dict, output_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '1x0e', hidden_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '128x0e + 128x1o', readout_mlp_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '16x0e', graph_key: str = 'graph', output_key: Optional[str] = None, avg_num_neighbors: float = 1.0, ninteractions: int = 2, num_features: Optional[int] = None, radial_basis: dict = <factory>, lmax: int = 1, correlation: int = 3, activation: str = 'silu', symmetric_tensor_product_basis: bool = False, interaction_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = 'o3_restricted', skip_connection_first_layer: bool = True, radial_network_hidden: Sequence[int] = <factory>, scalar_output: bool = False, zmax: int = 86, convolution_mode: int = 1, species_encoding_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
output_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '1x0e'

The output irreps of the model.

hidden_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '128x0e + 128x1o'

The hidden irreps of the model.

readout_mlp_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '16x0e'

The hidden irreps of the readout MLP.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the molecular graph to use.

output_key: Optional[str] = None

The key of the embedding in the output dictionary.

avg_num_neighbors: float = 1.0

The expected average number of neighbors.

ninteractions: int = 2

The number of interaction layers.

num_features: Optional[int] = None

The number of features per node. default gcd of hidden_irreps multiplicities

radial_basis: dict

The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.

lmax: int = 1

The maximum angular momentum to consider.

correlation: int = 3

The correlation order at each layer.

activation: str = 'silu'

The activation function to use.

symmetric_tensor_product_basis: bool = False

Whether to use the symmetric tensor product basis.

interaction_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = 'o3_restricted'
skip_connection_first_layer: bool = True
radial_network_hidden: Sequence[int]
scalar_output: bool = False
zmax: int = 86

The maximum atomic number to consider.

convolution_mode: int = 1
species_encoding_key: Optional[str] = None
FID: ClassVar[str] = 'MACE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SymmetricContraction(flax.linen.module.Module):
339class SymmetricContraction(nn.Module):
340
341    correlation: int
342    keep_irrep_out: Set[Irrep]
343    num_species: int
344    gradient_normalization: Union[str, float]
345    symmetric_tensor_product_basis: bool
346
347    @nn.compact
348    def __call__(self, input, index):
349        if not E3NN_AVAILABLE:
350            raise E3NN_EXCEPTION
351
352        if self.gradient_normalization is None:
353            gradient_normalization = e3nn.config("gradient_normalization")
354        else:
355            gradient_normalization = self.gradient_normalization
356        if isinstance(gradient_normalization, str):
357            gradient_normalization = {"element": 0.0, "path": 1.0}[
358                gradient_normalization
359            ]
360
361        keep_irrep_out = self.keep_irrep_out
362        if isinstance(keep_irrep_out, str):
363            keep_irrep_out = e3nn.Irreps(keep_irrep_out)
364            assert all(mul == 1 for mul, _ in keep_irrep_out)
365
366        keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out}
367
368        input = input.mul_to_axis().remove_nones()
369
370        ### PREPARE WEIGHTS
371        ws = []
372        Us = []
373        for order in range(1, self.correlation + 1):  # correlation, ..., 1
374            if self.symmetric_tensor_product_basis:
375                U = e3nn.reduced_symmetric_tensor_product_basis(
376                    input.irreps, order, keep_ir=keep_irrep_out
377                )
378            else:
379                U = e3nn.reduced_tensor_product_basis(
380                    [input.irreps] * order, keep_ir=keep_irrep_out
381                )
382            # U = U / order  # normalization TODO(mario): put back after testing
383            # NOTE(mario): The normalization constants (/order and /mul**0.5)
384            # has been numerically checked to be correct.
385
386            # TODO(mario) implement norm_p
387            Us.append(U)
388
389            wsorder = []
390            for (mul, ir_out), u in zip(U.irreps, U.list):
391                u = u.astype(input.array.dtype)
392                # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
393
394                w = self.param(
395                    f"w{order}_{ir_out}",
396                    nn.initializers.normal(
397                        stddev=(mul**-0.5) ** (1.0 - gradient_normalization)
398                    ),
399                    (self.num_species, mul, input.shape[-2]),
400                )
401                w = w * (mul**-0.5) ** gradient_normalization
402                wsorder.append(w)
403            ws.append(wsorder)
404
405        def fn(input: e3nn.IrrepsArray, index: jnp.ndarray):
406            # - This operation is parallel on the feature dimension (but each feature has its own parameters)
407            # This operation is an efficient implementation of
408            # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x)
409            # up to x power self.correlation
410            assert input.ndim == 2  # [num_features, irreps_x.dim]
411            assert index.ndim == 0  # int
412
413            out = dict()
414            x_ = input.array
415
416            for order in range(self.correlation, 0, -1):  # correlation, ..., 1
417
418                U = Us[order - 1]
419
420                # ((w3 x + w2) x + w1) x
421                #  \-----------/
422                #       out
423
424                for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)):
425                    u = u.astype(x_.dtype)
426                    # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
427
428                    w = ws[order - 1][ii][index]
429                    if ir_out not in out:
430                        out[ir_out] = (
431                            "special",
432                            jnp.einsum("...jki,kc,cj->c...i", u, w, x_),
433                        )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
434                    else:
435                        out[ir_out] += jnp.einsum(
436                            "...ki,kc->c...i", u, w
437                        )  # [num_features, (irreps_x.dim)^order, ir_out.dim]
438
439                # ((w3 x + w2) x + w1) x
440                #  \----------------/
441                #         out (in the normal case)
442
443                for ir_out in out:
444                    if isinstance(out[ir_out], tuple):
445                        out[ir_out] = out[ir_out][1]
446                        continue  # already done (special case optimization above)
447
448                    out[ir_out] = jnp.einsum(
449                        "c...ji,cj->c...i", out[ir_out], x_
450                    )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
451
452                # ((w3 x + w2) x + w1) x
453                #  \-------------------/
454                #           out
455
456            # out[irrep_out] : [num_features, ir_out.dim]
457            irreps_out = e3nn.Irreps(sorted(out.keys()))
458            return e3nn.IrrepsArray.from_list(
459                irreps_out,
460                [out[ir][:, None, :] for (_, ir) in irreps_out],
461                (input.shape[0],),
462            )
463
464        # Treat batch indices using vmap
465        shape = jnp.broadcast_shapes(input.shape[:-2], index.shape)
466        input = input.broadcast_to(shape + input.shape[-2:])
467        index = jnp.broadcast_to(index, shape)
468
469        fn_mapped = fn
470        for _ in range(input.ndim - 2):
471            fn_mapped = jax.vmap(fn_mapped)
472
473        return fn_mapped(input, index).axis_to_mul()
SymmetricContraction( correlation: int, keep_irrep_out: Set[e3nn_jax._src.irreps.Irrep], num_species: int, gradient_normalization: Union[str, float], symmetric_tensor_product_basis: bool, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
correlation: int
keep_irrep_out: Set[e3nn_jax._src.irreps.Irrep]
num_species: int
gradient_normalization: Union[str, float]
symmetric_tensor_product_basis: bool
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None