fennol.models.physics.dispersion
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import numpy as np 5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 6from ...utils.atomic_units import au 7from ...utils.periodic_table import ( 8 POLARIZABILITIES, 9 C6_FREE, 10) 11import pathlib 12import pickle 13 14class VdwOQDO(nn.Module): 15 """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model. 16 17 FID : VDW_OQDO 18 19 ### Reference 20 A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, 21 J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797) 22 23 """ 24 graph_key: str = "graph" 25 """ The key for the graph input.""" 26 include_exchange: bool = True 27 """ Whether to compute the exchange part.""" 28 ratiovol_key: Optional[str] = None 29 """ The key for the ratio between AIM volume and free-atom volume. 30 If None, the volume ratio is assumed to be 1.0.""" 31 energy_key: Optional[str] = None 32 """ The key for the output energy. If None, the name of the module is used.""" 33 damped: bool = True 34 """ Whether to use short-range damping.""" 35 _energy_unit: str = "Ha" 36 """The energy unit of the model. **Automatically set by FENNIX**""" 37 38 FID: ClassVar[str] = "VDW_OQDO" 39 40 @nn.compact 41 def __call__(self, inputs): 42 energy_unit = au.get_multiplier(self._energy_unit) 43 44 species = inputs["species"] 45 graph = inputs[self.graph_key] 46 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 47 rij = graph["distances"] / au.ANG 48 switch = graph["switch"] 49 50 c6 = jnp.asarray(C6_FREE)[species] 51 alpha = jnp.asarray(POLARIZABILITIES)[species] 52 53 if self.ratiovol_key is not None: 54 ratiovol = inputs[self.ratiovol_key] + 1.0e-6 55 if ratiovol.shape[-1] == 1: 56 ratiovol = jnp.squeeze(ratiovol, axis=-1) 57 c6 = c6 * ratiovol**2 58 alpha = alpha * ratiovol 59 60 c6i, c6j = c6[edge_src], c6[edge_dst] 61 alphai, alphaj = alpha[edge_src], alpha[edge_dst] 62 63 # combination rules 64 alphaij = 0.5 * (alphai + alphaj) 65 c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2) 66 67 # equilibrium distance 68 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 69 Re2 = Re**2 70 Re4 = Re**4 71 # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators" 72 if self.damped: 73 muw = ( 74 4.83053463e-01 75 - 3.76191669e-02 * Re 76 + 1.27066988e-03 * Re2 77 - 7.21940151e-07 * Re4 78 ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 79 else: 80 muw = ( 81 3.66316787e01 82 - 5.79579187 * Re 83 + 3.02674813e-01 * Re2 84 - 3.65461255e-04 * Re4 85 ) / (-1.46169102e01 + 7.32461225 * Re) 86 87 c8ij = 5 * c6ij / muw 88 c10ij = 245 * c6ij / (8 * muw**2) 89 90 if self.damped: 91 z = 0.5 * muw * rij**2 92 ez = jnp.exp(-z) 93 f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3) 94 f8 = f6 - (1.0 / 24.0) * ez * z**4 95 f10 = f8 - (1.0 / 120.0) * ez * z**5 96 epair = ( 97 f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10 98 ) 99 else: 100 epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10 101 102 edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0]) 103 104 output_key = self.name if self.energy_key is None else self.energy_key 105 106 if not self.include_exchange: 107 return {**inputs, output_key: edisp} 108 109 ### exchange 110 w = 4 * c6ij / (3 * alphaij**2) 111 # q = (alphaij * mu*w**2)**0.5 112 q2 = alphaij * muw * w 113 # undamped case 114 if self.damped: 115 ze = 0.5 * muw * Re2 116 eze = jnp.exp(-ze) 117 118 s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3) 119 f6e = 1.0 - s6 120 muwRe = muw * Re 121 df6e = muwRe * s6 - eze * ( 122 muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3 123 ) 124 125 s8 = (1.0 / 24.0) * eze * ze**4 126 f8e = f6e - s8 127 df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4 128 129 s10 = (1.0 / 120.0) * eze * ze**5 130 f10e = f8e - s10 131 df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5 132 133 den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e) 134 A = ( 135 0.5 136 + c8ij * (8 * f8e - Re * df8e) / den 137 + c10ij * (10 * f10e - Re * df10e) / (den * Re2) 138 ) 139 else: 140 A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4) 141 ez = jnp.exp(-0.5 * muw * rij**2) 142 143 exij = A * q2 * ez / rij 144 ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0]) 145 146 return { 147 **inputs, 148 output_key + "_dispersion": edisp, 149 output_key + "_exchange": ex, 150 output_key: edisp + ex, 151 } 152 153 154class DispersionD3(nn.Module): 155 """ 156 157 FID : DISPERSION_D3 158 159 """ 160 161 _graphs_properties: Dict 162 graph_key: str = "graph" 163 output_key: Optional[str] = None 164 s6: float = 1.0 165 s8: float = 1.0 166 a1: float = 0.4 167 a2: float = 5.0 168 trainable: bool = False 169 _energy_unit: str = "Ha" 170 """The energy unit of the model. **Automatically set by FENNIX**""" 171 172 FID: ClassVar[str] = "DISPERSION_D3" 173 174 @nn.compact 175 def __call__(self, inputs): 176 177 path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl" 178 with open(path, "rb") as f: 179 DATA_D3 = pickle.load(f) 180 181 graph = inputs[self.graph_key] 182 edge_src = graph["edge_src"] 183 edge_dst = graph["edge_dst"] 184 switch = graph["switch"] 185 species = inputs["species"] 186 187 rij = jnp.clip(graph["distances"] / au.ANG, 1e-6, None) 188 189 ## RADII (in BOHR) 190 rcov = jnp.array(DATA_D3["COV_D3"])[species] 191 # rvdw = jnp.array(DATA_D3["VDW_D3"])[species] 192 r4r2 = jnp.array(DATA_D3["R4R2"])[species] 193 194 rcij = rcov[edge_src] + rcov[edge_dst] 195 196 ## COORDINATION NUMBER 197 KCN = 16.0 198 cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0)) 199 cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0]) 200 201 ## WEIGHTS 202 refcn = jnp.array(DATA_D3["REF_CN"])[species] 203 mask = refcn >= 0 204 dcn = refcn - cn[:, None] 205 KW = 4.0 206 weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0) 207 norm = weights.sum(axis=1, keepdims=True) 208 weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0) 209 210 ## correct for all null weights 211 imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1) 212 exweight = np.zeros_like(DATA_D3["REF_CN"]) 213 for i, imax in enumerate(imaxcn): 214 exweight[i, imax] = 1.0 215 exweight = jnp.array(exweight)[species] 216 217 exceptional = norm < 1.0e-6 218 weights = jnp.where(exceptional, exweight, weights) 219 220 ## C6 coefficients 221 REF_C6 = DATA_D3["REF_C6"] 222 nz = REF_C6.shape[0] 223 nref = REF_C6.shape[-1] 224 REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref))) 225 pair_num = species[edge_src] * nz + species[edge_dst] 226 rc6 = REF_C6[pair_num] 227 c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst]) 228 229 ## DISPERSION 230 231 qq = 3 * r4r2[edge_src] * r4r2[edge_dst] 232 c8 = c6 * qq 233 234 if self.trainable: 235 s6 = jnp.abs( 236 self.param( 237 "s6", 238 lambda key: jnp.array(self.s6, dtype=c6.dtype), 239 ) 240 ) 241 s8 = jnp.abs( 242 self.param( 243 "s8", 244 lambda key: jnp.array(self.s8, dtype=c6.dtype), 245 ) 246 ) 247 a1 = jnp.abs( 248 self.param( 249 "a1", 250 lambda key: jnp.array(self.a1, dtype=c6.dtype), 251 ) 252 ) 253 a2 = jnp.abs( 254 self.param( 255 "a2", 256 lambda key: jnp.array(self.a2, dtype=c6.dtype), 257 ) 258 ) 259 else: 260 s6 = self.s6 261 s8 = self.s8 262 a1 = self.a1 263 a2 = self.a2 264 265 r0 = a1 * jnp.sqrt(qq) + a2 266 267 t6 = s6 / (rij**6 + r0**6) 268 t8 = s8 / (rij**8 + r0**8) 269 270 energy_unit = au.get_multiplier(self._energy_unit) 271 energy = (-0.5*energy_unit) * jax.ops.segment_sum( 272 (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0] 273 ) 274 275 output_key = self.output_key if self.output_key is not None else self.name 276 return {**inputs, output_key: energy}
15class VdwOQDO(nn.Module): 16 """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model. 17 18 FID : VDW_OQDO 19 20 ### Reference 21 A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, 22 J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797) 23 24 """ 25 graph_key: str = "graph" 26 """ The key for the graph input.""" 27 include_exchange: bool = True 28 """ Whether to compute the exchange part.""" 29 ratiovol_key: Optional[str] = None 30 """ The key for the ratio between AIM volume and free-atom volume. 31 If None, the volume ratio is assumed to be 1.0.""" 32 energy_key: Optional[str] = None 33 """ The key for the output energy. If None, the name of the module is used.""" 34 damped: bool = True 35 """ Whether to use short-range damping.""" 36 _energy_unit: str = "Ha" 37 """The energy unit of the model. **Automatically set by FENNIX**""" 38 39 FID: ClassVar[str] = "VDW_OQDO" 40 41 @nn.compact 42 def __call__(self, inputs): 43 energy_unit = au.get_multiplier(self._energy_unit) 44 45 species = inputs["species"] 46 graph = inputs[self.graph_key] 47 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 48 rij = graph["distances"] / au.ANG 49 switch = graph["switch"] 50 51 c6 = jnp.asarray(C6_FREE)[species] 52 alpha = jnp.asarray(POLARIZABILITIES)[species] 53 54 if self.ratiovol_key is not None: 55 ratiovol = inputs[self.ratiovol_key] + 1.0e-6 56 if ratiovol.shape[-1] == 1: 57 ratiovol = jnp.squeeze(ratiovol, axis=-1) 58 c6 = c6 * ratiovol**2 59 alpha = alpha * ratiovol 60 61 c6i, c6j = c6[edge_src], c6[edge_dst] 62 alphai, alphaj = alpha[edge_src], alpha[edge_dst] 63 64 # combination rules 65 alphaij = 0.5 * (alphai + alphaj) 66 c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2) 67 68 # equilibrium distance 69 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 70 Re2 = Re**2 71 Re4 = Re**4 72 # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators" 73 if self.damped: 74 muw = ( 75 4.83053463e-01 76 - 3.76191669e-02 * Re 77 + 1.27066988e-03 * Re2 78 - 7.21940151e-07 * Re4 79 ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 80 else: 81 muw = ( 82 3.66316787e01 83 - 5.79579187 * Re 84 + 3.02674813e-01 * Re2 85 - 3.65461255e-04 * Re4 86 ) / (-1.46169102e01 + 7.32461225 * Re) 87 88 c8ij = 5 * c6ij / muw 89 c10ij = 245 * c6ij / (8 * muw**2) 90 91 if self.damped: 92 z = 0.5 * muw * rij**2 93 ez = jnp.exp(-z) 94 f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3) 95 f8 = f6 - (1.0 / 24.0) * ez * z**4 96 f10 = f8 - (1.0 / 120.0) * ez * z**5 97 epair = ( 98 f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10 99 ) 100 else: 101 epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10 102 103 edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0]) 104 105 output_key = self.name if self.energy_key is None else self.energy_key 106 107 if not self.include_exchange: 108 return {**inputs, output_key: edisp} 109 110 ### exchange 111 w = 4 * c6ij / (3 * alphaij**2) 112 # q = (alphaij * mu*w**2)**0.5 113 q2 = alphaij * muw * w 114 # undamped case 115 if self.damped: 116 ze = 0.5 * muw * Re2 117 eze = jnp.exp(-ze) 118 119 s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3) 120 f6e = 1.0 - s6 121 muwRe = muw * Re 122 df6e = muwRe * s6 - eze * ( 123 muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3 124 ) 125 126 s8 = (1.0 / 24.0) * eze * ze**4 127 f8e = f6e - s8 128 df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4 129 130 s10 = (1.0 / 120.0) * eze * ze**5 131 f10e = f8e - s10 132 df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5 133 134 den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e) 135 A = ( 136 0.5 137 + c8ij * (8 * f8e - Re * df8e) / den 138 + c10ij * (10 * f10e - Re * df10e) / (den * Re2) 139 ) 140 else: 141 A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4) 142 ez = jnp.exp(-0.5 * muw * rij**2) 143 144 exij = A * q2 * ez / rij 145 ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0]) 146 147 return { 148 **inputs, 149 output_key + "_dispersion": edisp, 150 output_key + "_exchange": ex, 151 output_key: edisp + ex, 152 }
Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
FID : VDW_OQDO
Reference
A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
The key for the ratio between AIM volume and free-atom volume. If None, the volume ratio is assumed to be 1.0.
The key for the output energy. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
155class DispersionD3(nn.Module): 156 """ 157 158 FID : DISPERSION_D3 159 160 """ 161 162 _graphs_properties: Dict 163 graph_key: str = "graph" 164 output_key: Optional[str] = None 165 s6: float = 1.0 166 s8: float = 1.0 167 a1: float = 0.4 168 a2: float = 5.0 169 trainable: bool = False 170 _energy_unit: str = "Ha" 171 """The energy unit of the model. **Automatically set by FENNIX**""" 172 173 FID: ClassVar[str] = "DISPERSION_D3" 174 175 @nn.compact 176 def __call__(self, inputs): 177 178 path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl" 179 with open(path, "rb") as f: 180 DATA_D3 = pickle.load(f) 181 182 graph = inputs[self.graph_key] 183 edge_src = graph["edge_src"] 184 edge_dst = graph["edge_dst"] 185 switch = graph["switch"] 186 species = inputs["species"] 187 188 rij = jnp.clip(graph["distances"] / au.ANG, 1e-6, None) 189 190 ## RADII (in BOHR) 191 rcov = jnp.array(DATA_D3["COV_D3"])[species] 192 # rvdw = jnp.array(DATA_D3["VDW_D3"])[species] 193 r4r2 = jnp.array(DATA_D3["R4R2"])[species] 194 195 rcij = rcov[edge_src] + rcov[edge_dst] 196 197 ## COORDINATION NUMBER 198 KCN = 16.0 199 cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0)) 200 cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0]) 201 202 ## WEIGHTS 203 refcn = jnp.array(DATA_D3["REF_CN"])[species] 204 mask = refcn >= 0 205 dcn = refcn - cn[:, None] 206 KW = 4.0 207 weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0) 208 norm = weights.sum(axis=1, keepdims=True) 209 weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0) 210 211 ## correct for all null weights 212 imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1) 213 exweight = np.zeros_like(DATA_D3["REF_CN"]) 214 for i, imax in enumerate(imaxcn): 215 exweight[i, imax] = 1.0 216 exweight = jnp.array(exweight)[species] 217 218 exceptional = norm < 1.0e-6 219 weights = jnp.where(exceptional, exweight, weights) 220 221 ## C6 coefficients 222 REF_C6 = DATA_D3["REF_C6"] 223 nz = REF_C6.shape[0] 224 nref = REF_C6.shape[-1] 225 REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref))) 226 pair_num = species[edge_src] * nz + species[edge_dst] 227 rc6 = REF_C6[pair_num] 228 c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst]) 229 230 ## DISPERSION 231 232 qq = 3 * r4r2[edge_src] * r4r2[edge_dst] 233 c8 = c6 * qq 234 235 if self.trainable: 236 s6 = jnp.abs( 237 self.param( 238 "s6", 239 lambda key: jnp.array(self.s6, dtype=c6.dtype), 240 ) 241 ) 242 s8 = jnp.abs( 243 self.param( 244 "s8", 245 lambda key: jnp.array(self.s8, dtype=c6.dtype), 246 ) 247 ) 248 a1 = jnp.abs( 249 self.param( 250 "a1", 251 lambda key: jnp.array(self.a1, dtype=c6.dtype), 252 ) 253 ) 254 a2 = jnp.abs( 255 self.param( 256 "a2", 257 lambda key: jnp.array(self.a2, dtype=c6.dtype), 258 ) 259 ) 260 else: 261 s6 = self.s6 262 s8 = self.s8 263 a1 = self.a1 264 a2 = self.a2 265 266 r0 = a1 * jnp.sqrt(qq) + a2 267 268 t6 = s6 / (rij**6 + r0**6) 269 t8 = s8 / (rij**8 + r0**8) 270 271 energy_unit = au.get_multiplier(self._energy_unit) 272 energy = (-0.5*energy_unit) * jax.ops.segment_sum( 273 (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0] 274 ) 275 276 output_key = self.output_key if self.output_key is not None else self.name 277 return {**inputs, output_key: energy}
FID : DISPERSION_D3
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.