fennol.md.barostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7from enum import Enum
  8
  9from .utils import us
 10
 11
 12def get_barostat(
 13    thermostat,
 14    simulation_parameters,
 15    dt,
 16    system_data,
 17    fprec,
 18    rng_key=None,
 19    restart_data={},
 20):
 21    state = {}
 22
 23    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 24    """@keyword[fennol_md] barostat
 25    Type of barostat for pressure control (NONE, LGV, LANGEVIN).
 26    Default: NONE
 27    """
 28
 29    kT = system_data.get("kT", None)
 30    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 31    target_pressure = simulation_parameters.get("target_pressure")
 32    """@keyword[fennol_md] target_pressure
 33    Target pressure for NPT ensemble simulations.
 34    Required for barostat != NONE
 35    """
 36    if barostat_name != "NONE":
 37        assert (
 38            target_pressure is not None
 39        ), "target_pressure must be specified for NPT/NPH simulations"
 40
 41    nbeads = system_data.get("nbeads", None)
 42    variable_cell = True
 43
 44    anisotropic = simulation_parameters.get("aniso_barostat", False)
 45    """@keyword[fennol_md] aniso_barostat
 46    Use anisotropic barostat allowing independent cell parameter scaling.
 47    Default: False
 48    """
 49
 50    isotropic = not anisotropic
 51
 52    pbc_data = system_data["pbc"]
 53    start_barostat = simulation_parameters.get("start_barostat", 0.0)
 54    """@keyword[fennol_md] start_barostat
 55    Time delay before starting barostat pressure coupling.
 56    Default: 0.0
 57    """
 58    start_time = restart_data.get("simulation_time_ps", 0.0) / us.PS
 59    start_barostat = max(0.0, start_barostat - start_time)
 60    istart_barostat = int(round(start_barostat / dt))
 61    if istart_barostat > 0 and barostat_name not in ["NONE"]:
 62        print(
 63            f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)"
 64        )
 65    else:
 66        istart_barostat = 0
 67
 68    if barostat_name in ["LGV", "LANGEVIN"]:
 69        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 70        gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ)
 71        """@keyword[fennol_md] gamma_piston
 72        Piston friction coefficient for Langevin barostat.
 73        Default: 20.0 ps^-1
 74        """
 75        tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS)
 76        """@keyword[fennol_md] tau_piston
 77        Piston time constant for barostat coupling.
 78        Default: 200.0 fs
 79        """
 80        nat = system_data["nat"]
 81        masspiston = 3 * nat * kT * tau_piston**2
 82        print(
 83            f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2"
 84        )
 85        a1 = math.exp(-gamma * dt)
 86        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 87
 88        rng_key, v_key = jax.random.split(rng_key)
 89        if anisotropic:
 90            extvol = pbc_data["cell"]
 91            vextvol = (
 92                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 93                * (kT / masspiston) ** 0.5
 94            )
 95            vextvol = 0.5 * (vextvol + vextvol.T)
 96
 97            aniso_mask = simulation_parameters.get(
 98                "aniso_mask", [True, True, True, True, True, True]
 99            )
100            """@keyword[fennol_md] aniso_mask
101            Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz].
102            Default: [True, True, True, True, True, True]
103            """
104            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
105            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
106            ndof_piston = np.sum(aniso_mask)
107            # xx   yy   zz    xy   xz   yz
108            aniso_mask = np.array(
109                [
110                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
111                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
112                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
113                ],
114                dtype=np.int32,
115            )
116        else:
117            extvol = jnp.asarray(pbc_data["volume"])
118            vextvol = (
119                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
120                * (kT / masspiston) ** 0.5
121            )
122            ndof_piston = 1.0
123
124        state["extvol"] = extvol
125        state["vextvol"] = vextvol
126        state["rng_key"] = rng_key
127        state["istep"] = 0
128
129        def barostat(x, vel, system):
130            if nbeads is not None:
131                x, eigx = x[0], x[1:]
132                vel, eigv = vel[0], vel[1:]
133            barostat_state = system["barostat"]
134            extvol = barostat_state["extvol"]
135            vextvol = barostat_state["vextvol"]
136            cell = system["cell"]
137            volume = jnp.abs(jnp.linalg.det(cell))
138
139            istep = barostat_state["istep"] + 1
140            dt_bar = dt * (istep >= istart_barostat)
141
142            # apply B
143            # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
144            ek = system["ek_c"] if nbeads is not None else system["ek"]
145            Pres = (
146                system["pressure_tensor"]
147                + ek * jnp.array(np.eye(3) * (2 / (3 * x.shape[0]))) / volume
148            )
149            if isotropic:
150                dPres = jnp.trace(Pres) - 3 * target_pressure
151            else:
152                dPres = 0.5 * (Pres + Pres.T) - jnp.array(target_pressure * np.eye(3))
153
154            vextvol = vextvol + ((dt_bar / masspiston) * volume) * dPres
155
156            # apply A
157            if isotropic:
158                vdt2 = 0.5 * dt_bar * vextvol
159                scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0]))
160                vel = vel * scalev
161                scale1 = jnp.exp(vdt2)
162            else:
163                vextvol = aniso_mask * vextvol
164                vdt2 = 0.5 * dt_bar * vextvol
165                l, O = jnp.linalg.eigh(vdt2)
166                lcorr = jnp.trace(vdt2) / (3 * x.shape[0])
167                Dv = jnp.diag(jnp.exp(-(l + lcorr)))
168                Dx = jnp.diag(jnp.exp(l))
169                scalev = O @ Dv @ O.T
170                scale1 = O @ Dx @ O.T
171                vel = vel @ scalev
172
173            # apply O
174            if nbeads is not None:
175                eigv, thermostat_state = thermostat(
176                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
177                )
178                vel, eigv = eigv[0], eigv[1:]
179            else:
180                vel, thermostat_state = thermostat(vel, system["thermostat"])
181            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
182
183            if isotropic:
184                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
185            else:
186                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
187                noise = 0.5 * (noise + noise.T)
188
189            vextvol = a1 * vextvol + a2 * noise
190
191            # apply A
192            if isotropic:
193                vdt2 = 0.5 * dt_bar * vextvol
194                scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0]))
195                vel = vel * scalev
196                scale2 = jnp.exp(vdt2)
197                x = x * (scale1 * scale2)
198                extvol = extvol * (scale1 * scale2) ** 3
199                cell = cell * (scale1 * scale2)
200            else:
201                vextvol = aniso_mask * vextvol
202                vdt2 = 0.5 * dt_bar * vextvol
203                l, O = jnp.linalg.eigh(vdt2)
204                lcorr = jnp.trace(vdt2) / (3 * x.shape[0])
205                Dv = jnp.diag(jnp.exp(-(l + lcorr)))
206                Dx = jnp.diag(jnp.exp(l))
207                scalev = O @ Dv @ O.T
208                scale = scale1 @ (O @ Dx @ O.T)
209
210                extvol = extvol @ scale
211
212                # ensure cell is lower triangular
213                extvol, rotation_matrix = tril_cell_(extvol)
214                cell = extvol
215                vextvol = rotation_matrix.T @ vextvol @ rotation_matrix
216
217                # scale and rotate positions and velocities
218                vel = vel @ (scalev @ rotation_matrix)
219                x = x @ (scale @ rotation_matrix)
220
221            if nbeads is not None:
222                if not isotropic:
223                    eigx = eigx @ rotation_matrix
224                    eigv = eigv @ rotation_matrix
225                x = jnp.concatenate((x[None], eigx), axis=0)
226                vel = jnp.concatenate((vel[None], eigv), axis=0)
227
228            piston_temperature = (us.KELVIN * masspiston / ndof_piston) * jnp.sum(
229                vextvol**2
230            )
231            barostat_state = {
232                **barostat_state,
233                "istep": istep,
234                "rng_key": rng_key,
235                "vextvol": vextvol,
236                "extvol": extvol,
237                "piston_temperature": piston_temperature,
238            }
239            return (
240                x,
241                vel,
242                {
243                    **system,
244                    "barostat": barostat_state,
245                    "cell": cell,
246                    "thermostat": thermostat_state,
247                },
248            )
249
250    elif barostat_name in ["NONE"]:
251        variable_cell = False
252
253        def barostat(x, vel, system):
254            vel, thermostat_state = thermostat(vel, system["thermostat"])
255            return x, vel, {**system, "thermostat": thermostat_state}
256
257    else:
258        raise ValueError(f"Unknown barostat {barostat_name}")
259
260    return barostat, variable_cell, state
261
262
263def tril_cell_(cell):
264    cell = jnp.asarray(cell, dtype=float).reshape(3, 3)
265    a = jnp.linalg.norm(cell[0])
266    b = jnp.linalg.norm(cell[1])
267    c = jnp.linalg.norm(cell[2])
268    cos_alpha = jnp.dot(cell[1], cell[2]) / (b * c)
269    cos_beta = jnp.dot(cell[0], cell[2]) / (a * c)
270    cos_gamma = jnp.dot(cell[0], cell[1]) / (a * b)
271    cell_tril = cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma)
272    rotation = cell_tril @ jnp.linalg.inv(cell)
273    return cell_tril, rotation
274
275
276def cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma):
277    sin_gamma = jnp.sqrt(1.0 - cos_gamma * cos_gamma)
278
279    # Build the cell vectors
280    va = a * jnp.array([1, 0, 0])
281    vb = b * jnp.array([cos_gamma, sin_gamma, 0])
282    cx = cos_beta
283    cy = (cos_alpha - cos_beta * cos_gamma) / sin_gamma
284    cz_sqr = 1.0 - cx * cx - cy * cy
285    cz = jnp.sqrt(cz_sqr)
286    vc = c * jnp.array([cx, cy, cz])
287
288    return jnp.vstack((va, vb, vc))
def get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 13def get_barostat(
 14    thermostat,
 15    simulation_parameters,
 16    dt,
 17    system_data,
 18    fprec,
 19    rng_key=None,
 20    restart_data={},
 21):
 22    state = {}
 23
 24    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 25    """@keyword[fennol_md] barostat
 26    Type of barostat for pressure control (NONE, LGV, LANGEVIN).
 27    Default: NONE
 28    """
 29
 30    kT = system_data.get("kT", None)
 31    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 32    target_pressure = simulation_parameters.get("target_pressure")
 33    """@keyword[fennol_md] target_pressure
 34    Target pressure for NPT ensemble simulations.
 35    Required for barostat != NONE
 36    """
 37    if barostat_name != "NONE":
 38        assert (
 39            target_pressure is not None
 40        ), "target_pressure must be specified for NPT/NPH simulations"
 41
 42    nbeads = system_data.get("nbeads", None)
 43    variable_cell = True
 44
 45    anisotropic = simulation_parameters.get("aniso_barostat", False)
 46    """@keyword[fennol_md] aniso_barostat
 47    Use anisotropic barostat allowing independent cell parameter scaling.
 48    Default: False
 49    """
 50
 51    isotropic = not anisotropic
 52
 53    pbc_data = system_data["pbc"]
 54    start_barostat = simulation_parameters.get("start_barostat", 0.0)
 55    """@keyword[fennol_md] start_barostat
 56    Time delay before starting barostat pressure coupling.
 57    Default: 0.0
 58    """
 59    start_time = restart_data.get("simulation_time_ps", 0.0) / us.PS
 60    start_barostat = max(0.0, start_barostat - start_time)
 61    istart_barostat = int(round(start_barostat / dt))
 62    if istart_barostat > 0 and barostat_name not in ["NONE"]:
 63        print(
 64            f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)"
 65        )
 66    else:
 67        istart_barostat = 0
 68
 69    if barostat_name in ["LGV", "LANGEVIN"]:
 70        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 71        gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ)
 72        """@keyword[fennol_md] gamma_piston
 73        Piston friction coefficient for Langevin barostat.
 74        Default: 20.0 ps^-1
 75        """
 76        tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS)
 77        """@keyword[fennol_md] tau_piston
 78        Piston time constant for barostat coupling.
 79        Default: 200.0 fs
 80        """
 81        nat = system_data["nat"]
 82        masspiston = 3 * nat * kT * tau_piston**2
 83        print(
 84            f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2"
 85        )
 86        a1 = math.exp(-gamma * dt)
 87        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 88
 89        rng_key, v_key = jax.random.split(rng_key)
 90        if anisotropic:
 91            extvol = pbc_data["cell"]
 92            vextvol = (
 93                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 94                * (kT / masspiston) ** 0.5
 95            )
 96            vextvol = 0.5 * (vextvol + vextvol.T)
 97
 98            aniso_mask = simulation_parameters.get(
 99                "aniso_mask", [True, True, True, True, True, True]
100            )
101            """@keyword[fennol_md] aniso_mask
102            Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz].
103            Default: [True, True, True, True, True, True]
104            """
105            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
106            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
107            ndof_piston = np.sum(aniso_mask)
108            # xx   yy   zz    xy   xz   yz
109            aniso_mask = np.array(
110                [
111                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
112                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
113                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
114                ],
115                dtype=np.int32,
116            )
117        else:
118            extvol = jnp.asarray(pbc_data["volume"])
119            vextvol = (
120                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
121                * (kT / masspiston) ** 0.5
122            )
123            ndof_piston = 1.0
124
125        state["extvol"] = extvol
126        state["vextvol"] = vextvol
127        state["rng_key"] = rng_key
128        state["istep"] = 0
129
130        def barostat(x, vel, system):
131            if nbeads is not None:
132                x, eigx = x[0], x[1:]
133                vel, eigv = vel[0], vel[1:]
134            barostat_state = system["barostat"]
135            extvol = barostat_state["extvol"]
136            vextvol = barostat_state["vextvol"]
137            cell = system["cell"]
138            volume = jnp.abs(jnp.linalg.det(cell))
139
140            istep = barostat_state["istep"] + 1
141            dt_bar = dt * (istep >= istart_barostat)
142
143            # apply B
144            # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
145            ek = system["ek_c"] if nbeads is not None else system["ek"]
146            Pres = (
147                system["pressure_tensor"]
148                + ek * jnp.array(np.eye(3) * (2 / (3 * x.shape[0]))) / volume
149            )
150            if isotropic:
151                dPres = jnp.trace(Pres) - 3 * target_pressure
152            else:
153                dPres = 0.5 * (Pres + Pres.T) - jnp.array(target_pressure * np.eye(3))
154
155            vextvol = vextvol + ((dt_bar / masspiston) * volume) * dPres
156
157            # apply A
158            if isotropic:
159                vdt2 = 0.5 * dt_bar * vextvol
160                scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0]))
161                vel = vel * scalev
162                scale1 = jnp.exp(vdt2)
163            else:
164                vextvol = aniso_mask * vextvol
165                vdt2 = 0.5 * dt_bar * vextvol
166                l, O = jnp.linalg.eigh(vdt2)
167                lcorr = jnp.trace(vdt2) / (3 * x.shape[0])
168                Dv = jnp.diag(jnp.exp(-(l + lcorr)))
169                Dx = jnp.diag(jnp.exp(l))
170                scalev = O @ Dv @ O.T
171                scale1 = O @ Dx @ O.T
172                vel = vel @ scalev
173
174            # apply O
175            if nbeads is not None:
176                eigv, thermostat_state = thermostat(
177                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
178                )
179                vel, eigv = eigv[0], eigv[1:]
180            else:
181                vel, thermostat_state = thermostat(vel, system["thermostat"])
182            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
183
184            if isotropic:
185                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
186            else:
187                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
188                noise = 0.5 * (noise + noise.T)
189
190            vextvol = a1 * vextvol + a2 * noise
191
192            # apply A
193            if isotropic:
194                vdt2 = 0.5 * dt_bar * vextvol
195                scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0]))
196                vel = vel * scalev
197                scale2 = jnp.exp(vdt2)
198                x = x * (scale1 * scale2)
199                extvol = extvol * (scale1 * scale2) ** 3
200                cell = cell * (scale1 * scale2)
201            else:
202                vextvol = aniso_mask * vextvol
203                vdt2 = 0.5 * dt_bar * vextvol
204                l, O = jnp.linalg.eigh(vdt2)
205                lcorr = jnp.trace(vdt2) / (3 * x.shape[0])
206                Dv = jnp.diag(jnp.exp(-(l + lcorr)))
207                Dx = jnp.diag(jnp.exp(l))
208                scalev = O @ Dv @ O.T
209                scale = scale1 @ (O @ Dx @ O.T)
210
211                extvol = extvol @ scale
212
213                # ensure cell is lower triangular
214                extvol, rotation_matrix = tril_cell_(extvol)
215                cell = extvol
216                vextvol = rotation_matrix.T @ vextvol @ rotation_matrix
217
218                # scale and rotate positions and velocities
219                vel = vel @ (scalev @ rotation_matrix)
220                x = x @ (scale @ rotation_matrix)
221
222            if nbeads is not None:
223                if not isotropic:
224                    eigx = eigx @ rotation_matrix
225                    eigv = eigv @ rotation_matrix
226                x = jnp.concatenate((x[None], eigx), axis=0)
227                vel = jnp.concatenate((vel[None], eigv), axis=0)
228
229            piston_temperature = (us.KELVIN * masspiston / ndof_piston) * jnp.sum(
230                vextvol**2
231            )
232            barostat_state = {
233                **barostat_state,
234                "istep": istep,
235                "rng_key": rng_key,
236                "vextvol": vextvol,
237                "extvol": extvol,
238                "piston_temperature": piston_temperature,
239            }
240            return (
241                x,
242                vel,
243                {
244                    **system,
245                    "barostat": barostat_state,
246                    "cell": cell,
247                    "thermostat": thermostat_state,
248                },
249            )
250
251    elif barostat_name in ["NONE"]:
252        variable_cell = False
253
254        def barostat(x, vel, system):
255            vel, thermostat_state = thermostat(vel, system["thermostat"])
256            return x, vel, {**system, "thermostat": thermostat_state}
257
258    else:
259        raise ValueError(f"Unknown barostat {barostat_name}")
260
261    return barostat, variable_cell, state
def tril_cell_(cell):
264def tril_cell_(cell):
265    cell = jnp.asarray(cell, dtype=float).reshape(3, 3)
266    a = jnp.linalg.norm(cell[0])
267    b = jnp.linalg.norm(cell[1])
268    c = jnp.linalg.norm(cell[2])
269    cos_alpha = jnp.dot(cell[1], cell[2]) / (b * c)
270    cos_beta = jnp.dot(cell[0], cell[2]) / (a * c)
271    cos_gamma = jnp.dot(cell[0], cell[1]) / (a * b)
272    cell_tril = cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma)
273    rotation = cell_tril @ jnp.linalg.inv(cell)
274    return cell_tril, rotation
def cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma):
277def cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma):
278    sin_gamma = jnp.sqrt(1.0 - cos_gamma * cos_gamma)
279
280    # Build the cell vectors
281    va = a * jnp.array([1, 0, 0])
282    vb = b * jnp.array([cos_gamma, sin_gamma, 0])
283    cx = cos_beta
284    cy = (cos_alpha - cos_beta * cos_gamma) / sin_gamma
285    cz_sqr = 1.0 - cx * cx - cy * cy
286    cz = jnp.sqrt(cz_sqr)
287    vc = c * jnp.array([cx, cy, cz])
288
289    return jnp.vstack((va, vb, vc))