fennol.models.physics.dispersion

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import numpy as np
  5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  6from ...utils.atomic_units import au
  7from ...utils.periodic_table import (
  8    POLARIZABILITIES,
  9    C6_FREE,
 10)
 11import pathlib
 12import pickle
 13
 14class VdwOQDO(nn.Module):
 15    """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
 16    
 17    FID : VDW_OQDO
 18
 19    ### Reference
 20    A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators,
 21    J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
 22    
 23    """
 24    graph_key: str = "graph"
 25    """ The key for the graph input."""
 26    include_exchange: bool = True
 27    """ Whether to compute the exchange part."""
 28    ratiovol_key: Optional[str] = None
 29    """ The key for the ratio between AIM volume and free-atom volume. 
 30         If None, the volume ratio is assumed to be 1.0."""
 31    energy_key: Optional[str] = None
 32    """ The key for the output energy. If None, the name of the module is used."""
 33    damped: bool = True
 34    """ Whether to use short-range damping."""
 35    _energy_unit: str = "Ha"
 36    """The energy unit of the model. **Automatically set by FENNIX**"""
 37
 38    FID: ClassVar[str]  = "VDW_OQDO"
 39
 40    @nn.compact
 41    def __call__(self, inputs):
 42        energy_unit = au.get_multiplier(self._energy_unit)
 43
 44        species = inputs["species"]
 45        graph = inputs[self.graph_key]
 46        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 47        rij = graph["distances"] / au.ANG
 48        switch = graph["switch"]
 49
 50        c6 = jnp.asarray(C6_FREE)[species]
 51        alpha = jnp.asarray(POLARIZABILITIES)[species]
 52
 53        if self.ratiovol_key is not None:
 54            ratiovol = inputs[self.ratiovol_key] + 1.0e-6
 55            if ratiovol.shape[-1] == 1:
 56                ratiovol = jnp.squeeze(ratiovol, axis=-1)
 57            c6 = c6 * ratiovol**2
 58            alpha = alpha * ratiovol
 59
 60        c6i, c6j = c6[edge_src], c6[edge_dst]
 61        alphai, alphaj = alpha[edge_src], alpha[edge_dst]
 62
 63        # combination rules
 64        alphaij = 0.5 * (alphai + alphaj)
 65        c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2)
 66
 67        # equilibrium distance
 68        Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
 69        Re2 = Re**2
 70        Re4 = Re**4
 71        # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators"
 72        if self.damped:
 73            muw = (
 74                4.83053463e-01
 75                - 3.76191669e-02 * Re
 76                + 1.27066988e-03 * Re2
 77                - 7.21940151e-07 * Re4
 78            ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
 79        else:
 80            muw = (
 81                3.66316787e01
 82                - 5.79579187 * Re
 83                + 3.02674813e-01 * Re2
 84                - 3.65461255e-04 * Re4
 85            ) / (-1.46169102e01 + 7.32461225 * Re)
 86
 87        c8ij = 5 * c6ij / muw
 88        c10ij = 245 * c6ij / (8 * muw**2)
 89
 90        if self.damped:
 91            z = 0.5 * muw * rij**2
 92            ez = jnp.exp(-z)
 93            f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3)
 94            f8 = f6 - (1.0 / 24.0) * ez * z**4
 95            f10 = f8 - (1.0 / 120.0) * ez * z**5
 96            epair = (
 97                f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10
 98            )
 99        else:
100            epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10
101
102        edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0])
103
104        output_key = self.name if self.energy_key is None else self.energy_key
105
106        if not self.include_exchange:
107            return {**inputs, output_key: edisp}
108
109        ### exchange
110        w = 4 * c6ij / (3 * alphaij**2)
111        # q = (alphaij * mu*w**2)**0.5
112        q2 = alphaij * muw * w
113        # undamped case
114        if self.damped:
115            ze = 0.5 * muw * Re2
116            eze = jnp.exp(-ze)
117
118            s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3)
119            f6e = 1.0 - s6
120            muwRe = muw * Re
121            df6e = muwRe * s6 - eze * (
122                muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3
123            )
124
125            s8 = (1.0 / 24.0) * eze * ze**4
126            f8e = f6e - s8
127            df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4
128
129            s10 = (1.0 / 120.0) * eze * ze**5
130            f10e = f8e - s10
131            df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5
132
133            den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e)
134            A = (
135                0.5
136                + c8ij * (8 * f8e - Re * df8e) / den
137                + c10ij * (10 * f10e - Re * df10e) / (den * Re2)
138            )
139        else:
140            A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4)
141            ez = jnp.exp(-0.5 * muw * rij**2)
142
143        exij = A * q2 * ez / rij
144        ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0])
145
146        return {
147            **inputs,
148            output_key + "_dispersion": edisp,
149            output_key + "_exchange": ex,
150            output_key: edisp + ex,
151        }
152
153
154class DispersionD3(nn.Module):
155    """
156
157    FID : DISPERSION_D3
158
159    """
160
161    _graphs_properties: Dict
162    graph_key: str = "graph"
163    output_key: Optional[str] = None
164    s6: float = 1.0
165    s8: float = 1.0
166    a1: float = 0.4
167    a2: float = 5.0
168    trainable: bool = False
169    _energy_unit: str = "Ha"
170    """The energy unit of the model. **Automatically set by FENNIX**"""
171
172    FID: ClassVar[str] = "DISPERSION_D3"
173
174    @nn.compact
175    def __call__(self, inputs):
176
177        path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl"
178        with open(path, "rb") as f:
179            DATA_D3 = pickle.load(f)
180
181        graph = inputs[self.graph_key]
182        edge_src = graph["edge_src"]
183        edge_dst = graph["edge_dst"]
184        switch = graph["switch"]
185        species = inputs["species"]
186
187        rij = jnp.clip(graph["distances"] / au.ANG, 1e-6, None)
188
189        ## RADII (in BOHR)
190        rcov = jnp.array(DATA_D3["COV_D3"])[species]
191        # rvdw = jnp.array(DATA_D3["VDW_D3"])[species]
192        r4r2 = jnp.array(DATA_D3["R4R2"])[species]
193
194        rcij = rcov[edge_src] + rcov[edge_dst]
195
196        ## COORDINATION NUMBER
197        KCN = 16.0
198        cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0))
199        cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0])
200
201        ## WEIGHTS
202        refcn = jnp.array(DATA_D3["REF_CN"])[species]
203        mask = refcn >= 0
204        dcn = refcn - cn[:, None]
205        KW = 4.0
206        weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0)
207        norm = weights.sum(axis=1, keepdims=True)
208        weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0)
209
210        ## correct for all null weights
211        imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1)
212        exweight = np.zeros_like(DATA_D3["REF_CN"])
213        for i, imax in enumerate(imaxcn):
214            exweight[i, imax] = 1.0
215        exweight = jnp.array(exweight)[species]
216
217        exceptional = norm < 1.0e-6
218        weights = jnp.where(exceptional, exweight, weights)
219
220        ## C6 coefficients
221        REF_C6 = DATA_D3["REF_C6"]
222        nz = REF_C6.shape[0]
223        nref = REF_C6.shape[-1]
224        REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref)))
225        pair_num = species[edge_src] * nz + species[edge_dst]
226        rc6 = REF_C6[pair_num]
227        c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst])
228
229        ## DISPERSION
230
231        qq = 3 * r4r2[edge_src] * r4r2[edge_dst]
232        c8 = c6 * qq
233
234        if self.trainable:
235            s6 = jnp.abs(
236                self.param(
237                    "s6",
238                    lambda key: jnp.array(self.s6, dtype=c6.dtype),
239                )
240            )
241            s8 = jnp.abs(
242                self.param(
243                    "s8",
244                    lambda key: jnp.array(self.s8, dtype=c6.dtype),
245                )
246            )
247            a1 = jnp.abs(
248                self.param(
249                    "a1",
250                    lambda key: jnp.array(self.a1, dtype=c6.dtype),
251                )
252            )
253            a2 = jnp.abs(
254                self.param(
255                    "a2",
256                    lambda key: jnp.array(self.a2, dtype=c6.dtype),
257                )
258            )
259        else:
260            s6 = self.s6
261            s8 = self.s8
262            a1 = self.a1
263            a2 = self.a2
264
265        r0 = a1 * jnp.sqrt(qq) + a2
266
267        t6 = s6 / (rij**6 + r0**6)
268        t8 = s8 / (rij**8 + r0**8)
269
270        energy_unit = au.get_multiplier(self._energy_unit)
271        energy = (-0.5*energy_unit) * jax.ops.segment_sum(
272            (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0]
273        )
274
275        output_key = self.output_key if self.output_key is not None else self.name
276        return {**inputs, output_key: energy}
class VdwOQDO(flax.linen.module.Module):
 15class VdwOQDO(nn.Module):
 16    """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
 17    
 18    FID : VDW_OQDO
 19
 20    ### Reference
 21    A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators,
 22    J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
 23    
 24    """
 25    graph_key: str = "graph"
 26    """ The key for the graph input."""
 27    include_exchange: bool = True
 28    """ Whether to compute the exchange part."""
 29    ratiovol_key: Optional[str] = None
 30    """ The key for the ratio between AIM volume and free-atom volume. 
 31         If None, the volume ratio is assumed to be 1.0."""
 32    energy_key: Optional[str] = None
 33    """ The key for the output energy. If None, the name of the module is used."""
 34    damped: bool = True
 35    """ Whether to use short-range damping."""
 36    _energy_unit: str = "Ha"
 37    """The energy unit of the model. **Automatically set by FENNIX**"""
 38
 39    FID: ClassVar[str]  = "VDW_OQDO"
 40
 41    @nn.compact
 42    def __call__(self, inputs):
 43        energy_unit = au.get_multiplier(self._energy_unit)
 44
 45        species = inputs["species"]
 46        graph = inputs[self.graph_key]
 47        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 48        rij = graph["distances"] / au.ANG
 49        switch = graph["switch"]
 50
 51        c6 = jnp.asarray(C6_FREE)[species]
 52        alpha = jnp.asarray(POLARIZABILITIES)[species]
 53
 54        if self.ratiovol_key is not None:
 55            ratiovol = inputs[self.ratiovol_key] + 1.0e-6
 56            if ratiovol.shape[-1] == 1:
 57                ratiovol = jnp.squeeze(ratiovol, axis=-1)
 58            c6 = c6 * ratiovol**2
 59            alpha = alpha * ratiovol
 60
 61        c6i, c6j = c6[edge_src], c6[edge_dst]
 62        alphai, alphaj = alpha[edge_src], alpha[edge_dst]
 63
 64        # combination rules
 65        alphaij = 0.5 * (alphai + alphaj)
 66        c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2)
 67
 68        # equilibrium distance
 69        Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
 70        Re2 = Re**2
 71        Re4 = Re**4
 72        # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators"
 73        if self.damped:
 74            muw = (
 75                4.83053463e-01
 76                - 3.76191669e-02 * Re
 77                + 1.27066988e-03 * Re2
 78                - 7.21940151e-07 * Re4
 79            ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
 80        else:
 81            muw = (
 82                3.66316787e01
 83                - 5.79579187 * Re
 84                + 3.02674813e-01 * Re2
 85                - 3.65461255e-04 * Re4
 86            ) / (-1.46169102e01 + 7.32461225 * Re)
 87
 88        c8ij = 5 * c6ij / muw
 89        c10ij = 245 * c6ij / (8 * muw**2)
 90
 91        if self.damped:
 92            z = 0.5 * muw * rij**2
 93            ez = jnp.exp(-z)
 94            f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3)
 95            f8 = f6 - (1.0 / 24.0) * ez * z**4
 96            f10 = f8 - (1.0 / 120.0) * ez * z**5
 97            epair = (
 98                f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10
 99            )
100        else:
101            epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10
102
103        edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0])
104
105        output_key = self.name if self.energy_key is None else self.energy_key
106
107        if not self.include_exchange:
108            return {**inputs, output_key: edisp}
109
110        ### exchange
111        w = 4 * c6ij / (3 * alphaij**2)
112        # q = (alphaij * mu*w**2)**0.5
113        q2 = alphaij * muw * w
114        # undamped case
115        if self.damped:
116            ze = 0.5 * muw * Re2
117            eze = jnp.exp(-ze)
118
119            s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3)
120            f6e = 1.0 - s6
121            muwRe = muw * Re
122            df6e = muwRe * s6 - eze * (
123                muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3
124            )
125
126            s8 = (1.0 / 24.0) * eze * ze**4
127            f8e = f6e - s8
128            df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4
129
130            s10 = (1.0 / 120.0) * eze * ze**5
131            f10e = f8e - s10
132            df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5
133
134            den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e)
135            A = (
136                0.5
137                + c8ij * (8 * f8e - Re * df8e) / den
138                + c10ij * (10 * f10e - Re * df10e) / (den * Re2)
139            )
140        else:
141            A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4)
142            ez = jnp.exp(-0.5 * muw * rij**2)
143
144        exij = A * q2 * ez / rij
145        ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0])
146
147        return {
148            **inputs,
149            output_key + "_dispersion": edisp,
150            output_key + "_exchange": ex,
151            output_key: edisp + ex,
152        }

Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.

FID : VDW_OQDO

Reference

A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)

VdwOQDO( graph_key: str = 'graph', include_exchange: bool = True, ratiovol_key: Optional[str] = None, energy_key: Optional[str] = None, damped: bool = True, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

include_exchange: bool = True

Whether to compute the exchange part.

ratiovol_key: Optional[str] = None

The key for the ratio between AIM volume and free-atom volume. If None, the volume ratio is assumed to be 1.0.

energy_key: Optional[str] = None

The key for the output energy. If None, the name of the module is used.

damped: bool = True

Whether to use short-range damping.

FID: ClassVar[str] = 'VDW_OQDO'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class DispersionD3(flax.linen.module.Module):
155class DispersionD3(nn.Module):
156    """
157
158    FID : DISPERSION_D3
159
160    """
161
162    _graphs_properties: Dict
163    graph_key: str = "graph"
164    output_key: Optional[str] = None
165    s6: float = 1.0
166    s8: float = 1.0
167    a1: float = 0.4
168    a2: float = 5.0
169    trainable: bool = False
170    _energy_unit: str = "Ha"
171    """The energy unit of the model. **Automatically set by FENNIX**"""
172
173    FID: ClassVar[str] = "DISPERSION_D3"
174
175    @nn.compact
176    def __call__(self, inputs):
177
178        path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl"
179        with open(path, "rb") as f:
180            DATA_D3 = pickle.load(f)
181
182        graph = inputs[self.graph_key]
183        edge_src = graph["edge_src"]
184        edge_dst = graph["edge_dst"]
185        switch = graph["switch"]
186        species = inputs["species"]
187
188        rij = jnp.clip(graph["distances"] / au.ANG, 1e-6, None)
189
190        ## RADII (in BOHR)
191        rcov = jnp.array(DATA_D3["COV_D3"])[species]
192        # rvdw = jnp.array(DATA_D3["VDW_D3"])[species]
193        r4r2 = jnp.array(DATA_D3["R4R2"])[species]
194
195        rcij = rcov[edge_src] + rcov[edge_dst]
196
197        ## COORDINATION NUMBER
198        KCN = 16.0
199        cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0))
200        cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0])
201
202        ## WEIGHTS
203        refcn = jnp.array(DATA_D3["REF_CN"])[species]
204        mask = refcn >= 0
205        dcn = refcn - cn[:, None]
206        KW = 4.0
207        weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0)
208        norm = weights.sum(axis=1, keepdims=True)
209        weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0)
210
211        ## correct for all null weights
212        imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1)
213        exweight = np.zeros_like(DATA_D3["REF_CN"])
214        for i, imax in enumerate(imaxcn):
215            exweight[i, imax] = 1.0
216        exweight = jnp.array(exweight)[species]
217
218        exceptional = norm < 1.0e-6
219        weights = jnp.where(exceptional, exweight, weights)
220
221        ## C6 coefficients
222        REF_C6 = DATA_D3["REF_C6"]
223        nz = REF_C6.shape[0]
224        nref = REF_C6.shape[-1]
225        REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref)))
226        pair_num = species[edge_src] * nz + species[edge_dst]
227        rc6 = REF_C6[pair_num]
228        c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst])
229
230        ## DISPERSION
231
232        qq = 3 * r4r2[edge_src] * r4r2[edge_dst]
233        c8 = c6 * qq
234
235        if self.trainable:
236            s6 = jnp.abs(
237                self.param(
238                    "s6",
239                    lambda key: jnp.array(self.s6, dtype=c6.dtype),
240                )
241            )
242            s8 = jnp.abs(
243                self.param(
244                    "s8",
245                    lambda key: jnp.array(self.s8, dtype=c6.dtype),
246                )
247            )
248            a1 = jnp.abs(
249                self.param(
250                    "a1",
251                    lambda key: jnp.array(self.a1, dtype=c6.dtype),
252                )
253            )
254            a2 = jnp.abs(
255                self.param(
256                    "a2",
257                    lambda key: jnp.array(self.a2, dtype=c6.dtype),
258                )
259            )
260        else:
261            s6 = self.s6
262            s8 = self.s8
263            a1 = self.a1
264            a2 = self.a2
265
266        r0 = a1 * jnp.sqrt(qq) + a2
267
268        t6 = s6 / (rij**6 + r0**6)
269        t8 = s8 / (rij**8 + r0**8)
270
271        energy_unit = au.get_multiplier(self._energy_unit)
272        energy = (-0.5*energy_unit) * jax.ops.segment_sum(
273            (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0]
274        )
275
276        output_key = self.output_key if self.output_key is not None else self.name
277        return {**inputs, output_key: energy}

FID : DISPERSION_D3

DispersionD3( _graphs_properties: Dict, graph_key: str = 'graph', output_key: Optional[str] = None, s6: float = 1.0, s8: float = 1.0, a1: float = 0.4, a2: float = 5.0, trainable: bool = False, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'
output_key: Optional[str] = None
s6: float = 1.0
s8: float = 1.0
a1: float = 0.4
a2: float = 5.0
trainable: bool = False
FID: ClassVar[str] = 'DISPERSION_D3'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None