fennol.models.physics.repulsion
1import pathlib 2import jax 3import jax.numpy as jnp 4import flax.linen as nn 5import numpy as np 6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 7from ...utils.atomic_units import au 8 9 10class RepulsionZBL(nn.Module): 11 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 12 13 FID: REPULSION_ZBL 14 15 ### Reference 16 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 17 18 """ 19 20 _graphs_properties: Dict 21 graph_key: str = "graph" 22 """The key for the graph input.""" 23 energy_key: Optional[str] = None 24 """The key for the output energy.""" 25 trainable: bool = True 26 """Whether the parameters are trainable.""" 27 _energy_unit: str = "Ha" 28 """The energy unit of the model. **Automatically set by FENNIX**""" 29 proportional_regularization: bool = True 30 d: float = 0.46850 31 p: float = 0.23 32 alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162) 33 cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697) 34 cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514) 35 36 FID: ClassVar[str] = "REPULSION_ZBL" 37 38 @nn.compact 39 def __call__(self, inputs): 40 species = inputs["species"] 41 graph = inputs[self.graph_key] 42 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 43 44 training = "training" in inputs.get("flags", {}) 45 46 rijs = graph["distances"] 47 d_ = self.d 48 p_ = self.p 49 assert len(self.alphas) == 4, "alphas must be a sequence of length 4" 50 alphas_ = np.array(self.alphas, dtype=rijs.dtype) 51 assert len(self.cs) == 4, "cs must be a sequence of length 4" 52 cs_ = np.array(self.cs, dtype=rijs.dtype) 53 cs_ = 0.5 * cs_ / np.sum(cs_) 54 if self.trainable: 55 d = jnp.abs( 56 self.param( 57 "d", 58 lambda key, d: jnp.asarray(d/au.ANG, dtype=rijs.dtype), 59 d_, 60 ) 61 )*au.ANG 62 p = jnp.abs( 63 self.param( 64 "p", 65 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 66 p_, 67 ) 68 ) 69 cs = 0.5 * jax.nn.softmax( 70 self.param( 71 "cs", 72 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 73 np.array(self.cs_logits, dtype=rijs.dtype), 74 ) 75 ) 76 alphas = jnp.abs( 77 self.param( 78 "alphas", 79 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 80 alphas_, 81 ) 82 ) 83 84 if training: 85 if self.proportional_regularization: 86 reg = jnp.asarray( 87 ((1 - alphas / alphas_) ** 2).sum() 88 + ((1 - cs / cs_) ** 2).sum() 89 + (1 - p / p_) ** 2 90 + (1 - d / d_) ** 2 91 ).reshape(1) 92 else: 93 reg = jnp.asarray( 94 ((alphas_ - alphas) ** 2).sum() 95 + ((cs_ - cs) ** 2).sum() 96 + (p_ - p) ** 2 97 + (d_ - d) ** 2 98 ).reshape(1) 99 else: 100 cs = jnp.asarray(cs_) 101 alphas = jnp.asarray(alphas_) 102 d = d_ 103 p = p_ 104 105 if "alch_group" in inputs: 106 switch = graph["switch_raw"] 107 lambda_v = inputs["alch_vlambda"] 108 alch_group = inputs["alch_group"] 109 alch_m = inputs.get("alch_m", 2) 110 111 mask = alch_group[edge_src] == alch_group[edge_dst] 112 113 if "alch_softcore_rep" in inputs: 114 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 115 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 116 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 117 switch = jnp.where( 118 mask, 119 switch, 120 (lambda_v**alch_m) * switch, 121 ) 122 else: 123 switch = graph["switch"] 124 125 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 126 Zij = Z[edge_src] * Z[edge_dst] 127 Zp = Z**p / d 128 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 129 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 130 131 ereppair = Zij * phi / rijs * switch 132 133 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 134 135 energy_unit = au.ANG*au.get_multiplier(self._energy_unit) 136 energy_key = self.energy_key if self.energy_key is not None else self.name 137 output = {**inputs, energy_key: erep_atomic * energy_unit} 138 if self.trainable and training: 139 output[energy_key + "_regularization"] = reg 140 141 return output 142 143 144class RepulsionNLH(nn.Module): 145 """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92 146 147 FID: REPULSION_NLH 148 149 ### Reference 150 K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 151 https://doi.org/10.1103/PhysRevA.111.032818 152 """ 153 154 _graphs_properties: Dict 155 graph_key: str = "graph" 156 """The key for the graph input.""" 157 energy_key: Optional[str] = None 158 """The key for the output energy.""" 159 _energy_unit: str = "Ha" 160 """The energy unit of the model. **Automatically set by FENNIX**""" 161 trainable: bool = False 162 direct_forces_key: Optional[str] = None 163 164 FID: ClassVar[str] = "REPULSION_NLH" 165 166 @nn.compact 167 def __call__(self, inputs): 168 169 path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat" 170 DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8)) 171 zmax = int(np.max(DATA_NLH[:, 0])) 172 AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32) 173 for i in range(DATA_NLH.shape[0]): 174 z1 = int(DATA_NLH[i, 0]) 175 z2 = int(DATA_NLH[i, 1]) 176 AB[z1 + zmax * z2] = DATA_NLH[i, 2:8] 177 AB[z2 + zmax * z1] = DATA_NLH[i, 2:8] 178 AB = AB.reshape((zmax + 1) ** 2, 3, 2) 179 180 species = inputs["species"] 181 graph = inputs[self.graph_key] 182 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 183 rijs = graph["distances"] 184 185 # coefficients (a1,a2,a3) 186 CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype) 187 # exponents (b1,b2,b3) 188 ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype) 189 190 if self.trainable: 191 cfact = jnp.abs( 192 self.param( 193 "c_fact", 194 lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype), 195 ) 196 ) 197 CS = CS * cfact[None, :] 198 CS = CS / jnp.sum(CS, axis=1, keepdims=True) 199 alphas_fact = jnp.abs( 200 self.param( 201 "alpha_fact", 202 lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype), 203 ) 204 ) 205 ALPHAS = ALPHAS * alphas_fact[None, :] 206 207 s12 = species[edge_src] + zmax * species[edge_dst] 208 cs = CS[s12] 209 alphas = ALPHAS[s12] 210 211 if "alch_group" in inputs: 212 switch = graph["switch_raw"] 213 lambda_v = inputs["alch_vlambda"] 214 alch_group = inputs["alch_group"] 215 alch_m = inputs.get("alch_m", 2) 216 217 mask = alch_group[edge_src] == alch_group[edge_dst] 218 if "alch_softcore_rep" in inputs: 219 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 220 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 221 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 222 switch = jnp.where( 223 mask, 224 switch, 225 (lambda_v**alch_m) * switch, 226 ) 227 # alphas = jnp.where( 228 # mask[:,None], 229 # alphas, 230 # lambda_v * alphas , 231 # ) 232 else: 233 switch = graph["switch"] 234 235 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 236 phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 237 Zij = Z[edge_src] * Z[edge_dst] * switch 238 239 ereppair = Zij * phi / rijs 240 241 energy_unit = au.get_multiplier(self._energy_unit) 242 erep_atomic = (energy_unit * 0.5 * au.ANG) * jax.ops.segment_sum( 243 ereppair, edge_src, species.shape[0] 244 ) 245 246 energy_key = self.energy_key if self.energy_key is not None else self.name 247 output = {**inputs, energy_key: erep_atomic} 248 249 if self.direct_forces_key is not None: 250 dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 251 dedr = Zij * (dphidr / rijs - phi / (rijs**2)) 252 dedij = (dedr / rijs)[:, None] * graph["vec"] 253 fi = (energy_unit * au.ANG) * jax.ops.segment_sum( 254 dedij, edge_src, species.shape[0] 255 ) 256 output[self.direct_forces_key] = fi 257 258 return output
11class RepulsionZBL(nn.Module): 12 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 13 14 FID: REPULSION_ZBL 15 16 ### Reference 17 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 18 19 """ 20 21 _graphs_properties: Dict 22 graph_key: str = "graph" 23 """The key for the graph input.""" 24 energy_key: Optional[str] = None 25 """The key for the output energy.""" 26 trainable: bool = True 27 """Whether the parameters are trainable.""" 28 _energy_unit: str = "Ha" 29 """The energy unit of the model. **Automatically set by FENNIX**""" 30 proportional_regularization: bool = True 31 d: float = 0.46850 32 p: float = 0.23 33 alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162) 34 cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697) 35 cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514) 36 37 FID: ClassVar[str] = "REPULSION_ZBL" 38 39 @nn.compact 40 def __call__(self, inputs): 41 species = inputs["species"] 42 graph = inputs[self.graph_key] 43 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 44 45 training = "training" in inputs.get("flags", {}) 46 47 rijs = graph["distances"] 48 d_ = self.d 49 p_ = self.p 50 assert len(self.alphas) == 4, "alphas must be a sequence of length 4" 51 alphas_ = np.array(self.alphas, dtype=rijs.dtype) 52 assert len(self.cs) == 4, "cs must be a sequence of length 4" 53 cs_ = np.array(self.cs, dtype=rijs.dtype) 54 cs_ = 0.5 * cs_ / np.sum(cs_) 55 if self.trainable: 56 d = jnp.abs( 57 self.param( 58 "d", 59 lambda key, d: jnp.asarray(d/au.ANG, dtype=rijs.dtype), 60 d_, 61 ) 62 )*au.ANG 63 p = jnp.abs( 64 self.param( 65 "p", 66 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 67 p_, 68 ) 69 ) 70 cs = 0.5 * jax.nn.softmax( 71 self.param( 72 "cs", 73 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 74 np.array(self.cs_logits, dtype=rijs.dtype), 75 ) 76 ) 77 alphas = jnp.abs( 78 self.param( 79 "alphas", 80 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 81 alphas_, 82 ) 83 ) 84 85 if training: 86 if self.proportional_regularization: 87 reg = jnp.asarray( 88 ((1 - alphas / alphas_) ** 2).sum() 89 + ((1 - cs / cs_) ** 2).sum() 90 + (1 - p / p_) ** 2 91 + (1 - d / d_) ** 2 92 ).reshape(1) 93 else: 94 reg = jnp.asarray( 95 ((alphas_ - alphas) ** 2).sum() 96 + ((cs_ - cs) ** 2).sum() 97 + (p_ - p) ** 2 98 + (d_ - d) ** 2 99 ).reshape(1) 100 else: 101 cs = jnp.asarray(cs_) 102 alphas = jnp.asarray(alphas_) 103 d = d_ 104 p = p_ 105 106 if "alch_group" in inputs: 107 switch = graph["switch_raw"] 108 lambda_v = inputs["alch_vlambda"] 109 alch_group = inputs["alch_group"] 110 alch_m = inputs.get("alch_m", 2) 111 112 mask = alch_group[edge_src] == alch_group[edge_dst] 113 114 if "alch_softcore_rep" in inputs: 115 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 116 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 117 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 118 switch = jnp.where( 119 mask, 120 switch, 121 (lambda_v**alch_m) * switch, 122 ) 123 else: 124 switch = graph["switch"] 125 126 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 127 Zij = Z[edge_src] * Z[edge_dst] 128 Zp = Z**p / d 129 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 130 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 131 132 ereppair = Zij * phi / rijs * switch 133 134 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 135 136 energy_unit = au.ANG*au.get_multiplier(self._energy_unit) 137 energy_key = self.energy_key if self.energy_key is not None else self.name 138 output = {**inputs, energy_key: erep_atomic * energy_unit} 139 if self.trainable and training: 140 output[energy_key + "_regularization"] = reg 141 142 return output
Repulsion energy based on the Ziegler-Biersack-Littmark potential
FID: REPULSION_ZBL
Reference
J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
145class RepulsionNLH(nn.Module): 146 """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92 147 148 FID: REPULSION_NLH 149 150 ### Reference 151 K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 152 https://doi.org/10.1103/PhysRevA.111.032818 153 """ 154 155 _graphs_properties: Dict 156 graph_key: str = "graph" 157 """The key for the graph input.""" 158 energy_key: Optional[str] = None 159 """The key for the output energy.""" 160 _energy_unit: str = "Ha" 161 """The energy unit of the model. **Automatically set by FENNIX**""" 162 trainable: bool = False 163 direct_forces_key: Optional[str] = None 164 165 FID: ClassVar[str] = "REPULSION_NLH" 166 167 @nn.compact 168 def __call__(self, inputs): 169 170 path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat" 171 DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8)) 172 zmax = int(np.max(DATA_NLH[:, 0])) 173 AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32) 174 for i in range(DATA_NLH.shape[0]): 175 z1 = int(DATA_NLH[i, 0]) 176 z2 = int(DATA_NLH[i, 1]) 177 AB[z1 + zmax * z2] = DATA_NLH[i, 2:8] 178 AB[z2 + zmax * z1] = DATA_NLH[i, 2:8] 179 AB = AB.reshape((zmax + 1) ** 2, 3, 2) 180 181 species = inputs["species"] 182 graph = inputs[self.graph_key] 183 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 184 rijs = graph["distances"] 185 186 # coefficients (a1,a2,a3) 187 CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype) 188 # exponents (b1,b2,b3) 189 ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype) 190 191 if self.trainable: 192 cfact = jnp.abs( 193 self.param( 194 "c_fact", 195 lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype), 196 ) 197 ) 198 CS = CS * cfact[None, :] 199 CS = CS / jnp.sum(CS, axis=1, keepdims=True) 200 alphas_fact = jnp.abs( 201 self.param( 202 "alpha_fact", 203 lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype), 204 ) 205 ) 206 ALPHAS = ALPHAS * alphas_fact[None, :] 207 208 s12 = species[edge_src] + zmax * species[edge_dst] 209 cs = CS[s12] 210 alphas = ALPHAS[s12] 211 212 if "alch_group" in inputs: 213 switch = graph["switch_raw"] 214 lambda_v = inputs["alch_vlambda"] 215 alch_group = inputs["alch_group"] 216 alch_m = inputs.get("alch_m", 2) 217 218 mask = alch_group[edge_src] == alch_group[edge_dst] 219 if "alch_softcore_rep" in inputs: 220 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 221 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 222 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 223 switch = jnp.where( 224 mask, 225 switch, 226 (lambda_v**alch_m) * switch, 227 ) 228 # alphas = jnp.where( 229 # mask[:,None], 230 # alphas, 231 # lambda_v * alphas , 232 # ) 233 else: 234 switch = graph["switch"] 235 236 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 237 phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 238 Zij = Z[edge_src] * Z[edge_dst] * switch 239 240 ereppair = Zij * phi / rijs 241 242 energy_unit = au.get_multiplier(self._energy_unit) 243 erep_atomic = (energy_unit * 0.5 * au.ANG) * jax.ops.segment_sum( 244 ereppair, edge_src, species.shape[0] 245 ) 246 247 energy_key = self.energy_key if self.energy_key is not None else self.name 248 output = {**inputs, energy_key: erep_atomic} 249 250 if self.direct_forces_key is not None: 251 dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 252 dedr = Zij * (dphidr / rijs - phi / (rijs**2)) 253 dedij = (dedr / rijs)[:, None] * graph["vec"] 254 fi = (energy_unit * au.ANG) * jax.ops.segment_sum( 255 dedij, edge_src, species.shape[0] 256 ) 257 output[self.direct_forces_key] = fi 258 259 return output
NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
FID: REPULSION_NLH
Reference
K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 https://doi.org/10.1103/PhysRevA.111.032818
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.