fennol.models.physics.repulsion

  1import pathlib
  2import jax
  3import jax.numpy as jnp
  4import flax.linen as nn
  5import numpy as np
  6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  7from ...utils.atomic_units import au
  8
  9
 10class RepulsionZBL(nn.Module):
 11    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 12
 13    FID: REPULSION_ZBL
 14
 15    ### Reference
 16    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 17
 18    """
 19
 20    _graphs_properties: Dict
 21    graph_key: str = "graph"
 22    """The key for the graph input."""
 23    energy_key: Optional[str] = None
 24    """The key for the output energy."""
 25    trainable: bool = True
 26    """Whether the parameters are trainable."""
 27    _energy_unit: str = "Ha"
 28    """The energy unit of the model. **Automatically set by FENNIX**"""
 29    proportional_regularization: bool = True
 30    d: float = 0.46850
 31    p: float = 0.23
 32    alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162)
 33    cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
 34    cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514)
 35
 36    FID: ClassVar[str] = "REPULSION_ZBL"
 37
 38    @nn.compact
 39    def __call__(self, inputs):
 40        species = inputs["species"]
 41        graph = inputs[self.graph_key]
 42        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 43
 44        training = "training" in inputs.get("flags", {})
 45
 46        rijs = graph["distances"]
 47        d_ = self.d
 48        p_ = self.p
 49        assert len(self.alphas) == 4, "alphas must be a sequence of length 4"
 50        alphas_ = np.array(self.alphas, dtype=rijs.dtype)
 51        assert len(self.cs) == 4, "cs must be a sequence of length 4"
 52        cs_ = np.array(self.cs, dtype=rijs.dtype)
 53        cs_ = 0.5 * cs_ / np.sum(cs_)
 54        if self.trainable:
 55            d = jnp.abs(
 56                self.param(
 57                    "d",
 58                    lambda key, d: jnp.asarray(d/au.ANG, dtype=rijs.dtype),
 59                    d_,
 60                )
 61            )*au.ANG
 62            p = jnp.abs(
 63                self.param(
 64                    "p",
 65                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 66                    p_,
 67                )
 68            )
 69            cs = 0.5 * jax.nn.softmax(
 70                self.param(
 71                    "cs",
 72                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 73                    np.array(self.cs_logits, dtype=rijs.dtype),
 74                )
 75            )
 76            alphas = jnp.abs(
 77                self.param(
 78                    "alphas",
 79                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 80                    alphas_,
 81                )
 82            )
 83
 84            if training:
 85                if self.proportional_regularization:
 86                    reg = jnp.asarray(
 87                        ((1 - alphas / alphas_) ** 2).sum()
 88                        + ((1 - cs / cs_) ** 2).sum()
 89                        + (1 - p / p_) ** 2
 90                        + (1 - d / d_) ** 2
 91                    ).reshape(1)
 92                else:
 93                    reg = jnp.asarray(
 94                        ((alphas_ - alphas) ** 2).sum()
 95                        + ((cs_ - cs) ** 2).sum()
 96                        + (p_ - p) ** 2
 97                        + (d_ - d) ** 2
 98                    ).reshape(1)
 99        else:
100            cs = jnp.asarray(cs_)
101            alphas = jnp.asarray(alphas_)
102            d = d_
103            p = p_
104
105        if "alch_group" in inputs:
106            switch = graph["switch_raw"]
107            lambda_v = inputs["alch_vlambda"]
108            alch_group = inputs["alch_group"]
109            alch_m = inputs.get("alch_m", 2)
110
111            mask = alch_group[edge_src] == alch_group[edge_dst]
112
113            if "alch_softcore_rep" in inputs:
114                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
115                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
116            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
117            switch = jnp.where(
118                mask,
119                switch,
120                (lambda_v**alch_m) * switch,
121            )
122        else:
123            switch = graph["switch"]
124
125        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
126        Zij = Z[edge_src] * Z[edge_dst]
127        Zp = Z**p / d
128        x = rijs * (Zp[edge_src] + Zp[edge_dst])
129        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
130
131        ereppair = Zij * phi / rijs * switch
132
133        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
134
135        energy_unit = au.ANG*au.get_multiplier(self._energy_unit)
136        energy_key = self.energy_key if self.energy_key is not None else self.name
137        output = {**inputs, energy_key: erep_atomic * energy_unit}
138        if self.trainable and training:
139            output[energy_key + "_regularization"] = reg
140
141        return output
142
143
144class RepulsionNLH(nn.Module):
145    """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
146
147    FID: REPULSION_NLH
148
149    ### Reference
150    K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818
151    https://doi.org/10.1103/PhysRevA.111.032818
152    """
153
154    _graphs_properties: Dict
155    graph_key: str = "graph"
156    """The key for the graph input."""
157    energy_key: Optional[str] = None
158    """The key for the output energy."""
159    _energy_unit: str = "Ha"
160    """The energy unit of the model. **Automatically set by FENNIX**"""
161    trainable: bool = False
162    direct_forces_key: Optional[str] = None
163
164    FID: ClassVar[str] = "REPULSION_NLH"
165
166    @nn.compact
167    def __call__(self, inputs):
168
169        path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat"
170        DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8))
171        zmax = int(np.max(DATA_NLH[:, 0]))
172        AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32)
173        for i in range(DATA_NLH.shape[0]):
174            z1 = int(DATA_NLH[i, 0])
175            z2 = int(DATA_NLH[i, 1])
176            AB[z1 + zmax * z2] = DATA_NLH[i, 2:8]
177            AB[z2 + zmax * z1] = DATA_NLH[i, 2:8]
178        AB = AB.reshape((zmax + 1) ** 2, 3, 2)
179
180        species = inputs["species"]
181        graph = inputs[self.graph_key]
182        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
183        rijs = graph["distances"]
184
185        # coefficients (a1,a2,a3)
186        CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype)
187        # exponents (b1,b2,b3)
188        ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype)
189
190        if self.trainable:
191            cfact = jnp.abs(
192                self.param(
193                    "c_fact",
194                    lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype),
195                )
196            )
197            CS = CS * cfact[None, :]
198            CS = CS / jnp.sum(CS, axis=1, keepdims=True)
199            alphas_fact = jnp.abs(
200                self.param(
201                    "alpha_fact",
202                    lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype),
203                )
204            )
205            ALPHAS = ALPHAS * alphas_fact[None, :]
206
207        s12 = species[edge_src] + zmax * species[edge_dst]
208        cs = CS[s12]
209        alphas = ALPHAS[s12]
210
211        if "alch_group" in inputs:
212            switch = graph["switch_raw"]
213            lambda_v = inputs["alch_vlambda"]
214            alch_group = inputs["alch_group"]
215            alch_m = inputs.get("alch_m", 2)
216
217            mask = alch_group[edge_src] == alch_group[edge_dst]
218            if "alch_softcore_rep" in inputs:
219                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
220                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
221            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
222            switch = jnp.where(
223                mask,
224                switch,
225                (lambda_v**alch_m) * switch,
226            )
227            # alphas = jnp.where(
228            #     mask[:,None],
229            #     alphas,
230            #     lambda_v * alphas ,
231            # )
232        else:
233            switch = graph["switch"]
234
235        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
236        phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
237        Zij = Z[edge_src] * Z[edge_dst] * switch
238
239        ereppair = Zij * phi / rijs
240
241        energy_unit = au.get_multiplier(self._energy_unit)
242        erep_atomic = (energy_unit * 0.5 * au.ANG) * jax.ops.segment_sum(
243            ereppair, edge_src, species.shape[0]
244        )
245
246        energy_key = self.energy_key if self.energy_key is not None else self.name
247        output = {**inputs, energy_key: erep_atomic}
248
249        if self.direct_forces_key is not None:
250            dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
251            dedr = Zij * (dphidr / rijs - phi / (rijs**2))
252            dedij = (dedr / rijs)[:, None] * graph["vec"]
253            fi = (energy_unit * au.ANG) * jax.ops.segment_sum(
254                dedij, edge_src, species.shape[0]
255            )
256            output[self.direct_forces_key] = fi
257
258        return output
class RepulsionZBL(flax.linen.module.Module):
 11class RepulsionZBL(nn.Module):
 12    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 13
 14    FID: REPULSION_ZBL
 15
 16    ### Reference
 17    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 18
 19    """
 20
 21    _graphs_properties: Dict
 22    graph_key: str = "graph"
 23    """The key for the graph input."""
 24    energy_key: Optional[str] = None
 25    """The key for the output energy."""
 26    trainable: bool = True
 27    """Whether the parameters are trainable."""
 28    _energy_unit: str = "Ha"
 29    """The energy unit of the model. **Automatically set by FENNIX**"""
 30    proportional_regularization: bool = True
 31    d: float = 0.46850
 32    p: float = 0.23
 33    alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162)
 34    cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
 35    cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514)
 36
 37    FID: ClassVar[str] = "REPULSION_ZBL"
 38
 39    @nn.compact
 40    def __call__(self, inputs):
 41        species = inputs["species"]
 42        graph = inputs[self.graph_key]
 43        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 44
 45        training = "training" in inputs.get("flags", {})
 46
 47        rijs = graph["distances"]
 48        d_ = self.d
 49        p_ = self.p
 50        assert len(self.alphas) == 4, "alphas must be a sequence of length 4"
 51        alphas_ = np.array(self.alphas, dtype=rijs.dtype)
 52        assert len(self.cs) == 4, "cs must be a sequence of length 4"
 53        cs_ = np.array(self.cs, dtype=rijs.dtype)
 54        cs_ = 0.5 * cs_ / np.sum(cs_)
 55        if self.trainable:
 56            d = jnp.abs(
 57                self.param(
 58                    "d",
 59                    lambda key, d: jnp.asarray(d/au.ANG, dtype=rijs.dtype),
 60                    d_,
 61                )
 62            )*au.ANG
 63            p = jnp.abs(
 64                self.param(
 65                    "p",
 66                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 67                    p_,
 68                )
 69            )
 70            cs = 0.5 * jax.nn.softmax(
 71                self.param(
 72                    "cs",
 73                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 74                    np.array(self.cs_logits, dtype=rijs.dtype),
 75                )
 76            )
 77            alphas = jnp.abs(
 78                self.param(
 79                    "alphas",
 80                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 81                    alphas_,
 82                )
 83            )
 84
 85            if training:
 86                if self.proportional_regularization:
 87                    reg = jnp.asarray(
 88                        ((1 - alphas / alphas_) ** 2).sum()
 89                        + ((1 - cs / cs_) ** 2).sum()
 90                        + (1 - p / p_) ** 2
 91                        + (1 - d / d_) ** 2
 92                    ).reshape(1)
 93                else:
 94                    reg = jnp.asarray(
 95                        ((alphas_ - alphas) ** 2).sum()
 96                        + ((cs_ - cs) ** 2).sum()
 97                        + (p_ - p) ** 2
 98                        + (d_ - d) ** 2
 99                    ).reshape(1)
100        else:
101            cs = jnp.asarray(cs_)
102            alphas = jnp.asarray(alphas_)
103            d = d_
104            p = p_
105
106        if "alch_group" in inputs:
107            switch = graph["switch_raw"]
108            lambda_v = inputs["alch_vlambda"]
109            alch_group = inputs["alch_group"]
110            alch_m = inputs.get("alch_m", 2)
111
112            mask = alch_group[edge_src] == alch_group[edge_dst]
113
114            if "alch_softcore_rep" in inputs:
115                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
116                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
117            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
118            switch = jnp.where(
119                mask,
120                switch,
121                (lambda_v**alch_m) * switch,
122            )
123        else:
124            switch = graph["switch"]
125
126        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
127        Zij = Z[edge_src] * Z[edge_dst]
128        Zp = Z**p / d
129        x = rijs * (Zp[edge_src] + Zp[edge_dst])
130        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
131
132        ereppair = Zij * phi / rijs * switch
133
134        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
135
136        energy_unit = au.ANG*au.get_multiplier(self._energy_unit)
137        energy_key = self.energy_key if self.energy_key is not None else self.name
138        output = {**inputs, energy_key: erep_atomic * energy_unit}
139        if self.trainable and training:
140            output[energy_key + "_regularization"] = reg
141
142        return output

Repulsion energy based on the Ziegler-Biersack-Littmark potential

FID: REPULSION_ZBL

Reference

J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter

RepulsionZBL( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, trainable: bool = True, _energy_unit: str = 'Ha', proportional_regularization: bool = True, d: float = 0.4685, p: float = 0.23, alphas: Sequence[float] = (3.1998, 0.94229, 0.4029, 0.20162), cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697), cs_logits: Sequence[float] = (0.113, 1.1445, 0.5459, -1.7514), parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

energy_key: Optional[str] = None

The key for the output energy.

trainable: bool = True

Whether the parameters are trainable.

proportional_regularization: bool = True
d: float = 0.4685
p: float = 0.23
alphas: Sequence[float] = (3.1998, 0.94229, 0.4029, 0.20162)
cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
cs_logits: Sequence[float] = (0.113, 1.1445, 0.5459, -1.7514)
FID: ClassVar[str] = 'REPULSION_ZBL'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class RepulsionNLH(flax.linen.module.Module):
145class RepulsionNLH(nn.Module):
146    """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
147
148    FID: REPULSION_NLH
149
150    ### Reference
151    K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818
152    https://doi.org/10.1103/PhysRevA.111.032818
153    """
154
155    _graphs_properties: Dict
156    graph_key: str = "graph"
157    """The key for the graph input."""
158    energy_key: Optional[str] = None
159    """The key for the output energy."""
160    _energy_unit: str = "Ha"
161    """The energy unit of the model. **Automatically set by FENNIX**"""
162    trainable: bool = False
163    direct_forces_key: Optional[str] = None
164
165    FID: ClassVar[str] = "REPULSION_NLH"
166
167    @nn.compact
168    def __call__(self, inputs):
169
170        path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat"
171        DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8))
172        zmax = int(np.max(DATA_NLH[:, 0]))
173        AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32)
174        for i in range(DATA_NLH.shape[0]):
175            z1 = int(DATA_NLH[i, 0])
176            z2 = int(DATA_NLH[i, 1])
177            AB[z1 + zmax * z2] = DATA_NLH[i, 2:8]
178            AB[z2 + zmax * z1] = DATA_NLH[i, 2:8]
179        AB = AB.reshape((zmax + 1) ** 2, 3, 2)
180
181        species = inputs["species"]
182        graph = inputs[self.graph_key]
183        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
184        rijs = graph["distances"]
185
186        # coefficients (a1,a2,a3)
187        CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype)
188        # exponents (b1,b2,b3)
189        ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype)
190
191        if self.trainable:
192            cfact = jnp.abs(
193                self.param(
194                    "c_fact",
195                    lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype),
196                )
197            )
198            CS = CS * cfact[None, :]
199            CS = CS / jnp.sum(CS, axis=1, keepdims=True)
200            alphas_fact = jnp.abs(
201                self.param(
202                    "alpha_fact",
203                    lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype),
204                )
205            )
206            ALPHAS = ALPHAS * alphas_fact[None, :]
207
208        s12 = species[edge_src] + zmax * species[edge_dst]
209        cs = CS[s12]
210        alphas = ALPHAS[s12]
211
212        if "alch_group" in inputs:
213            switch = graph["switch_raw"]
214            lambda_v = inputs["alch_vlambda"]
215            alch_group = inputs["alch_group"]
216            alch_m = inputs.get("alch_m", 2)
217
218            mask = alch_group[edge_src] == alch_group[edge_dst]
219            if "alch_softcore_rep" in inputs:
220                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
221                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
222            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
223            switch = jnp.where(
224                mask,
225                switch,
226                (lambda_v**alch_m) * switch,
227            )
228            # alphas = jnp.where(
229            #     mask[:,None],
230            #     alphas,
231            #     lambda_v * alphas ,
232            # )
233        else:
234            switch = graph["switch"]
235
236        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
237        phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
238        Zij = Z[edge_src] * Z[edge_dst] * switch
239
240        ereppair = Zij * phi / rijs
241
242        energy_unit = au.get_multiplier(self._energy_unit)
243        erep_atomic = (energy_unit * 0.5 * au.ANG) * jax.ops.segment_sum(
244            ereppair, edge_src, species.shape[0]
245        )
246
247        energy_key = self.energy_key if self.energy_key is not None else self.name
248        output = {**inputs, energy_key: erep_atomic}
249
250        if self.direct_forces_key is not None:
251            dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
252            dedr = Zij * (dphidr / rijs - phi / (rijs**2))
253            dedij = (dedr / rijs)[:, None] * graph["vec"]
254            fi = (energy_unit * au.ANG) * jax.ops.segment_sum(
255                dedij, edge_src, species.shape[0]
256            )
257            output[self.direct_forces_key] = fi
258
259        return output

NLH pairwise repulsive potential with pair-specific coefficients up to Z=92

FID: REPULSION_NLH

Reference

K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 https://doi.org/10.1103/PhysRevA.111.032818

RepulsionNLH( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, _energy_unit: str = 'Ha', trainable: bool = False, direct_forces_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

energy_key: Optional[str] = None

The key for the output energy.

trainable: bool = False
direct_forces_key: Optional[str] = None
FID: ClassVar[str] = 'REPULSION_NLH'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None