fennol.utils
1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics 2from .atomic_units import AtomicUnits, UnitSystem 3from typing import Dict, Any,Sequence, Union 4import jax 5import jax.numpy as jnp 6import numpy as np 7from ase.geometry.cell import cellpar_to_cell 8import numba 9 10def minmaxone(x, name=""): 11 print(name, x.min(), x.max(), (x**2).mean() ** 0.5) 12 13def minmaxone_jax(x, name=""): 14 jax.debug.print( 15 "{name} {min} {max} {mean}", 16 name=name, 17 min=x.min(), 18 max=x.max(), 19 mean=(x**2).mean(), 20 ) 21 22def cell_vectors_to_lengths_angles(cell): 23 cell = cell.reshape(3, 3) 24 a = np.linalg.norm(cell[0]) 25 b = np.linalg.norm(cell[1]) 26 c = np.linalg.norm(cell[2]) 27 degree = 180.0 / np.pi 28 alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c)) 29 beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c)) 30 gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b)) 31 return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype) 32 33def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None): 34 return cellpar_to_cell(lengths_angles, ab_normal=ab_normal, a_direction=a_direction) 35 36def parse_cell(cell): 37 if cell is None: 38 return None 39 cell = np.asarray(cell, dtype=float).flatten() 40 assert cell.size in [1, 3, 6, 9], "Cell must be of size 1, 3, 6 or 9" 41 if cell.size == 9: 42 return cell.reshape(3, 3) 43 44 return cell_lengths_angles_to_vectors(cell) 45 46def cell_is_triangular(cell, tol=1e-5): 47 if cell is None: 48 return False 49 cell = np.asarray(cell, dtype=float).reshape(3, 3) 50 return np.all(np.abs(cell - np.tril(cell)) < tol) 51 52def tril_cell(cell,reciprocal_cell=None): 53 if cell is None: 54 return None 55 cell = np.asarray(cell, dtype=float).reshape(3, 3) 56 if reciprocal_cell is None: 57 reciprocal_cell = np.linalg.inv(cell) 58 length_angles = cell_vectors_to_lengths_angles(cell) 59 cell_tril = cell_lengths_angles_to_vectors(length_angles) 60 rotation = reciprocal_cell @ cell_tril 61 return cell_tril, rotation 62 63 64def mask_filter_1d(mask, max_size, *values_fill): 65 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 66 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 67 outputs = [] 68 for value, fill in values_fill: 69 shape = list(value.shape) 70 shape[0] = max_size 71 output = ( 72 jnp.full(shape, fill, dtype=value.dtype) 73 .at[scatter_idx] 74 .set(value, mode="drop") 75 ) 76 outputs.append(output) 77 if cumsum.size == 0: 78 return outputs, scatter_idx, 0 79 return outputs, scatter_idx, cumsum[-1] 80 81 82def deep_update( 83 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 84) -> Dict[Any, Any]: 85 updated_mapping = mapping.copy() 86 for updating_mapping in updating_mappings: 87 for k, v in updating_mapping.items(): 88 if ( 89 k in updated_mapping 90 and isinstance(updated_mapping[k], dict) 91 and isinstance(v, dict) 92 ): 93 updated_mapping[k] = deep_update(updated_mapping[k], v) 94 else: 95 updated_mapping[k] = v 96 return updated_mapping 97 98 99class Counter: 100 def __init__(self, nseg, startsave=1): 101 self.i = 0 102 self.i_avg = 0 103 self.nseg = nseg 104 self.startsave = startsave 105 106 @property 107 def count(self): 108 return self.i 109 110 @property 111 def count_avg(self): 112 return self.i_avg 113 114 @property 115 def nsample(self): 116 return max(self.count_avg - self.startsave + 1, 1) 117 118 @property 119 def is_reset_step(self): 120 return self.count == 0 121 122 def reset_avg(self): 123 self.i_avg = 0 124 125 def reset_all(self): 126 self.i = 0 127 self.i_avg = 0 128 129 def increment(self): 130 self.i = self.i + 1 131 if self.i >= self.nseg: 132 self.i = 0 133 self.i_avg = self.i_avg + 1 134 135### TOPLOGY DETECTION 136@numba.njit 137def _detect_bonds_pbc(radii,coordinates,cell): 138 reciprocal_cell = np.linalg.inv(cell).T 139 cell = cell.T 140 nat = len(radii) 141 bond1 = [] 142 bond2 = [] 143 distances = [] 144 for i in range(nat): 145 for j in range(i + 1, nat): 146 vec = coordinates[i] - coordinates[j] 147 vecpbc = reciprocal_cell @ vec 148 vecpbc -= np.round(vecpbc) 149 vec = cell @ vecpbc 150 dist = np.linalg.norm(vec) 151 if dist < radii[i] + radii[j] + 0.4 and dist > 0.4: 152 bond1.append(i) 153 bond2.append(j) 154 distances.append(dist) 155 return bond1,bond2, distances 156 157@numba.njit 158def _detect_bonds(radii,coordinates): 159 nat = len(radii) 160 bond1 = [] 161 bond2 = [] 162 distances = [] 163 for i in range(nat): 164 for j in range(i + 1, nat): 165 vec = coordinates[i] - coordinates[j] 166 dist = np.linalg.norm(vec) 167 if dist < radii[i] + radii[j] + 0.4 and dist > 0.4: 168 bond1.append(i) 169 bond2.append(j) 170 distances.append(dist) 171 return bond1,bond2, distances 172 173def detect_topology(species,coordinates, cell=None): 174 """ 175 Detects the topology of a system based on species and coordinates. 176 Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. 177 Inspired by OpenBabel's ConnectTheDots in mol.cpp 178 """ 179 from .periodic_table import COV_RADII, UFF_MAX_COORDINATION 180 radii = (COV_RADII* AtomicUnits.ANG)[species] 181 max_coord = UFF_MAX_COORDINATION[species] 182 183 if cell is not None: 184 bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell) 185 else: 186 bond1,bond2,distances = _detect_bonds(radii, coordinates) 187 188 bond1 = np.array(bond1, dtype=np.int32) 189 bond2 = np.array(bond2, dtype=np.int32) 190 bonds = np.stack((bond1, bond2), axis=1) 191 192 coord = np.zeros(len(species), dtype=np.int32) 193 np.add.at(coord, bonds[:, 0], 1) 194 np.add.at(coord, bonds[:, 1], 1) 195 196 if np.all(coord <= max_coord): 197 return bonds 198 199 distances = np.array(distances, dtype=np.float32) 200 radiibonds = radii[bonds] 201 req = radiibonds.sum(axis=1) 202 rminbonds = radiibonds.min(axis=1) 203 sorted_indices = np.lexsort((-distances/req, rminbonds)) 204 205 bonds = bonds[sorted_indices,:] 206 distances = distances[sorted_indices] 207 208 true_bonds = [] 209 for ibond in range(bonds.shape[0]): 210 i,j = bonds[ibond] 211 ci, cj = coord[i], coord[j] 212 mci, mcj = max_coord[i], max_coord[j] 213 if ci <= mci and cj <= mcj: 214 true_bonds.append((i, j)) 215 else: 216 coord[i] -= 1 217 coord[j] -= 1 218 219 true_bonds = np.array(true_bonds, dtype=np.int32) 220 sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0])) 221 true_bonds = true_bonds[sorted_indices, :] 222 223 return true_bonds 224 225def get_energy_gradient_function( 226 energy_function, 227 gradient_keys: Sequence[str], 228 jit: bool = True, 229 ): 230 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 231 232 def energy_gradient(data): 233 def _etot(inputs): 234 if "strain" in inputs: 235 scaling = inputs["strain"] 236 batch_index = data["batch_index"] 237 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 238 coordinates = jax.vmap(jnp.matmul)( 239 coordinates, scaling[batch_index] 240 ) 241 inputs = {**inputs, "coordinates": coordinates} 242 if "cells" in inputs or "cells" in data: 243 cells = inputs["cells"] if "cells" in inputs else data["cells"] 244 cells = jax.vmap(jnp.matmul)(cells, scaling) 245 inputs["cells"] = cells 246 if "cells" in inputs: 247 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 248 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 249 energy, out = energy_function(inputs) 250 return energy.sum(), out 251 252 if "strain" in gradient_keys and "strain" not in data: 253 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 254 inputs = {k: data[k] for k in gradient_keys} 255 de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs) 256 257 return ( 258 de, 259 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 260 ) 261 262 if jit: 263 return jax.jit(energy_gradient) 264 else: 265 return energy_gradient 266 267 268def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray: 269 interval = [int(i) for i in indices_interval] 270 indices = [] 271 while len(interval) > 0: 272 i = interval.pop(0) 273 if i > 0: 274 indices.append(i) 275 elif i < 0: 276 start = -i 277 end = interval.pop(0) 278 assert end > start, "Syntax error in ligand indices. End index must be greater than start index." 279 indices.extend(range(start, end + 1)) 280 else: 281 raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.") 282 indices = np.unique(np.array(indices, dtype=np.int32)) 283 return indices - 1 # Convert to zero-based indexing
def
minmaxone(x, name=''):
def
minmaxone_jax(x, name=''):
def
cell_vectors_to_lengths_angles(cell):
23def cell_vectors_to_lengths_angles(cell): 24 cell = cell.reshape(3, 3) 25 a = np.linalg.norm(cell[0]) 26 b = np.linalg.norm(cell[1]) 27 c = np.linalg.norm(cell[2]) 28 degree = 180.0 / np.pi 29 alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c)) 30 beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c)) 31 gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b)) 32 return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype)
def
cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
def
parse_cell(cell):
def
cell_is_triangular(cell, tol=1e-05):
def
tril_cell(cell, reciprocal_cell=None):
53def tril_cell(cell,reciprocal_cell=None): 54 if cell is None: 55 return None 56 cell = np.asarray(cell, dtype=float).reshape(3, 3) 57 if reciprocal_cell is None: 58 reciprocal_cell = np.linalg.inv(cell) 59 length_angles = cell_vectors_to_lengths_angles(cell) 60 cell_tril = cell_lengths_angles_to_vectors(length_angles) 61 rotation = reciprocal_cell @ cell_tril 62 return cell_tril, rotation
def
mask_filter_1d(mask, max_size, *values_fill):
65def mask_filter_1d(mask, max_size, *values_fill): 66 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 67 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 68 outputs = [] 69 for value, fill in values_fill: 70 shape = list(value.shape) 71 shape[0] = max_size 72 output = ( 73 jnp.full(shape, fill, dtype=value.dtype) 74 .at[scatter_idx] 75 .set(value, mode="drop") 76 ) 77 outputs.append(output) 78 if cumsum.size == 0: 79 return outputs, scatter_idx, 0 80 return outputs, scatter_idx, cumsum[-1]
def
deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
83def deep_update( 84 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 85) -> Dict[Any, Any]: 86 updated_mapping = mapping.copy() 87 for updating_mapping in updating_mappings: 88 for k, v in updating_mapping.items(): 89 if ( 90 k in updated_mapping 91 and isinstance(updated_mapping[k], dict) 92 and isinstance(v, dict) 93 ): 94 updated_mapping[k] = deep_update(updated_mapping[k], v) 95 else: 96 updated_mapping[k] = v 97 return updated_mapping
class
Counter:
100class Counter: 101 def __init__(self, nseg, startsave=1): 102 self.i = 0 103 self.i_avg = 0 104 self.nseg = nseg 105 self.startsave = startsave 106 107 @property 108 def count(self): 109 return self.i 110 111 @property 112 def count_avg(self): 113 return self.i_avg 114 115 @property 116 def nsample(self): 117 return max(self.count_avg - self.startsave + 1, 1) 118 119 @property 120 def is_reset_step(self): 121 return self.count == 0 122 123 def reset_avg(self): 124 self.i_avg = 0 125 126 def reset_all(self): 127 self.i = 0 128 self.i_avg = 0 129 130 def increment(self): 131 self.i = self.i + 1 132 if self.i >= self.nseg: 133 self.i = 0 134 self.i_avg = self.i_avg + 1
def
detect_topology(species, coordinates, cell=None):
174def detect_topology(species,coordinates, cell=None): 175 """ 176 Detects the topology of a system based on species and coordinates. 177 Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. 178 Inspired by OpenBabel's ConnectTheDots in mol.cpp 179 """ 180 from .periodic_table import COV_RADII, UFF_MAX_COORDINATION 181 radii = (COV_RADII* AtomicUnits.ANG)[species] 182 max_coord = UFF_MAX_COORDINATION[species] 183 184 if cell is not None: 185 bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell) 186 else: 187 bond1,bond2,distances = _detect_bonds(radii, coordinates) 188 189 bond1 = np.array(bond1, dtype=np.int32) 190 bond2 = np.array(bond2, dtype=np.int32) 191 bonds = np.stack((bond1, bond2), axis=1) 192 193 coord = np.zeros(len(species), dtype=np.int32) 194 np.add.at(coord, bonds[:, 0], 1) 195 np.add.at(coord, bonds[:, 1], 1) 196 197 if np.all(coord <= max_coord): 198 return bonds 199 200 distances = np.array(distances, dtype=np.float32) 201 radiibonds = radii[bonds] 202 req = radiibonds.sum(axis=1) 203 rminbonds = radiibonds.min(axis=1) 204 sorted_indices = np.lexsort((-distances/req, rminbonds)) 205 206 bonds = bonds[sorted_indices,:] 207 distances = distances[sorted_indices] 208 209 true_bonds = [] 210 for ibond in range(bonds.shape[0]): 211 i,j = bonds[ibond] 212 ci, cj = coord[i], coord[j] 213 mci, mcj = max_coord[i], max_coord[j] 214 if ci <= mci and cj <= mcj: 215 true_bonds.append((i, j)) 216 else: 217 coord[i] -= 1 218 coord[j] -= 1 219 220 true_bonds = np.array(true_bonds, dtype=np.int32) 221 sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0])) 222 true_bonds = true_bonds[sorted_indices, :] 223 224 return true_bonds
Detects the topology of a system based on species and coordinates. Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. Inspired by OpenBabel's ConnectTheDots in mol.cpp
def
get_energy_gradient_function(energy_function, gradient_keys: Sequence[str], jit: bool = True):
226def get_energy_gradient_function( 227 energy_function, 228 gradient_keys: Sequence[str], 229 jit: bool = True, 230 ): 231 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 232 233 def energy_gradient(data): 234 def _etot(inputs): 235 if "strain" in inputs: 236 scaling = inputs["strain"] 237 batch_index = data["batch_index"] 238 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 239 coordinates = jax.vmap(jnp.matmul)( 240 coordinates, scaling[batch_index] 241 ) 242 inputs = {**inputs, "coordinates": coordinates} 243 if "cells" in inputs or "cells" in data: 244 cells = inputs["cells"] if "cells" in inputs else data["cells"] 245 cells = jax.vmap(jnp.matmul)(cells, scaling) 246 inputs["cells"] = cells 247 if "cells" in inputs: 248 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 249 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 250 energy, out = energy_function(inputs) 251 return energy.sum(), out 252 253 if "strain" in gradient_keys and "strain" not in data: 254 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 255 inputs = {k: data[k] for k in gradient_keys} 256 de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs) 257 258 return ( 259 de, 260 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 261 ) 262 263 if jit: 264 return jax.jit(energy_gradient) 265 else: 266 return energy_gradient
Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys
def
read_tinker_interval(indices_interval: Sequence[Union[int, str]]) -> numpy.ndarray:
269def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray: 270 interval = [int(i) for i in indices_interval] 271 indices = [] 272 while len(interval) > 0: 273 i = interval.pop(0) 274 if i > 0: 275 indices.append(i) 276 elif i < 0: 277 start = -i 278 end = interval.pop(0) 279 assert end > start, "Syntax error in ligand indices. End index must be greater than start index." 280 indices.extend(range(start, end + 1)) 281 else: 282 raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.") 283 indices = np.unique(np.array(indices, dtype=np.int32)) 284 return indices - 1 # Convert to zero-based indexing