fennol.md.barostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7from enum import Enum 8 9from .utils import us 10 11 12def get_barostat( 13 thermostat, 14 simulation_parameters, 15 dt, 16 system_data, 17 fprec, 18 rng_key=None, 19 restart_data={}, 20): 21 state = {} 22 23 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 24 """@keyword[fennol_md] barostat 25 Type of barostat for pressure control (NONE, LGV, LANGEVIN). 26 Default: NONE 27 """ 28 29 kT = system_data.get("kT", None) 30 assert kT is not None, "kT must be specified for NPT/NPH simulations" 31 target_pressure = simulation_parameters.get("target_pressure") 32 """@keyword[fennol_md] target_pressure 33 Target pressure for NPT ensemble simulations. 34 Required for barostat != NONE 35 """ 36 if barostat_name != "NONE": 37 assert ( 38 target_pressure is not None 39 ), "target_pressure must be specified for NPT/NPH simulations" 40 41 nbeads = system_data.get("nbeads", None) 42 variable_cell = True 43 44 anisotropic = simulation_parameters.get("aniso_barostat", False) 45 """@keyword[fennol_md] aniso_barostat 46 Use anisotropic barostat allowing independent cell parameter scaling. 47 Default: False 48 """ 49 50 isotropic = not anisotropic 51 52 pbc_data = system_data["pbc"] 53 start_barostat = simulation_parameters.get("start_barostat", 0.0) 54 """@keyword[fennol_md] start_barostat 55 Time delay before starting barostat pressure coupling. 56 Default: 0.0 57 """ 58 start_time = restart_data.get("simulation_time_ps", 0.0) / us.PS 59 start_barostat = max(0.0, start_barostat - start_time) 60 istart_barostat = int(round(start_barostat / dt)) 61 if istart_barostat > 0 and barostat_name not in ["NONE"]: 62 print( 63 f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)" 64 ) 65 else: 66 istart_barostat = 0 67 68 if barostat_name in ["LGV", "LANGEVIN"]: 69 assert rng_key is not None, "rng_key must be provided for QTB barostat" 70 gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ) 71 """@keyword[fennol_md] gamma_piston 72 Piston friction coefficient for Langevin barostat. 73 Default: 20.0 ps^-1 74 """ 75 tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS) 76 """@keyword[fennol_md] tau_piston 77 Piston time constant for barostat coupling. 78 Default: 200.0 fs 79 """ 80 nat = system_data["nat"] 81 masspiston = 3 * nat * kT * tau_piston**2 82 print( 83 f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2" 84 ) 85 a1 = math.exp(-gamma * dt) 86 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 87 88 rng_key, v_key = jax.random.split(rng_key) 89 if anisotropic: 90 extvol = pbc_data["cell"] 91 vextvol = ( 92 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 93 * (kT / masspiston) ** 0.5 94 ) 95 vextvol = 0.5 * (vextvol + vextvol.T) 96 97 aniso_mask = simulation_parameters.get( 98 "aniso_mask", [True, True, True, True, True, True] 99 ) 100 """@keyword[fennol_md] aniso_mask 101 Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz]. 102 Default: [True, True, True, True, True, True] 103 """ 104 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 105 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 106 ndof_piston = np.sum(aniso_mask) 107 # xx yy zz xy xz yz 108 aniso_mask = np.array( 109 [ 110 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 111 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 112 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 113 ], 114 dtype=np.int32, 115 ) 116 else: 117 extvol = jnp.asarray(pbc_data["volume"]) 118 vextvol = ( 119 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 120 * (kT / masspiston) ** 0.5 121 ) 122 ndof_piston = 1.0 123 124 state["extvol"] = extvol 125 state["vextvol"] = vextvol 126 state["rng_key"] = rng_key 127 state["istep"] = 0 128 129 def barostat(x, vel, system): 130 if nbeads is not None: 131 x, eigx = x[0], x[1:] 132 vel, eigv = vel[0], vel[1:] 133 barostat_state = system["barostat"] 134 extvol = barostat_state["extvol"] 135 vextvol = barostat_state["vextvol"] 136 cell = system["cell"] 137 volume = jnp.abs(jnp.linalg.det(cell)) 138 139 istep = barostat_state["istep"] + 1 140 dt_bar = dt * (istep >= istart_barostat) 141 142 # apply B 143 # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 144 ek = system["ek_c"] if nbeads is not None else system["ek"] 145 Pres = ( 146 system["pressure_tensor"] 147 + ek * jnp.array(np.eye(3) * (2 / (3 * x.shape[0]))) / volume 148 ) 149 if isotropic: 150 dPres = jnp.trace(Pres) - 3 * target_pressure 151 else: 152 dPres = 0.5 * (Pres + Pres.T) - jnp.array(target_pressure * np.eye(3)) 153 154 vextvol = vextvol + ((dt_bar / masspiston) * volume) * dPres 155 156 # apply A 157 if isotropic: 158 vdt2 = 0.5 * dt_bar * vextvol 159 scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0])) 160 vel = vel * scalev 161 scale1 = jnp.exp(vdt2) 162 else: 163 vextvol = aniso_mask * vextvol 164 vdt2 = 0.5 * dt_bar * vextvol 165 l, O = jnp.linalg.eigh(vdt2) 166 lcorr = jnp.trace(vdt2) / (3 * x.shape[0]) 167 Dv = jnp.diag(jnp.exp(-(l + lcorr))) 168 Dx = jnp.diag(jnp.exp(l)) 169 scalev = O @ Dv @ O.T 170 scale1 = O @ Dx @ O.T 171 vel = vel @ scalev 172 173 # apply O 174 if nbeads is not None: 175 eigv, thermostat_state = thermostat( 176 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 177 ) 178 vel, eigv = eigv[0], eigv[1:] 179 else: 180 vel, thermostat_state = thermostat(vel, system["thermostat"]) 181 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 182 183 if isotropic: 184 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 185 else: 186 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 187 noise = 0.5 * (noise + noise.T) 188 189 vextvol = a1 * vextvol + a2 * noise 190 191 # apply A 192 if isotropic: 193 vdt2 = 0.5 * dt_bar * vextvol 194 scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0])) 195 vel = vel * scalev 196 scale2 = jnp.exp(vdt2) 197 x = x * (scale1 * scale2) 198 extvol = extvol * (scale1 * scale2) ** 3 199 cell = cell * (scale1 * scale2) 200 else: 201 vextvol = aniso_mask * vextvol 202 vdt2 = 0.5 * dt_bar * vextvol 203 l, O = jnp.linalg.eigh(vdt2) 204 lcorr = jnp.trace(vdt2) / (3 * x.shape[0]) 205 Dv = jnp.diag(jnp.exp(-(l + lcorr))) 206 Dx = jnp.diag(jnp.exp(l)) 207 scalev = O @ Dv @ O.T 208 scale = scale1 @ (O @ Dx @ O.T) 209 210 extvol = extvol @ scale 211 212 # ensure cell is lower triangular 213 extvol, rotation_matrix = tril_cell_(extvol) 214 cell = extvol 215 vextvol = rotation_matrix.T @ vextvol @ rotation_matrix 216 217 # scale and rotate positions and velocities 218 vel = vel @ (scalev @ rotation_matrix) 219 x = x @ (scale @ rotation_matrix) 220 221 if nbeads is not None: 222 if not isotropic: 223 eigx = eigx @ rotation_matrix 224 eigv = eigv @ rotation_matrix 225 x = jnp.concatenate((x[None], eigx), axis=0) 226 vel = jnp.concatenate((vel[None], eigv), axis=0) 227 228 piston_temperature = (us.KELVIN * masspiston / ndof_piston) * jnp.sum( 229 vextvol**2 230 ) 231 barostat_state = { 232 **barostat_state, 233 "istep": istep, 234 "rng_key": rng_key, 235 "vextvol": vextvol, 236 "extvol": extvol, 237 "piston_temperature": piston_temperature, 238 } 239 return ( 240 x, 241 vel, 242 { 243 **system, 244 "barostat": barostat_state, 245 "cell": cell, 246 "thermostat": thermostat_state, 247 }, 248 ) 249 250 elif barostat_name in ["NONE"]: 251 variable_cell = False 252 253 def barostat(x, vel, system): 254 vel, thermostat_state = thermostat(vel, system["thermostat"]) 255 return x, vel, {**system, "thermostat": thermostat_state} 256 257 else: 258 raise ValueError(f"Unknown barostat {barostat_name}") 259 260 return barostat, variable_cell, state 261 262 263def tril_cell_(cell): 264 cell = jnp.asarray(cell, dtype=float).reshape(3, 3) 265 a = jnp.linalg.norm(cell[0]) 266 b = jnp.linalg.norm(cell[1]) 267 c = jnp.linalg.norm(cell[2]) 268 cos_alpha = jnp.dot(cell[1], cell[2]) / (b * c) 269 cos_beta = jnp.dot(cell[0], cell[2]) / (a * c) 270 cos_gamma = jnp.dot(cell[0], cell[1]) / (a * b) 271 cell_tril = cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma) 272 rotation = cell_tril @ jnp.linalg.inv(cell) 273 return cell_tril, rotation 274 275 276def cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma): 277 sin_gamma = jnp.sqrt(1.0 - cos_gamma * cos_gamma) 278 279 # Build the cell vectors 280 va = a * jnp.array([1, 0, 0]) 281 vb = b * jnp.array([cos_gamma, sin_gamma, 0]) 282 cx = cos_beta 283 cy = (cos_alpha - cos_beta * cos_gamma) / sin_gamma 284 cz_sqr = 1.0 - cx * cx - cy * cy 285 cz = jnp.sqrt(cz_sqr) 286 vc = c * jnp.array([cx, cy, cz]) 287 288 return jnp.vstack((va, vb, vc))
def
get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
13def get_barostat( 14 thermostat, 15 simulation_parameters, 16 dt, 17 system_data, 18 fprec, 19 rng_key=None, 20 restart_data={}, 21): 22 state = {} 23 24 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 25 """@keyword[fennol_md] barostat 26 Type of barostat for pressure control (NONE, LGV, LANGEVIN). 27 Default: NONE 28 """ 29 30 kT = system_data.get("kT", None) 31 assert kT is not None, "kT must be specified for NPT/NPH simulations" 32 target_pressure = simulation_parameters.get("target_pressure") 33 """@keyword[fennol_md] target_pressure 34 Target pressure for NPT ensemble simulations. 35 Required for barostat != NONE 36 """ 37 if barostat_name != "NONE": 38 assert ( 39 target_pressure is not None 40 ), "target_pressure must be specified for NPT/NPH simulations" 41 42 nbeads = system_data.get("nbeads", None) 43 variable_cell = True 44 45 anisotropic = simulation_parameters.get("aniso_barostat", False) 46 """@keyword[fennol_md] aniso_barostat 47 Use anisotropic barostat allowing independent cell parameter scaling. 48 Default: False 49 """ 50 51 isotropic = not anisotropic 52 53 pbc_data = system_data["pbc"] 54 start_barostat = simulation_parameters.get("start_barostat", 0.0) 55 """@keyword[fennol_md] start_barostat 56 Time delay before starting barostat pressure coupling. 57 Default: 0.0 58 """ 59 start_time = restart_data.get("simulation_time_ps", 0.0) / us.PS 60 start_barostat = max(0.0, start_barostat - start_time) 61 istart_barostat = int(round(start_barostat / dt)) 62 if istart_barostat > 0 and barostat_name not in ["NONE"]: 63 print( 64 f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)" 65 ) 66 else: 67 istart_barostat = 0 68 69 if barostat_name in ["LGV", "LANGEVIN"]: 70 assert rng_key is not None, "rng_key must be provided for QTB barostat" 71 gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ) 72 """@keyword[fennol_md] gamma_piston 73 Piston friction coefficient for Langevin barostat. 74 Default: 20.0 ps^-1 75 """ 76 tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS) 77 """@keyword[fennol_md] tau_piston 78 Piston time constant for barostat coupling. 79 Default: 200.0 fs 80 """ 81 nat = system_data["nat"] 82 masspiston = 3 * nat * kT * tau_piston**2 83 print( 84 f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2" 85 ) 86 a1 = math.exp(-gamma * dt) 87 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 88 89 rng_key, v_key = jax.random.split(rng_key) 90 if anisotropic: 91 extvol = pbc_data["cell"] 92 vextvol = ( 93 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 94 * (kT / masspiston) ** 0.5 95 ) 96 vextvol = 0.5 * (vextvol + vextvol.T) 97 98 aniso_mask = simulation_parameters.get( 99 "aniso_mask", [True, True, True, True, True, True] 100 ) 101 """@keyword[fennol_md] aniso_mask 102 Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz]. 103 Default: [True, True, True, True, True, True] 104 """ 105 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 106 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 107 ndof_piston = np.sum(aniso_mask) 108 # xx yy zz xy xz yz 109 aniso_mask = np.array( 110 [ 111 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 112 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 113 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 114 ], 115 dtype=np.int32, 116 ) 117 else: 118 extvol = jnp.asarray(pbc_data["volume"]) 119 vextvol = ( 120 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 121 * (kT / masspiston) ** 0.5 122 ) 123 ndof_piston = 1.0 124 125 state["extvol"] = extvol 126 state["vextvol"] = vextvol 127 state["rng_key"] = rng_key 128 state["istep"] = 0 129 130 def barostat(x, vel, system): 131 if nbeads is not None: 132 x, eigx = x[0], x[1:] 133 vel, eigv = vel[0], vel[1:] 134 barostat_state = system["barostat"] 135 extvol = barostat_state["extvol"] 136 vextvol = barostat_state["vextvol"] 137 cell = system["cell"] 138 volume = jnp.abs(jnp.linalg.det(cell)) 139 140 istep = barostat_state["istep"] + 1 141 dt_bar = dt * (istep >= istart_barostat) 142 143 # apply B 144 # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 145 ek = system["ek_c"] if nbeads is not None else system["ek"] 146 Pres = ( 147 system["pressure_tensor"] 148 + ek * jnp.array(np.eye(3) * (2 / (3 * x.shape[0]))) / volume 149 ) 150 if isotropic: 151 dPres = jnp.trace(Pres) - 3 * target_pressure 152 else: 153 dPres = 0.5 * (Pres + Pres.T) - jnp.array(target_pressure * np.eye(3)) 154 155 vextvol = vextvol + ((dt_bar / masspiston) * volume) * dPres 156 157 # apply A 158 if isotropic: 159 vdt2 = 0.5 * dt_bar * vextvol 160 scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0])) 161 vel = vel * scalev 162 scale1 = jnp.exp(vdt2) 163 else: 164 vextvol = aniso_mask * vextvol 165 vdt2 = 0.5 * dt_bar * vextvol 166 l, O = jnp.linalg.eigh(vdt2) 167 lcorr = jnp.trace(vdt2) / (3 * x.shape[0]) 168 Dv = jnp.diag(jnp.exp(-(l + lcorr))) 169 Dx = jnp.diag(jnp.exp(l)) 170 scalev = O @ Dv @ O.T 171 scale1 = O @ Dx @ O.T 172 vel = vel @ scalev 173 174 # apply O 175 if nbeads is not None: 176 eigv, thermostat_state = thermostat( 177 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 178 ) 179 vel, eigv = eigv[0], eigv[1:] 180 else: 181 vel, thermostat_state = thermostat(vel, system["thermostat"]) 182 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 183 184 if isotropic: 185 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 186 else: 187 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 188 noise = 0.5 * (noise + noise.T) 189 190 vextvol = a1 * vextvol + a2 * noise 191 192 # apply A 193 if isotropic: 194 vdt2 = 0.5 * dt_bar * vextvol 195 scalev = jnp.exp(-vdt2 * (1 + 1.0 / x.shape[0])) 196 vel = vel * scalev 197 scale2 = jnp.exp(vdt2) 198 x = x * (scale1 * scale2) 199 extvol = extvol * (scale1 * scale2) ** 3 200 cell = cell * (scale1 * scale2) 201 else: 202 vextvol = aniso_mask * vextvol 203 vdt2 = 0.5 * dt_bar * vextvol 204 l, O = jnp.linalg.eigh(vdt2) 205 lcorr = jnp.trace(vdt2) / (3 * x.shape[0]) 206 Dv = jnp.diag(jnp.exp(-(l + lcorr))) 207 Dx = jnp.diag(jnp.exp(l)) 208 scalev = O @ Dv @ O.T 209 scale = scale1 @ (O @ Dx @ O.T) 210 211 extvol = extvol @ scale 212 213 # ensure cell is lower triangular 214 extvol, rotation_matrix = tril_cell_(extvol) 215 cell = extvol 216 vextvol = rotation_matrix.T @ vextvol @ rotation_matrix 217 218 # scale and rotate positions and velocities 219 vel = vel @ (scalev @ rotation_matrix) 220 x = x @ (scale @ rotation_matrix) 221 222 if nbeads is not None: 223 if not isotropic: 224 eigx = eigx @ rotation_matrix 225 eigv = eigv @ rotation_matrix 226 x = jnp.concatenate((x[None], eigx), axis=0) 227 vel = jnp.concatenate((vel[None], eigv), axis=0) 228 229 piston_temperature = (us.KELVIN * masspiston / ndof_piston) * jnp.sum( 230 vextvol**2 231 ) 232 barostat_state = { 233 **barostat_state, 234 "istep": istep, 235 "rng_key": rng_key, 236 "vextvol": vextvol, 237 "extvol": extvol, 238 "piston_temperature": piston_temperature, 239 } 240 return ( 241 x, 242 vel, 243 { 244 **system, 245 "barostat": barostat_state, 246 "cell": cell, 247 "thermostat": thermostat_state, 248 }, 249 ) 250 251 elif barostat_name in ["NONE"]: 252 variable_cell = False 253 254 def barostat(x, vel, system): 255 vel, thermostat_state = thermostat(vel, system["thermostat"]) 256 return x, vel, {**system, "thermostat": thermostat_state} 257 258 else: 259 raise ValueError(f"Unknown barostat {barostat_name}") 260 261 return barostat, variable_cell, state
def
tril_cell_(cell):
264def tril_cell_(cell): 265 cell = jnp.asarray(cell, dtype=float).reshape(3, 3) 266 a = jnp.linalg.norm(cell[0]) 267 b = jnp.linalg.norm(cell[1]) 268 c = jnp.linalg.norm(cell[2]) 269 cos_alpha = jnp.dot(cell[1], cell[2]) / (b * c) 270 cos_beta = jnp.dot(cell[0], cell[2]) / (a * c) 271 cos_gamma = jnp.dot(cell[0], cell[1]) / (a * b) 272 cell_tril = cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma) 273 rotation = cell_tril @ jnp.linalg.inv(cell) 274 return cell_tril, rotation
def
cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma):
277def cell_lengths_angles_to_tril(a, b, c, cos_alpha, cos_beta, cos_gamma): 278 sin_gamma = jnp.sqrt(1.0 - cos_gamma * cos_gamma) 279 280 # Build the cell vectors 281 va = a * jnp.array([1, 0, 0]) 282 vb = b * jnp.array([cos_gamma, sin_gamma, 0]) 283 cx = cos_beta 284 cy = (cos_alpha - cos_beta * cos_gamma) / sin_gamma 285 cz_sqr = 1.0 - cx * cx - cy * cy 286 cz = jnp.sqrt(cz_sqr) 287 vc = c * jnp.array([cx, cy, cz]) 288 289 return jnp.vstack((va, vb, vc))