fennol.models.embeddings.raster

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, Union, ClassVar, Optional
  5import numpy as np
  6from ...utils.periodic_table import D3_COV_RADII
  7import dataclasses
  8from ..misc.encodings import RadialBasis, SpeciesEncoding
  9from ...utils.spherical_harmonics import generate_spherical_harmonics
 10from ..misc.e3 import  ChannelMixing
 11from ..misc.nets import FullyConnectedNet,BlockIndexNet
 12from ...utils.activations import activation_from_str
 13from ...utils.atomic_units import au
 14from ...utils.initializers import initializer_from_str
 15
 16class RaSTER(nn.Module):
 17    """ Range-Separated Transformer with Equivariant Representations
 18
 19    FID : RASTER
 20
 21    """
 22
 23    _graphs_properties: Dict
 24    dim: int = 176
 25    """The dimension of the output embedding."""
 26    nlayers: int = 2
 27    """The number of message-passing layers."""
 28    att_dim: int = 16
 29    """The dimension of the attention heads."""
 30    scal_heads: int = 16
 31    """The number of scalar attention heads."""
 32    tens_heads: int = 4
 33    """The number of tensor attention heads."""
 34    lmax: int = 3
 35    """The maximum angular momentum to consider."""
 36    normalize_vec: bool = True
 37    """Whether to normalize the vector features before computing spherical harmonics."""
 38    att_activation: str = "identity"
 39    """The activation function to use for the attention coefficients."""
 40    activation: str = "swish"
 41    """The activation function to use for the update network."""
 42    update_hidden: Sequence[int] = ()
 43    """The hidden layers for the update network."""
 44    update_bias: bool = True
 45    """Whether to use bias in the update network."""
 46    positional_activation: str = "swish"
 47    """The activation function to use for the positional embedding network."""
 48    positional_bias: bool = True
 49    """Whether to use bias in the positional embedding network."""
 50    switch_before_net: bool = False
 51    """Whether to apply the switch function to the radial basis before the edge neural network."""
 52    ignore_parity: bool = False
 53    """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding."""
 54    additive_positional: bool = False
 55    """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used."""
 56    edge_value: bool = False
 57    """Whether to use edge values in the attention mechanism."""
 58    layer_normalization: bool = True
 59    """Whether to use layer normalization of atomic embeddings."""
 60    layernorm_shift: bool = True
 61    """Whether to shift the mean in layer normalization."""
 62    graph_key: str = "graph"
 63    """ The key in the input dictionary that corresponds to the radial graph."""
 64    embedding_key: str = "embedding"
 65    """ The key in the output dictionary that corresponds to the embedding."""
 66    radial_basis: dict = dataclasses.field(
 67        default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16}
 68    )
 69    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 70    species_encoding: str | dict = dataclasses.field(
 71        default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"}
 72    )
 73    """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 74    graph_lode: Optional[str] = None
 75    """The key in the input dictionary that corresponds to the long-range graph."""
 76    lmax_lode: int = 0
 77    """The maximum angular momentum for the long-range features."""
 78    lode_rshort: Optional[float] = None
 79    """The short-range cutoff for the long-range features."""
 80    lode_dshort: float = 2.0
 81    """The width of the short-range cutoff for the long-range features."""
 82    lode_extra_powers: Sequence[int] = ()
 83    """The extra powers to include in the long-range features."""
 84    a_lode: float = -1.0
 85    """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode)."""
 86    block_index_key: Optional[str] = None
 87    """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used."""
 88    lode_channels: int = 1
 89    """The number of channels for the long-range features."""
 90    switch_cov_start: float = 0.5
 91    """The start of close-range covalent switch (in units of covalent radii)."""
 92    switch_cov_end: float = 0.6
 93    """The end of close-range covalent switch (in units of covalent radii)."""
 94    normalize_keys: bool = False
 95    """Whether to normalize queries and keys in the attention mechanism."""
 96    normalize_components: bool = False
 97    """Whether to normalize the components before the update network."""
 98    keep_all_layers: bool = False
 99    """Whether to return the stacked scalar embeddings from all message-passing layers."""
100    kernel_init: Optional[str] = None
101    
102    FID: ClassVar[str] = "RASTER"
103
104    @nn.compact
105    def __call__(self, inputs):
106        species = inputs["species"]
107
108        ## SETUP LAYER NORMALIZATION
109        if self.layernorm_shift:
110            def _layer_norm(x):
111                mu = jnp.mean(x, axis=-1, keepdims=True)
112                dx = x - mu
113                var = jnp.mean(dx**2, axis=-1, keepdims=True)
114                sig = (1.0e-6 + var) ** (-0.5)
115                return dx * sig
116        else:
117            def _layer_norm(x):
118                var = jnp.mean(x**2, axis=-1, keepdims=True)
119                sig = (1.0e-6 + var) ** (-0.5)
120                return x * sig
121
122        if self.layer_normalization:
123            layer_norm = _layer_norm
124        else:
125            layer_norm = lambda x: x
126        
127        if self.normalize_keys:
128            ln_qk = _layer_norm
129        else:
130            ln_qk = lambda x: x
131
132        kernel_init = initializer_from_str(self.kernel_init)
133
134        ## SPECIES ENCODING
135        if isinstance(self.species_encoding, str):
136            Zi = inputs[self.species_encoding]
137        else:
138            Zi = SpeciesEncoding(**self.species_encoding)(species)
139
140        ## INITIALIZE SCALAR FEATURES
141        xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi))
142
143        # RADIAL GRAPH
144        graph = inputs[self.graph_key]
145        distances = graph["distances"]
146        switch = graph["switch"]
147        edge_src = graph["edge_src"]
148        edge_dst = graph["edge_dst"]
149        vec = (
150            graph["vec"] / graph["distances"][:, None]
151            if self.normalize_vec
152            else graph["vec"]
153        )
154        ## CLOSE-RANGE SWITCH
155        use_switch_cov = False
156        if self.switch_cov_end > 0 and self.switch_cov_start > 0:
157            use_switch_cov = True
158            assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}"
159            assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1"
160            rc = jnp.array(D3_COV_RADII*au.ANG)[species]
161            rcij = rc[edge_src] + rc[edge_dst]
162            rstart = rcij * self.switch_cov_start
163            rend = rcij * self.switch_cov_end
164            switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend)
165            switch = switch * switch_short
166
167        ## COMPUTE SPHERICAL HARMONICS ON EDGES
168        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:]
169        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
170        ls = np.arange(self.lmax + 1).repeat(nrep)
171            
172        parity = jnp.array((-1) ** ls[None,None,:])
173        if self.ignore_parity:
174            parity = -jnp.ones_like(parity)
175
176        ## INITIALIZE TENSOR FEATURES
177        Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1]))
178
179        # RADIAL BASIS
180        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
181        radial_terms = RadialBasis(
182            **{
183                **self.radial_basis,
184                "end": cutoff,
185                "name": f"RadialBasis",
186            }
187        )(distances)
188        if self.switch_before_net:
189            radial_terms = radial_terms * switch[:, None]
190        elif use_switch_cov:
191            radial_terms = radial_terms * switch_short[:, None]
192
193        ## INITIALIZE LODE
194        do_lode = self.graph_lode is not None
195        if do_lode:
196            ## LONG-RANGE GRAPH
197            graph_lode = inputs[self.graph_lode]
198            switch_lode = graph_lode["switch"][:, None]
199            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
200            r = graph_lode["distances"][:, None]
201            rc = self._graphs_properties[self.graph_lode]["cutoff"]
202
203            lmax_lr = self.lmax_lode
204            equivariant_lode = lmax_lr > 0
205            assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}"
206            assert (
207                lmax_lr <= self.lmax
208            ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
209            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
210            if equivariant_lode:
211                ls_lr = np.arange(lmax_lr + 1)
212            else:
213                ls_lr = np.array([0])
214
215            ## PARAMETERS FOR THE LR RADIAL BASIS
216            nextra_powers = len(self.lode_extra_powers)
217            if nextra_powers > 0:
218                ls_lr = np.concatenate([self.lode_extra_powers, ls_lr])
219
220            if self.a_lode > 0:
221                a = self.a_lode**2
222            else:
223                a = (
224                    self.param(
225                        "a_lr",
226                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[
227                            None, :
228                        ],
229                    )
230                    ** 2
231                )
232            rc2a = rc**2 + a
233            ls_lr = 0.5 * (ls_lr[None, :] + 1)
234            ### minimal radial basis for long range (damped coulomb)
235            eij_lr = (
236                1.0 / (r**2 + a) ** ls_lr
237                - 1.0 / rc2a**ls_lr
238                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
239            ) * switch_lode
240
241            if self.lode_rshort is not None:
242                rs = self.lode_rshort
243                d = self.lode_dshort
244                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
245                    r < rs + d
246                ) + (r >= rs + d)
247                eij_lr = eij_lr * switch_short
248
249            dim_lr = 1
250            if nextra_powers > 0:
251                eij_lr_extra = eij_lr[:, :nextra_powers]
252                eij_lr = eij_lr[:, nextra_powers:]
253                dim_lr += nextra_powers
254
255            if equivariant_lode:
256                ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH
257                eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
258                Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
259                    graph_lode["vec"] / r
260                )
261                dim_lr += lmax_lr
262                eij_lr = eij_lr * Yij_lr
263                del Yij_lr
264        
265
266        if self.keep_all_layers:
267            xis = []
268
269        ### START MESSAGE PASSING ITERATIONS
270        for layer in range(self.nlayers):
271            ## GATHER SCALAR EDGE FEATURES
272            u = [radial_terms]
273            if layer > 0:
274                ## edge-tensor contraction
275                xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij
276                for l in range(self.lmax + 1):
277                    u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1))
278            ur = jnp.concatenate(u, axis=-1)
279
280            ## BUILD RELATIVE POSITIONAL ENCODING
281            if self.edge_value:
282                nout = 2
283            else:
284                nout = 1
285            w = FullyConnectedNet(
286                [2 * self.att_dim, nout*self.att_dim],
287                activation=self.positional_activation,
288                use_bias=self.positional_bias,
289                name=f"positional_encoding_{layer}",
290            )(ur).reshape(radial_terms.shape[0],nout, self.att_dim)
291            if self.edge_value:
292                w,vij = jnp.split(w, 2, axis=1)
293
294            nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1)
295
296
297            ## QUERY, KEY, VALUE
298            q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)(
299                xi
300            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim))
301            k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)(
302                xi
303            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)
304
305            v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape(
306                xi.shape[0], self.scal_heads, self.att_dim
307            )
308
309            ## ATTENTION COEFFICIENTS
310            if self.additive_positional:
311                wk = ln_qk(w + k[edge_dst])
312            else:
313                wk = ln_qk(w * k[edge_dst])
314
315            act = activation_from_str(self.att_activation)
316            aij = (
317                act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5))
318                * switch[:, None]
319            )
320
321            aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
322            if layer > 0:
323                aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
324            aij = aij[:, self.tens_heads*nls:, None]
325
326            if self.edge_value:
327                ## EDGE VALUES
328                if self.additive_positional:
329                    vij = vij + v[edge_dst]
330                else:
331                    vij = vij * v[edge_dst]
332            else:
333                ## MOVE DEST VALUES TO EDGE
334                vij = v[edge_dst]
335
336            ## SCALAR ATTENDED FEATURES
337            vai = jax.ops.segment_sum(
338                aij * vij,
339                edge_src,
340                num_segments=xi.shape[0],
341            )
342            vai = vai.reshape(xi.shape[0], -1)
343
344            ### TENSOR ATTENDED FEATURES
345            uij = aijl * Yij
346            if layer > 0:
347                uij = uij + aijl1 * Vi[edge_dst]
348            Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0])
349
350            ## SELF SCALAR FEATURES
351            si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi)
352
353            components = [si, vai]
354
355            ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS
356            if self.tens_heads == 1:
357                Vi2 = Vi**2
358            else:
359                Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi)
360            for l in range(self.lmax + 1):
361                norm = 1.0 / (2 * l + 1)
362                components.append(
363                    (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm
364                )
365
366            ### LODE (~ LONG-RANGE ATTENTION)
367            if do_lode and layer == self.nlayers - 1:
368                assert self.lode_channels <= self.tens_heads
369                zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape(
370                    xi.shape[0], self.lode_channels, dim_lr
371                )
372                if nextra_powers > 0:
373                    zj_extra = zj[:,:, :nextra_powers]
374                    zj = zj[:, :, nextra_powers:]
375                    xi_lr_extra = jax.ops.segment_sum(
376                        eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr],
377                        edge_src_lr,
378                        species.shape[0],
379                    ).reshape(species.shape[0],-1)
380                    components.append(xi_lr_extra)
381                if equivariant_lode:
382                    zj = zj.repeat(nrep_lr, axis=-1)
383                Vi_lr = jax.ops.segment_sum(
384                    eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0]
385                )
386                components.append(Vi_lr[:,: , 0])
387                if equivariant_lode:
388                    Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr
389                    for l in range(1, lmax_lr + 1):
390                        norm = 1.0 / (2 * l + 1)
391                        components.append(
392                            Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1)
393                            * norm
394                        )
395
396            ### CONCATENATE UPDATE COMPONENTS
397            components = jnp.concatenate(components, axis=-1)
398            if self.normalize_components:
399                components = _layer_norm(components)
400            ### COMPUTE UPDATE
401            if self.block_index_key is not None:
402                ## MoE neural network from block index
403                block_index = inputs[self.block_index_key]
404                updi = BlockIndexNet(
405                        output_dim=self.dim + self.tens_heads*(self.lmax + 1),
406                        hidden_neurons=self.update_hidden,
407                        activation=self.activation,
408                        use_bias=self.update_bias,
409                        name=f"update_net_{layer}",
410                        kernel_init=kernel_init,
411                    )((species,components, block_index))
412            else:
413                updi = FullyConnectedNet(
414                        [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)],
415                        activation=self.activation,
416                        use_bias=self.update_bias,
417                        name=f"update_net_{layer}",
418                        kernel_init=kernel_init,
419                    )(components)
420                
421            ## UPDATE ATOM FEATURES
422            xi = layer_norm(xi + updi[:,:self.dim])
423            Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
424            if self.tens_heads > 1:
425                Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi)
426
427            if self.keep_all_layers:
428                ## STORE ALL LAYERS
429                xis.append(xi)
430
431
432        output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi}
433        if self.keep_all_layers:
434            output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1)
435        return output
class RaSTER(flax.linen.module.Module):
 17class RaSTER(nn.Module):
 18    """ Range-Separated Transformer with Equivariant Representations
 19
 20    FID : RASTER
 21
 22    """
 23
 24    _graphs_properties: Dict
 25    dim: int = 176
 26    """The dimension of the output embedding."""
 27    nlayers: int = 2
 28    """The number of message-passing layers."""
 29    att_dim: int = 16
 30    """The dimension of the attention heads."""
 31    scal_heads: int = 16
 32    """The number of scalar attention heads."""
 33    tens_heads: int = 4
 34    """The number of tensor attention heads."""
 35    lmax: int = 3
 36    """The maximum angular momentum to consider."""
 37    normalize_vec: bool = True
 38    """Whether to normalize the vector features before computing spherical harmonics."""
 39    att_activation: str = "identity"
 40    """The activation function to use for the attention coefficients."""
 41    activation: str = "swish"
 42    """The activation function to use for the update network."""
 43    update_hidden: Sequence[int] = ()
 44    """The hidden layers for the update network."""
 45    update_bias: bool = True
 46    """Whether to use bias in the update network."""
 47    positional_activation: str = "swish"
 48    """The activation function to use for the positional embedding network."""
 49    positional_bias: bool = True
 50    """Whether to use bias in the positional embedding network."""
 51    switch_before_net: bool = False
 52    """Whether to apply the switch function to the radial basis before the edge neural network."""
 53    ignore_parity: bool = False
 54    """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding."""
 55    additive_positional: bool = False
 56    """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used."""
 57    edge_value: bool = False
 58    """Whether to use edge values in the attention mechanism."""
 59    layer_normalization: bool = True
 60    """Whether to use layer normalization of atomic embeddings."""
 61    layernorm_shift: bool = True
 62    """Whether to shift the mean in layer normalization."""
 63    graph_key: str = "graph"
 64    """ The key in the input dictionary that corresponds to the radial graph."""
 65    embedding_key: str = "embedding"
 66    """ The key in the output dictionary that corresponds to the embedding."""
 67    radial_basis: dict = dataclasses.field(
 68        default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16}
 69    )
 70    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 71    species_encoding: str | dict = dataclasses.field(
 72        default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"}
 73    )
 74    """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 75    graph_lode: Optional[str] = None
 76    """The key in the input dictionary that corresponds to the long-range graph."""
 77    lmax_lode: int = 0
 78    """The maximum angular momentum for the long-range features."""
 79    lode_rshort: Optional[float] = None
 80    """The short-range cutoff for the long-range features."""
 81    lode_dshort: float = 2.0
 82    """The width of the short-range cutoff for the long-range features."""
 83    lode_extra_powers: Sequence[int] = ()
 84    """The extra powers to include in the long-range features."""
 85    a_lode: float = -1.0
 86    """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode)."""
 87    block_index_key: Optional[str] = None
 88    """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used."""
 89    lode_channels: int = 1
 90    """The number of channels for the long-range features."""
 91    switch_cov_start: float = 0.5
 92    """The start of close-range covalent switch (in units of covalent radii)."""
 93    switch_cov_end: float = 0.6
 94    """The end of close-range covalent switch (in units of covalent radii)."""
 95    normalize_keys: bool = False
 96    """Whether to normalize queries and keys in the attention mechanism."""
 97    normalize_components: bool = False
 98    """Whether to normalize the components before the update network."""
 99    keep_all_layers: bool = False
100    """Whether to return the stacked scalar embeddings from all message-passing layers."""
101    kernel_init: Optional[str] = None
102    
103    FID: ClassVar[str] = "RASTER"
104
105    @nn.compact
106    def __call__(self, inputs):
107        species = inputs["species"]
108
109        ## SETUP LAYER NORMALIZATION
110        if self.layernorm_shift:
111            def _layer_norm(x):
112                mu = jnp.mean(x, axis=-1, keepdims=True)
113                dx = x - mu
114                var = jnp.mean(dx**2, axis=-1, keepdims=True)
115                sig = (1.0e-6 + var) ** (-0.5)
116                return dx * sig
117        else:
118            def _layer_norm(x):
119                var = jnp.mean(x**2, axis=-1, keepdims=True)
120                sig = (1.0e-6 + var) ** (-0.5)
121                return x * sig
122
123        if self.layer_normalization:
124            layer_norm = _layer_norm
125        else:
126            layer_norm = lambda x: x
127        
128        if self.normalize_keys:
129            ln_qk = _layer_norm
130        else:
131            ln_qk = lambda x: x
132
133        kernel_init = initializer_from_str(self.kernel_init)
134
135        ## SPECIES ENCODING
136        if isinstance(self.species_encoding, str):
137            Zi = inputs[self.species_encoding]
138        else:
139            Zi = SpeciesEncoding(**self.species_encoding)(species)
140
141        ## INITIALIZE SCALAR FEATURES
142        xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi))
143
144        # RADIAL GRAPH
145        graph = inputs[self.graph_key]
146        distances = graph["distances"]
147        switch = graph["switch"]
148        edge_src = graph["edge_src"]
149        edge_dst = graph["edge_dst"]
150        vec = (
151            graph["vec"] / graph["distances"][:, None]
152            if self.normalize_vec
153            else graph["vec"]
154        )
155        ## CLOSE-RANGE SWITCH
156        use_switch_cov = False
157        if self.switch_cov_end > 0 and self.switch_cov_start > 0:
158            use_switch_cov = True
159            assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}"
160            assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1"
161            rc = jnp.array(D3_COV_RADII*au.ANG)[species]
162            rcij = rc[edge_src] + rc[edge_dst]
163            rstart = rcij * self.switch_cov_start
164            rend = rcij * self.switch_cov_end
165            switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend)
166            switch = switch * switch_short
167
168        ## COMPUTE SPHERICAL HARMONICS ON EDGES
169        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:]
170        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
171        ls = np.arange(self.lmax + 1).repeat(nrep)
172            
173        parity = jnp.array((-1) ** ls[None,None,:])
174        if self.ignore_parity:
175            parity = -jnp.ones_like(parity)
176
177        ## INITIALIZE TENSOR FEATURES
178        Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1]))
179
180        # RADIAL BASIS
181        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
182        radial_terms = RadialBasis(
183            **{
184                **self.radial_basis,
185                "end": cutoff,
186                "name": f"RadialBasis",
187            }
188        )(distances)
189        if self.switch_before_net:
190            radial_terms = radial_terms * switch[:, None]
191        elif use_switch_cov:
192            radial_terms = radial_terms * switch_short[:, None]
193
194        ## INITIALIZE LODE
195        do_lode = self.graph_lode is not None
196        if do_lode:
197            ## LONG-RANGE GRAPH
198            graph_lode = inputs[self.graph_lode]
199            switch_lode = graph_lode["switch"][:, None]
200            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
201            r = graph_lode["distances"][:, None]
202            rc = self._graphs_properties[self.graph_lode]["cutoff"]
203
204            lmax_lr = self.lmax_lode
205            equivariant_lode = lmax_lr > 0
206            assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}"
207            assert (
208                lmax_lr <= self.lmax
209            ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
210            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
211            if equivariant_lode:
212                ls_lr = np.arange(lmax_lr + 1)
213            else:
214                ls_lr = np.array([0])
215
216            ## PARAMETERS FOR THE LR RADIAL BASIS
217            nextra_powers = len(self.lode_extra_powers)
218            if nextra_powers > 0:
219                ls_lr = np.concatenate([self.lode_extra_powers, ls_lr])
220
221            if self.a_lode > 0:
222                a = self.a_lode**2
223            else:
224                a = (
225                    self.param(
226                        "a_lr",
227                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[
228                            None, :
229                        ],
230                    )
231                    ** 2
232                )
233            rc2a = rc**2 + a
234            ls_lr = 0.5 * (ls_lr[None, :] + 1)
235            ### minimal radial basis for long range (damped coulomb)
236            eij_lr = (
237                1.0 / (r**2 + a) ** ls_lr
238                - 1.0 / rc2a**ls_lr
239                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
240            ) * switch_lode
241
242            if self.lode_rshort is not None:
243                rs = self.lode_rshort
244                d = self.lode_dshort
245                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
246                    r < rs + d
247                ) + (r >= rs + d)
248                eij_lr = eij_lr * switch_short
249
250            dim_lr = 1
251            if nextra_powers > 0:
252                eij_lr_extra = eij_lr[:, :nextra_powers]
253                eij_lr = eij_lr[:, nextra_powers:]
254                dim_lr += nextra_powers
255
256            if equivariant_lode:
257                ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH
258                eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
259                Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
260                    graph_lode["vec"] / r
261                )
262                dim_lr += lmax_lr
263                eij_lr = eij_lr * Yij_lr
264                del Yij_lr
265        
266
267        if self.keep_all_layers:
268            xis = []
269
270        ### START MESSAGE PASSING ITERATIONS
271        for layer in range(self.nlayers):
272            ## GATHER SCALAR EDGE FEATURES
273            u = [radial_terms]
274            if layer > 0:
275                ## edge-tensor contraction
276                xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij
277                for l in range(self.lmax + 1):
278                    u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1))
279            ur = jnp.concatenate(u, axis=-1)
280
281            ## BUILD RELATIVE POSITIONAL ENCODING
282            if self.edge_value:
283                nout = 2
284            else:
285                nout = 1
286            w = FullyConnectedNet(
287                [2 * self.att_dim, nout*self.att_dim],
288                activation=self.positional_activation,
289                use_bias=self.positional_bias,
290                name=f"positional_encoding_{layer}",
291            )(ur).reshape(radial_terms.shape[0],nout, self.att_dim)
292            if self.edge_value:
293                w,vij = jnp.split(w, 2, axis=1)
294
295            nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1)
296
297
298            ## QUERY, KEY, VALUE
299            q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)(
300                xi
301            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim))
302            k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)(
303                xi
304            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)
305
306            v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape(
307                xi.shape[0], self.scal_heads, self.att_dim
308            )
309
310            ## ATTENTION COEFFICIENTS
311            if self.additive_positional:
312                wk = ln_qk(w + k[edge_dst])
313            else:
314                wk = ln_qk(w * k[edge_dst])
315
316            act = activation_from_str(self.att_activation)
317            aij = (
318                act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5))
319                * switch[:, None]
320            )
321
322            aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
323            if layer > 0:
324                aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
325            aij = aij[:, self.tens_heads*nls:, None]
326
327            if self.edge_value:
328                ## EDGE VALUES
329                if self.additive_positional:
330                    vij = vij + v[edge_dst]
331                else:
332                    vij = vij * v[edge_dst]
333            else:
334                ## MOVE DEST VALUES TO EDGE
335                vij = v[edge_dst]
336
337            ## SCALAR ATTENDED FEATURES
338            vai = jax.ops.segment_sum(
339                aij * vij,
340                edge_src,
341                num_segments=xi.shape[0],
342            )
343            vai = vai.reshape(xi.shape[0], -1)
344
345            ### TENSOR ATTENDED FEATURES
346            uij = aijl * Yij
347            if layer > 0:
348                uij = uij + aijl1 * Vi[edge_dst]
349            Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0])
350
351            ## SELF SCALAR FEATURES
352            si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi)
353
354            components = [si, vai]
355
356            ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS
357            if self.tens_heads == 1:
358                Vi2 = Vi**2
359            else:
360                Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi)
361            for l in range(self.lmax + 1):
362                norm = 1.0 / (2 * l + 1)
363                components.append(
364                    (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm
365                )
366
367            ### LODE (~ LONG-RANGE ATTENTION)
368            if do_lode and layer == self.nlayers - 1:
369                assert self.lode_channels <= self.tens_heads
370                zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape(
371                    xi.shape[0], self.lode_channels, dim_lr
372                )
373                if nextra_powers > 0:
374                    zj_extra = zj[:,:, :nextra_powers]
375                    zj = zj[:, :, nextra_powers:]
376                    xi_lr_extra = jax.ops.segment_sum(
377                        eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr],
378                        edge_src_lr,
379                        species.shape[0],
380                    ).reshape(species.shape[0],-1)
381                    components.append(xi_lr_extra)
382                if equivariant_lode:
383                    zj = zj.repeat(nrep_lr, axis=-1)
384                Vi_lr = jax.ops.segment_sum(
385                    eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0]
386                )
387                components.append(Vi_lr[:,: , 0])
388                if equivariant_lode:
389                    Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr
390                    for l in range(1, lmax_lr + 1):
391                        norm = 1.0 / (2 * l + 1)
392                        components.append(
393                            Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1)
394                            * norm
395                        )
396
397            ### CONCATENATE UPDATE COMPONENTS
398            components = jnp.concatenate(components, axis=-1)
399            if self.normalize_components:
400                components = _layer_norm(components)
401            ### COMPUTE UPDATE
402            if self.block_index_key is not None:
403                ## MoE neural network from block index
404                block_index = inputs[self.block_index_key]
405                updi = BlockIndexNet(
406                        output_dim=self.dim + self.tens_heads*(self.lmax + 1),
407                        hidden_neurons=self.update_hidden,
408                        activation=self.activation,
409                        use_bias=self.update_bias,
410                        name=f"update_net_{layer}",
411                        kernel_init=kernel_init,
412                    )((species,components, block_index))
413            else:
414                updi = FullyConnectedNet(
415                        [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)],
416                        activation=self.activation,
417                        use_bias=self.update_bias,
418                        name=f"update_net_{layer}",
419                        kernel_init=kernel_init,
420                    )(components)
421                
422            ## UPDATE ATOM FEATURES
423            xi = layer_norm(xi + updi[:,:self.dim])
424            Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
425            if self.tens_heads > 1:
426                Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi)
427
428            if self.keep_all_layers:
429                ## STORE ALL LAYERS
430                xis.append(xi)
431
432
433        output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi}
434        if self.keep_all_layers:
435            output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1)
436        return output

Range-Separated Transformer with Equivariant Representations

FID : RASTER

RaSTER( _graphs_properties: Dict, dim: int = 176, nlayers: int = 2, att_dim: int = 16, scal_heads: int = 16, tens_heads: int = 4, lmax: int = 3, normalize_vec: bool = True, att_activation: str = 'identity', activation: str = 'swish', update_hidden: Sequence[int] = (), update_bias: bool = True, positional_activation: str = 'swish', positional_bias: bool = True, switch_before_net: bool = False, ignore_parity: bool = False, additive_positional: bool = False, edge_value: bool = False, layer_normalization: bool = True, layernorm_shift: bool = True, graph_key: str = 'graph', embedding_key: str = 'embedding', radial_basis: dict = <factory>, species_encoding: str | dict = <factory>, graph_lode: Optional[str] = None, lmax_lode: int = 0, lode_rshort: Optional[float] = None, lode_dshort: float = 2.0, lode_extra_powers: Sequence[int] = (), a_lode: float = -1.0, block_index_key: Optional[str] = None, lode_channels: int = 1, switch_cov_start: float = 0.5, switch_cov_end: float = 0.6, normalize_keys: bool = False, normalize_components: bool = False, keep_all_layers: bool = False, kernel_init: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 176

The dimension of the output embedding.

nlayers: int = 2

The number of message-passing layers.

att_dim: int = 16

The dimension of the attention heads.

scal_heads: int = 16

The number of scalar attention heads.

tens_heads: int = 4

The number of tensor attention heads.

lmax: int = 3

The maximum angular momentum to consider.

normalize_vec: bool = True

Whether to normalize the vector features before computing spherical harmonics.

att_activation: str = 'identity'

The activation function to use for the attention coefficients.

activation: str = 'swish'

The activation function to use for the update network.

update_hidden: Sequence[int] = ()

The hidden layers for the update network.

update_bias: bool = True

Whether to use bias in the update network.

positional_activation: str = 'swish'

The activation function to use for the positional embedding network.

positional_bias: bool = True

Whether to use bias in the positional embedding network.

switch_before_net: bool = False

Whether to apply the switch function to the radial basis before the edge neural network.

ignore_parity: bool = False

Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.

additive_positional: bool = False

Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.

edge_value: bool = False

Whether to use edge values in the attention mechanism.

layer_normalization: bool = True

Whether to use layer normalization of atomic embeddings.

layernorm_shift: bool = True

Whether to shift the mean in layer normalization.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the radial graph.

embedding_key: str = 'embedding'

The key in the output dictionary that corresponds to the embedding.

radial_basis: dict

The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.

species_encoding: str | dict

The dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding.

graph_lode: Optional[str] = None

The key in the input dictionary that corresponds to the long-range graph.

lmax_lode: int = 0

The maximum angular momentum for the long-range features.

lode_rshort: Optional[float] = None

The short-range cutoff for the long-range features.

lode_dshort: float = 2.0

The width of the short-range cutoff for the long-range features.

lode_extra_powers: Sequence[int] = ()

The extra powers to include in the long-range features.

a_lode: float = -1.0

The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).

block_index_key: Optional[str] = None

The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.

lode_channels: int = 1

The number of channels for the long-range features.

switch_cov_start: float = 0.5

The start of close-range covalent switch (in units of covalent radii).

switch_cov_end: float = 0.6

The end of close-range covalent switch (in units of covalent radii).

normalize_keys: bool = False

Whether to normalize queries and keys in the attention mechanism.

normalize_components: bool = False

Whether to normalize the components before the update network.

keep_all_layers: bool = False

Whether to return the stacked scalar embeddings from all message-passing layers.

kernel_init: Optional[str] = None
FID: ClassVar[str] = 'RASTER'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None