fennol.models.embeddings.raster
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Union, ClassVar, Optional 5import numpy as np 6from ...utils.periodic_table import D3_COV_RADII 7import dataclasses 8from ..misc.encodings import RadialBasis, SpeciesEncoding 9from ...utils.spherical_harmonics import generate_spherical_harmonics 10from ..misc.e3 import ChannelMixing 11from ..misc.nets import FullyConnectedNet,BlockIndexNet 12from ...utils.activations import activation_from_str 13from ...utils.atomic_units import au 14from ...utils.initializers import initializer_from_str 15 16class RaSTER(nn.Module): 17 """ Range-Separated Transformer with Equivariant Representations 18 19 FID : RASTER 20 21 """ 22 23 _graphs_properties: Dict 24 dim: int = 176 25 """The dimension of the output embedding.""" 26 nlayers: int = 2 27 """The number of message-passing layers.""" 28 att_dim: int = 16 29 """The dimension of the attention heads.""" 30 scal_heads: int = 16 31 """The number of scalar attention heads.""" 32 tens_heads: int = 4 33 """The number of tensor attention heads.""" 34 lmax: int = 3 35 """The maximum angular momentum to consider.""" 36 normalize_vec: bool = True 37 """Whether to normalize the vector features before computing spherical harmonics.""" 38 att_activation: str = "identity" 39 """The activation function to use for the attention coefficients.""" 40 activation: str = "swish" 41 """The activation function to use for the update network.""" 42 update_hidden: Sequence[int] = () 43 """The hidden layers for the update network.""" 44 update_bias: bool = True 45 """Whether to use bias in the update network.""" 46 positional_activation: str = "swish" 47 """The activation function to use for the positional embedding network.""" 48 positional_bias: bool = True 49 """Whether to use bias in the positional embedding network.""" 50 switch_before_net: bool = False 51 """Whether to apply the switch function to the radial basis before the edge neural network.""" 52 ignore_parity: bool = False 53 """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.""" 54 additive_positional: bool = False 55 """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.""" 56 edge_value: bool = False 57 """Whether to use edge values in the attention mechanism.""" 58 layer_normalization: bool = True 59 """Whether to use layer normalization of atomic embeddings.""" 60 layernorm_shift: bool = True 61 """Whether to shift the mean in layer normalization.""" 62 graph_key: str = "graph" 63 """ The key in the input dictionary that corresponds to the radial graph.""" 64 embedding_key: str = "embedding" 65 """ The key in the output dictionary that corresponds to the embedding.""" 66 radial_basis: dict = dataclasses.field( 67 default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16} 68 ) 69 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 70 species_encoding: str | dict = dataclasses.field( 71 default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"} 72 ) 73 """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 74 graph_lode: Optional[str] = None 75 """The key in the input dictionary that corresponds to the long-range graph.""" 76 lmax_lode: int = 0 77 """The maximum angular momentum for the long-range features.""" 78 lode_rshort: Optional[float] = None 79 """The short-range cutoff for the long-range features.""" 80 lode_dshort: float = 2.0 81 """The width of the short-range cutoff for the long-range features.""" 82 lode_extra_powers: Sequence[int] = () 83 """The extra powers to include in the long-range features.""" 84 a_lode: float = -1.0 85 """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).""" 86 block_index_key: Optional[str] = None 87 """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.""" 88 lode_channels: int = 1 89 """The number of channels for the long-range features.""" 90 switch_cov_start: float = 0.5 91 """The start of close-range covalent switch (in units of covalent radii).""" 92 switch_cov_end: float = 0.6 93 """The end of close-range covalent switch (in units of covalent radii).""" 94 normalize_keys: bool = False 95 """Whether to normalize queries and keys in the attention mechanism.""" 96 normalize_components: bool = False 97 """Whether to normalize the components before the update network.""" 98 keep_all_layers: bool = False 99 """Whether to return the stacked scalar embeddings from all message-passing layers.""" 100 kernel_init: Optional[str] = None 101 102 FID: ClassVar[str] = "RASTER" 103 104 @nn.compact 105 def __call__(self, inputs): 106 species = inputs["species"] 107 108 ## SETUP LAYER NORMALIZATION 109 if self.layernorm_shift: 110 def _layer_norm(x): 111 mu = jnp.mean(x, axis=-1, keepdims=True) 112 dx = x - mu 113 var = jnp.mean(dx**2, axis=-1, keepdims=True) 114 sig = (1.0e-6 + var) ** (-0.5) 115 return dx * sig 116 else: 117 def _layer_norm(x): 118 var = jnp.mean(x**2, axis=-1, keepdims=True) 119 sig = (1.0e-6 + var) ** (-0.5) 120 return x * sig 121 122 if self.layer_normalization: 123 layer_norm = _layer_norm 124 else: 125 layer_norm = lambda x: x 126 127 if self.normalize_keys: 128 ln_qk = _layer_norm 129 else: 130 ln_qk = lambda x: x 131 132 kernel_init = initializer_from_str(self.kernel_init) 133 134 ## SPECIES ENCODING 135 if isinstance(self.species_encoding, str): 136 Zi = inputs[self.species_encoding] 137 else: 138 Zi = SpeciesEncoding(**self.species_encoding)(species) 139 140 ## INITIALIZE SCALAR FEATURES 141 xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi)) 142 143 # RADIAL GRAPH 144 graph = inputs[self.graph_key] 145 distances = graph["distances"] 146 switch = graph["switch"] 147 edge_src = graph["edge_src"] 148 edge_dst = graph["edge_dst"] 149 vec = ( 150 graph["vec"] / graph["distances"][:, None] 151 if self.normalize_vec 152 else graph["vec"] 153 ) 154 ## CLOSE-RANGE SWITCH 155 use_switch_cov = False 156 if self.switch_cov_end > 0 and self.switch_cov_start > 0: 157 use_switch_cov = True 158 assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}" 159 assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1" 160 rc = jnp.array(D3_COV_RADII*au.ANG)[species] 161 rcij = rc[edge_src] + rc[edge_dst] 162 rstart = rcij * self.switch_cov_start 163 rend = rcij * self.switch_cov_end 164 switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend) 165 switch = switch * switch_short 166 167 ## COMPUTE SPHERICAL HARMONICS ON EDGES 168 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:] 169 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 170 ls = np.arange(self.lmax + 1).repeat(nrep) 171 172 parity = jnp.array((-1) ** ls[None,None,:]) 173 if self.ignore_parity: 174 parity = -jnp.ones_like(parity) 175 176 ## INITIALIZE TENSOR FEATURES 177 Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1])) 178 179 # RADIAL BASIS 180 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 181 radial_terms = RadialBasis( 182 **{ 183 **self.radial_basis, 184 "end": cutoff, 185 "name": f"RadialBasis", 186 } 187 )(distances) 188 if self.switch_before_net: 189 radial_terms = radial_terms * switch[:, None] 190 elif use_switch_cov: 191 radial_terms = radial_terms * switch_short[:, None] 192 193 ## INITIALIZE LODE 194 do_lode = self.graph_lode is not None 195 if do_lode: 196 ## LONG-RANGE GRAPH 197 graph_lode = inputs[self.graph_lode] 198 switch_lode = graph_lode["switch"][:, None] 199 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 200 r = graph_lode["distances"][:, None] 201 rc = self._graphs_properties[self.graph_lode]["cutoff"] 202 203 lmax_lr = self.lmax_lode 204 equivariant_lode = lmax_lr > 0 205 assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}" 206 assert ( 207 lmax_lr <= self.lmax 208 ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 209 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 210 if equivariant_lode: 211 ls_lr = np.arange(lmax_lr + 1) 212 else: 213 ls_lr = np.array([0]) 214 215 ## PARAMETERS FOR THE LR RADIAL BASIS 216 nextra_powers = len(self.lode_extra_powers) 217 if nextra_powers > 0: 218 ls_lr = np.concatenate([self.lode_extra_powers, ls_lr]) 219 220 if self.a_lode > 0: 221 a = self.a_lode**2 222 else: 223 a = ( 224 self.param( 225 "a_lr", 226 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[ 227 None, : 228 ], 229 ) 230 ** 2 231 ) 232 rc2a = rc**2 + a 233 ls_lr = 0.5 * (ls_lr[None, :] + 1) 234 ### minimal radial basis for long range (damped coulomb) 235 eij_lr = ( 236 1.0 / (r**2 + a) ** ls_lr 237 - 1.0 / rc2a**ls_lr 238 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 239 ) * switch_lode 240 241 if self.lode_rshort is not None: 242 rs = self.lode_rshort 243 d = self.lode_dshort 244 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 245 r < rs + d 246 ) + (r >= rs + d) 247 eij_lr = eij_lr * switch_short 248 249 dim_lr = 1 250 if nextra_powers > 0: 251 eij_lr_extra = eij_lr[:, :nextra_powers] 252 eij_lr = eij_lr[:, nextra_powers:] 253 dim_lr += nextra_powers 254 255 if equivariant_lode: 256 ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH 257 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 258 Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 259 graph_lode["vec"] / r 260 ) 261 dim_lr += lmax_lr 262 eij_lr = eij_lr * Yij_lr 263 del Yij_lr 264 265 266 if self.keep_all_layers: 267 xis = [] 268 269 ### START MESSAGE PASSING ITERATIONS 270 for layer in range(self.nlayers): 271 ## GATHER SCALAR EDGE FEATURES 272 u = [radial_terms] 273 if layer > 0: 274 ## edge-tensor contraction 275 xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij 276 for l in range(self.lmax + 1): 277 u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1)) 278 ur = jnp.concatenate(u, axis=-1) 279 280 ## BUILD RELATIVE POSITIONAL ENCODING 281 if self.edge_value: 282 nout = 2 283 else: 284 nout = 1 285 w = FullyConnectedNet( 286 [2 * self.att_dim, nout*self.att_dim], 287 activation=self.positional_activation, 288 use_bias=self.positional_bias, 289 name=f"positional_encoding_{layer}", 290 )(ur).reshape(radial_terms.shape[0],nout, self.att_dim) 291 if self.edge_value: 292 w,vij = jnp.split(w, 2, axis=1) 293 294 nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1) 295 296 297 ## QUERY, KEY, VALUE 298 q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)( 299 xi 300 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)) 301 k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)( 302 xi 303 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim) 304 305 v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape( 306 xi.shape[0], self.scal_heads, self.att_dim 307 ) 308 309 ## ATTENTION COEFFICIENTS 310 if self.additive_positional: 311 wk = ln_qk(w + k[edge_dst]) 312 else: 313 wk = ln_qk(w * k[edge_dst]) 314 315 act = activation_from_str(self.att_activation) 316 aij = ( 317 act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5)) 318 * switch[:, None] 319 ) 320 321 aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 322 if layer > 0: 323 aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 324 aij = aij[:, self.tens_heads*nls:, None] 325 326 if self.edge_value: 327 ## EDGE VALUES 328 if self.additive_positional: 329 vij = vij + v[edge_dst] 330 else: 331 vij = vij * v[edge_dst] 332 else: 333 ## MOVE DEST VALUES TO EDGE 334 vij = v[edge_dst] 335 336 ## SCALAR ATTENDED FEATURES 337 vai = jax.ops.segment_sum( 338 aij * vij, 339 edge_src, 340 num_segments=xi.shape[0], 341 ) 342 vai = vai.reshape(xi.shape[0], -1) 343 344 ### TENSOR ATTENDED FEATURES 345 uij = aijl * Yij 346 if layer > 0: 347 uij = uij + aijl1 * Vi[edge_dst] 348 Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0]) 349 350 ## SELF SCALAR FEATURES 351 si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi) 352 353 components = [si, vai] 354 355 ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS 356 if self.tens_heads == 1: 357 Vi2 = Vi**2 358 else: 359 Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi) 360 for l in range(self.lmax + 1): 361 norm = 1.0 / (2 * l + 1) 362 components.append( 363 (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm 364 ) 365 366 ### LODE (~ LONG-RANGE ATTENTION) 367 if do_lode and layer == self.nlayers - 1: 368 assert self.lode_channels <= self.tens_heads 369 zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape( 370 xi.shape[0], self.lode_channels, dim_lr 371 ) 372 if nextra_powers > 0: 373 zj_extra = zj[:,:, :nextra_powers] 374 zj = zj[:, :, nextra_powers:] 375 xi_lr_extra = jax.ops.segment_sum( 376 eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr], 377 edge_src_lr, 378 species.shape[0], 379 ).reshape(species.shape[0],-1) 380 components.append(xi_lr_extra) 381 if equivariant_lode: 382 zj = zj.repeat(nrep_lr, axis=-1) 383 Vi_lr = jax.ops.segment_sum( 384 eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0] 385 ) 386 components.append(Vi_lr[:,: , 0]) 387 if equivariant_lode: 388 Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr 389 for l in range(1, lmax_lr + 1): 390 norm = 1.0 / (2 * l + 1) 391 components.append( 392 Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1) 393 * norm 394 ) 395 396 ### CONCATENATE UPDATE COMPONENTS 397 components = jnp.concatenate(components, axis=-1) 398 if self.normalize_components: 399 components = _layer_norm(components) 400 ### COMPUTE UPDATE 401 if self.block_index_key is not None: 402 ## MoE neural network from block index 403 block_index = inputs[self.block_index_key] 404 updi = BlockIndexNet( 405 output_dim=self.dim + self.tens_heads*(self.lmax + 1), 406 hidden_neurons=self.update_hidden, 407 activation=self.activation, 408 use_bias=self.update_bias, 409 name=f"update_net_{layer}", 410 kernel_init=kernel_init, 411 )((species,components, block_index)) 412 else: 413 updi = FullyConnectedNet( 414 [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)], 415 activation=self.activation, 416 use_bias=self.update_bias, 417 name=f"update_net_{layer}", 418 kernel_init=kernel_init, 419 )(components) 420 421 ## UPDATE ATOM FEATURES 422 xi = layer_norm(xi + updi[:,:self.dim]) 423 Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 424 if self.tens_heads > 1: 425 Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi) 426 427 if self.keep_all_layers: 428 ## STORE ALL LAYERS 429 xis.append(xi) 430 431 432 output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi} 433 if self.keep_all_layers: 434 output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1) 435 return output
17class RaSTER(nn.Module): 18 """ Range-Separated Transformer with Equivariant Representations 19 20 FID : RASTER 21 22 """ 23 24 _graphs_properties: Dict 25 dim: int = 176 26 """The dimension of the output embedding.""" 27 nlayers: int = 2 28 """The number of message-passing layers.""" 29 att_dim: int = 16 30 """The dimension of the attention heads.""" 31 scal_heads: int = 16 32 """The number of scalar attention heads.""" 33 tens_heads: int = 4 34 """The number of tensor attention heads.""" 35 lmax: int = 3 36 """The maximum angular momentum to consider.""" 37 normalize_vec: bool = True 38 """Whether to normalize the vector features before computing spherical harmonics.""" 39 att_activation: str = "identity" 40 """The activation function to use for the attention coefficients.""" 41 activation: str = "swish" 42 """The activation function to use for the update network.""" 43 update_hidden: Sequence[int] = () 44 """The hidden layers for the update network.""" 45 update_bias: bool = True 46 """Whether to use bias in the update network.""" 47 positional_activation: str = "swish" 48 """The activation function to use for the positional embedding network.""" 49 positional_bias: bool = True 50 """Whether to use bias in the positional embedding network.""" 51 switch_before_net: bool = False 52 """Whether to apply the switch function to the radial basis before the edge neural network.""" 53 ignore_parity: bool = False 54 """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.""" 55 additive_positional: bool = False 56 """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.""" 57 edge_value: bool = False 58 """Whether to use edge values in the attention mechanism.""" 59 layer_normalization: bool = True 60 """Whether to use layer normalization of atomic embeddings.""" 61 layernorm_shift: bool = True 62 """Whether to shift the mean in layer normalization.""" 63 graph_key: str = "graph" 64 """ The key in the input dictionary that corresponds to the radial graph.""" 65 embedding_key: str = "embedding" 66 """ The key in the output dictionary that corresponds to the embedding.""" 67 radial_basis: dict = dataclasses.field( 68 default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16} 69 ) 70 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 71 species_encoding: str | dict = dataclasses.field( 72 default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"} 73 ) 74 """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 75 graph_lode: Optional[str] = None 76 """The key in the input dictionary that corresponds to the long-range graph.""" 77 lmax_lode: int = 0 78 """The maximum angular momentum for the long-range features.""" 79 lode_rshort: Optional[float] = None 80 """The short-range cutoff for the long-range features.""" 81 lode_dshort: float = 2.0 82 """The width of the short-range cutoff for the long-range features.""" 83 lode_extra_powers: Sequence[int] = () 84 """The extra powers to include in the long-range features.""" 85 a_lode: float = -1.0 86 """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).""" 87 block_index_key: Optional[str] = None 88 """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.""" 89 lode_channels: int = 1 90 """The number of channels for the long-range features.""" 91 switch_cov_start: float = 0.5 92 """The start of close-range covalent switch (in units of covalent radii).""" 93 switch_cov_end: float = 0.6 94 """The end of close-range covalent switch (in units of covalent radii).""" 95 normalize_keys: bool = False 96 """Whether to normalize queries and keys in the attention mechanism.""" 97 normalize_components: bool = False 98 """Whether to normalize the components before the update network.""" 99 keep_all_layers: bool = False 100 """Whether to return the stacked scalar embeddings from all message-passing layers.""" 101 kernel_init: Optional[str] = None 102 103 FID: ClassVar[str] = "RASTER" 104 105 @nn.compact 106 def __call__(self, inputs): 107 species = inputs["species"] 108 109 ## SETUP LAYER NORMALIZATION 110 if self.layernorm_shift: 111 def _layer_norm(x): 112 mu = jnp.mean(x, axis=-1, keepdims=True) 113 dx = x - mu 114 var = jnp.mean(dx**2, axis=-1, keepdims=True) 115 sig = (1.0e-6 + var) ** (-0.5) 116 return dx * sig 117 else: 118 def _layer_norm(x): 119 var = jnp.mean(x**2, axis=-1, keepdims=True) 120 sig = (1.0e-6 + var) ** (-0.5) 121 return x * sig 122 123 if self.layer_normalization: 124 layer_norm = _layer_norm 125 else: 126 layer_norm = lambda x: x 127 128 if self.normalize_keys: 129 ln_qk = _layer_norm 130 else: 131 ln_qk = lambda x: x 132 133 kernel_init = initializer_from_str(self.kernel_init) 134 135 ## SPECIES ENCODING 136 if isinstance(self.species_encoding, str): 137 Zi = inputs[self.species_encoding] 138 else: 139 Zi = SpeciesEncoding(**self.species_encoding)(species) 140 141 ## INITIALIZE SCALAR FEATURES 142 xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi)) 143 144 # RADIAL GRAPH 145 graph = inputs[self.graph_key] 146 distances = graph["distances"] 147 switch = graph["switch"] 148 edge_src = graph["edge_src"] 149 edge_dst = graph["edge_dst"] 150 vec = ( 151 graph["vec"] / graph["distances"][:, None] 152 if self.normalize_vec 153 else graph["vec"] 154 ) 155 ## CLOSE-RANGE SWITCH 156 use_switch_cov = False 157 if self.switch_cov_end > 0 and self.switch_cov_start > 0: 158 use_switch_cov = True 159 assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}" 160 assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1" 161 rc = jnp.array(D3_COV_RADII*au.ANG)[species] 162 rcij = rc[edge_src] + rc[edge_dst] 163 rstart = rcij * self.switch_cov_start 164 rend = rcij * self.switch_cov_end 165 switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend) 166 switch = switch * switch_short 167 168 ## COMPUTE SPHERICAL HARMONICS ON EDGES 169 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:] 170 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 171 ls = np.arange(self.lmax + 1).repeat(nrep) 172 173 parity = jnp.array((-1) ** ls[None,None,:]) 174 if self.ignore_parity: 175 parity = -jnp.ones_like(parity) 176 177 ## INITIALIZE TENSOR FEATURES 178 Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1])) 179 180 # RADIAL BASIS 181 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 182 radial_terms = RadialBasis( 183 **{ 184 **self.radial_basis, 185 "end": cutoff, 186 "name": f"RadialBasis", 187 } 188 )(distances) 189 if self.switch_before_net: 190 radial_terms = radial_terms * switch[:, None] 191 elif use_switch_cov: 192 radial_terms = radial_terms * switch_short[:, None] 193 194 ## INITIALIZE LODE 195 do_lode = self.graph_lode is not None 196 if do_lode: 197 ## LONG-RANGE GRAPH 198 graph_lode = inputs[self.graph_lode] 199 switch_lode = graph_lode["switch"][:, None] 200 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 201 r = graph_lode["distances"][:, None] 202 rc = self._graphs_properties[self.graph_lode]["cutoff"] 203 204 lmax_lr = self.lmax_lode 205 equivariant_lode = lmax_lr > 0 206 assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}" 207 assert ( 208 lmax_lr <= self.lmax 209 ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 210 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 211 if equivariant_lode: 212 ls_lr = np.arange(lmax_lr + 1) 213 else: 214 ls_lr = np.array([0]) 215 216 ## PARAMETERS FOR THE LR RADIAL BASIS 217 nextra_powers = len(self.lode_extra_powers) 218 if nextra_powers > 0: 219 ls_lr = np.concatenate([self.lode_extra_powers, ls_lr]) 220 221 if self.a_lode > 0: 222 a = self.a_lode**2 223 else: 224 a = ( 225 self.param( 226 "a_lr", 227 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[ 228 None, : 229 ], 230 ) 231 ** 2 232 ) 233 rc2a = rc**2 + a 234 ls_lr = 0.5 * (ls_lr[None, :] + 1) 235 ### minimal radial basis for long range (damped coulomb) 236 eij_lr = ( 237 1.0 / (r**2 + a) ** ls_lr 238 - 1.0 / rc2a**ls_lr 239 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 240 ) * switch_lode 241 242 if self.lode_rshort is not None: 243 rs = self.lode_rshort 244 d = self.lode_dshort 245 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 246 r < rs + d 247 ) + (r >= rs + d) 248 eij_lr = eij_lr * switch_short 249 250 dim_lr = 1 251 if nextra_powers > 0: 252 eij_lr_extra = eij_lr[:, :nextra_powers] 253 eij_lr = eij_lr[:, nextra_powers:] 254 dim_lr += nextra_powers 255 256 if equivariant_lode: 257 ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH 258 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 259 Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 260 graph_lode["vec"] / r 261 ) 262 dim_lr += lmax_lr 263 eij_lr = eij_lr * Yij_lr 264 del Yij_lr 265 266 267 if self.keep_all_layers: 268 xis = [] 269 270 ### START MESSAGE PASSING ITERATIONS 271 for layer in range(self.nlayers): 272 ## GATHER SCALAR EDGE FEATURES 273 u = [radial_terms] 274 if layer > 0: 275 ## edge-tensor contraction 276 xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij 277 for l in range(self.lmax + 1): 278 u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1)) 279 ur = jnp.concatenate(u, axis=-1) 280 281 ## BUILD RELATIVE POSITIONAL ENCODING 282 if self.edge_value: 283 nout = 2 284 else: 285 nout = 1 286 w = FullyConnectedNet( 287 [2 * self.att_dim, nout*self.att_dim], 288 activation=self.positional_activation, 289 use_bias=self.positional_bias, 290 name=f"positional_encoding_{layer}", 291 )(ur).reshape(radial_terms.shape[0],nout, self.att_dim) 292 if self.edge_value: 293 w,vij = jnp.split(w, 2, axis=1) 294 295 nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1) 296 297 298 ## QUERY, KEY, VALUE 299 q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)( 300 xi 301 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)) 302 k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)( 303 xi 304 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim) 305 306 v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape( 307 xi.shape[0], self.scal_heads, self.att_dim 308 ) 309 310 ## ATTENTION COEFFICIENTS 311 if self.additive_positional: 312 wk = ln_qk(w + k[edge_dst]) 313 else: 314 wk = ln_qk(w * k[edge_dst]) 315 316 act = activation_from_str(self.att_activation) 317 aij = ( 318 act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5)) 319 * switch[:, None] 320 ) 321 322 aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 323 if layer > 0: 324 aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 325 aij = aij[:, self.tens_heads*nls:, None] 326 327 if self.edge_value: 328 ## EDGE VALUES 329 if self.additive_positional: 330 vij = vij + v[edge_dst] 331 else: 332 vij = vij * v[edge_dst] 333 else: 334 ## MOVE DEST VALUES TO EDGE 335 vij = v[edge_dst] 336 337 ## SCALAR ATTENDED FEATURES 338 vai = jax.ops.segment_sum( 339 aij * vij, 340 edge_src, 341 num_segments=xi.shape[0], 342 ) 343 vai = vai.reshape(xi.shape[0], -1) 344 345 ### TENSOR ATTENDED FEATURES 346 uij = aijl * Yij 347 if layer > 0: 348 uij = uij + aijl1 * Vi[edge_dst] 349 Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0]) 350 351 ## SELF SCALAR FEATURES 352 si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi) 353 354 components = [si, vai] 355 356 ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS 357 if self.tens_heads == 1: 358 Vi2 = Vi**2 359 else: 360 Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi) 361 for l in range(self.lmax + 1): 362 norm = 1.0 / (2 * l + 1) 363 components.append( 364 (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm 365 ) 366 367 ### LODE (~ LONG-RANGE ATTENTION) 368 if do_lode and layer == self.nlayers - 1: 369 assert self.lode_channels <= self.tens_heads 370 zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape( 371 xi.shape[0], self.lode_channels, dim_lr 372 ) 373 if nextra_powers > 0: 374 zj_extra = zj[:,:, :nextra_powers] 375 zj = zj[:, :, nextra_powers:] 376 xi_lr_extra = jax.ops.segment_sum( 377 eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr], 378 edge_src_lr, 379 species.shape[0], 380 ).reshape(species.shape[0],-1) 381 components.append(xi_lr_extra) 382 if equivariant_lode: 383 zj = zj.repeat(nrep_lr, axis=-1) 384 Vi_lr = jax.ops.segment_sum( 385 eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0] 386 ) 387 components.append(Vi_lr[:,: , 0]) 388 if equivariant_lode: 389 Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr 390 for l in range(1, lmax_lr + 1): 391 norm = 1.0 / (2 * l + 1) 392 components.append( 393 Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1) 394 * norm 395 ) 396 397 ### CONCATENATE UPDATE COMPONENTS 398 components = jnp.concatenate(components, axis=-1) 399 if self.normalize_components: 400 components = _layer_norm(components) 401 ### COMPUTE UPDATE 402 if self.block_index_key is not None: 403 ## MoE neural network from block index 404 block_index = inputs[self.block_index_key] 405 updi = BlockIndexNet( 406 output_dim=self.dim + self.tens_heads*(self.lmax + 1), 407 hidden_neurons=self.update_hidden, 408 activation=self.activation, 409 use_bias=self.update_bias, 410 name=f"update_net_{layer}", 411 kernel_init=kernel_init, 412 )((species,components, block_index)) 413 else: 414 updi = FullyConnectedNet( 415 [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)], 416 activation=self.activation, 417 use_bias=self.update_bias, 418 name=f"update_net_{layer}", 419 kernel_init=kernel_init, 420 )(components) 421 422 ## UPDATE ATOM FEATURES 423 xi = layer_norm(xi + updi[:,:self.dim]) 424 Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 425 if self.tens_heads > 1: 426 Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi) 427 428 if self.keep_all_layers: 429 ## STORE ALL LAYERS 430 xis.append(xi) 431 432 433 output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi} 434 if self.keep_all_layers: 435 output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1) 436 return output
Range-Separated Transformer with Equivariant Representations
FID : RASTER
Whether to normalize the vector features before computing spherical harmonics.
The activation function to use for the positional embedding network.
Whether to apply the switch function to the radial basis before the edge neural network.
Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.
Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.
The key in the output dictionary that corresponds to the embedding.
The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.
The dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding.
The key in the input dictionary that corresponds to the long-range graph.
The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).
The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.
The start of close-range covalent switch (in units of covalent radii).
Whether to return the stacked scalar embeddings from all message-passing layers.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.