fennol.utils

  1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics
  2from .atomic_units import AtomicUnits, UnitSystem
  3from typing import Dict, Any,Sequence, Union
  4import jax
  5import jax.numpy as jnp
  6import numpy as np
  7from ase.geometry.cell import cellpar_to_cell
  8import numba
  9
 10def minmaxone(x, name=""):
 11    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
 12
 13def minmaxone_jax(x, name=""):
 14    jax.debug.print(
 15        "{name}  {min}  {max}  {mean}",
 16        name=name,
 17        min=x.min(),
 18        max=x.max(),
 19        mean=(x**2).mean(),
 20    )
 21
 22def cell_vectors_to_lengths_angles(cell):
 23    cell = cell.reshape(3, 3)
 24    a = np.linalg.norm(cell[0])
 25    b = np.linalg.norm(cell[1])
 26    c = np.linalg.norm(cell[2])
 27    degree = 180.0 / np.pi
 28    alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c))
 29    beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c))
 30    gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b))
 31    return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype)
 32
 33def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
 34    return cellpar_to_cell(lengths_angles, ab_normal=ab_normal, a_direction=a_direction)
 35
 36def parse_cell(cell):
 37    if cell is None:
 38        return None
 39    cell = np.asarray(cell, dtype=float).flatten()
 40    assert cell.size in [1, 3, 6, 9], "Cell must be of size 1, 3, 6 or 9"
 41    if cell.size == 9:
 42        return cell.reshape(3, 3)
 43    
 44    return cell_lengths_angles_to_vectors(cell)
 45
 46def cell_is_triangular(cell, tol=1e-5):
 47    if cell is None:
 48        return False
 49    cell = np.asarray(cell, dtype=float).reshape(3, 3)
 50    return np.all(np.abs(cell - np.tril(cell)) < tol)
 51
 52def tril_cell(cell,reciprocal_cell=None):
 53    if cell is None:
 54        return None
 55    cell = np.asarray(cell, dtype=float).reshape(3, 3)
 56    if reciprocal_cell is None:
 57        reciprocal_cell = np.linalg.inv(cell)
 58    length_angles = cell_vectors_to_lengths_angles(cell)
 59    cell_tril = cell_lengths_angles_to_vectors(length_angles)
 60    rotation = reciprocal_cell @ cell_tril
 61    return cell_tril, rotation
 62
 63
 64def mask_filter_1d(mask, max_size, *values_fill):
 65    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
 66    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
 67    outputs = []
 68    for value, fill in values_fill:
 69        shape = list(value.shape)
 70        shape[0] = max_size
 71        output = (
 72            jnp.full(shape, fill, dtype=value.dtype)
 73            .at[scatter_idx]
 74            .set(value, mode="drop")
 75        )
 76        outputs.append(output)
 77    if cumsum.size == 0:
 78        return outputs, scatter_idx, 0
 79    return outputs, scatter_idx, cumsum[-1]
 80
 81
 82def deep_update(
 83    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
 84) -> Dict[Any, Any]:
 85    updated_mapping = mapping.copy()
 86    for updating_mapping in updating_mappings:
 87        for k, v in updating_mapping.items():
 88            if (
 89                k in updated_mapping
 90                and isinstance(updated_mapping[k], dict)
 91                and isinstance(v, dict)
 92            ):
 93                updated_mapping[k] = deep_update(updated_mapping[k], v)
 94            else:
 95                updated_mapping[k] = v
 96    return updated_mapping
 97
 98
 99class Counter:
100    def __init__(self, nseg, startsave=1):
101        self.i = 0
102        self.i_avg = 0
103        self.nseg = nseg
104        self.startsave = startsave
105
106    @property
107    def count(self):
108        return self.i
109
110    @property
111    def count_avg(self):
112        return self.i_avg
113
114    @property
115    def nsample(self):
116        return max(self.count_avg - self.startsave + 1, 1)
117
118    @property
119    def is_reset_step(self):
120        return self.count == 0
121
122    def reset_avg(self):
123        self.i_avg = 0
124
125    def reset_all(self):
126        self.i = 0
127        self.i_avg = 0
128
129    def increment(self):
130        self.i = self.i + 1
131        if self.i >= self.nseg:
132            self.i = 0
133            self.i_avg = self.i_avg + 1
134
135### TOPLOGY DETECTION
136@numba.njit
137def _detect_bonds_pbc(radii,coordinates,cell):
138    reciprocal_cell = np.linalg.inv(cell).T
139    cell = cell.T
140    nat = len(radii)
141    bond1 = []
142    bond2 = []
143    distances = []
144    for i in range(nat):
145        for j in range(i + 1, nat):
146            vec = coordinates[i] - coordinates[j]
147            vecpbc = reciprocal_cell @ vec
148            vecpbc -= np.round(vecpbc)
149            vec = cell @ vecpbc
150            dist = np.linalg.norm(vec)
151            if dist < radii[i] + radii[j] + 0.4 and dist > 0.4:
152                bond1.append(i)
153                bond2.append(j)
154                distances.append(dist)
155    return bond1,bond2, distances
156
157@numba.njit
158def _detect_bonds(radii,coordinates):
159    nat = len(radii)
160    bond1 = []
161    bond2 = []
162    distances = []
163    for i in range(nat):
164        for j in range(i + 1, nat):
165            vec = coordinates[i] - coordinates[j]
166            dist = np.linalg.norm(vec)
167            if dist < radii[i] + radii[j] + 0.4 and dist > 0.4:
168                bond1.append(i)
169                bond2.append(j)
170                distances.append(dist)
171    return bond1,bond2, distances
172
173def detect_topology(species,coordinates, cell=None):
174    """
175    Detects the topology of a system based on species and coordinates.
176    Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond.
177    Inspired by OpenBabel's ConnectTheDots in mol.cpp
178    """
179    from .periodic_table import COV_RADII, UFF_MAX_COORDINATION
180    radii = (COV_RADII* AtomicUnits.ANG)[species]
181    max_coord = UFF_MAX_COORDINATION[species]
182
183    if cell is not None:
184        bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell)
185    else:
186        bond1,bond2,distances = _detect_bonds(radii, coordinates)
187
188    bond1 = np.array(bond1, dtype=np.int32)
189    bond2 = np.array(bond2, dtype=np.int32)
190    bonds = np.stack((bond1, bond2), axis=1)
191    
192    coord = np.zeros(len(species), dtype=np.int32)
193    np.add.at(coord, bonds[:, 0], 1)
194    np.add.at(coord, bonds[:, 1], 1)
195
196    if np.all(coord <= max_coord):
197        return bonds
198
199    distances = np.array(distances, dtype=np.float32)
200    radiibonds = radii[bonds]
201    req = radiibonds.sum(axis=1)
202    rminbonds = radiibonds.min(axis=1)
203    sorted_indices = np.lexsort((-distances/req, rminbonds))
204
205    bonds = bonds[sorted_indices,:]
206    distances = distances[sorted_indices]
207
208    true_bonds = []
209    for ibond in range(bonds.shape[0]):
210        i,j = bonds[ibond]
211        ci, cj = coord[i], coord[j]
212        mci, mcj = max_coord[i], max_coord[j]
213        if ci <= mci and cj <= mcj:
214            true_bonds.append((i, j))
215        else:
216            coord[i] -= 1
217            coord[j] -= 1
218
219    true_bonds = np.array(true_bonds, dtype=np.int32)
220    sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0]))
221    true_bonds = true_bonds[sorted_indices, :]
222
223    return true_bonds
224
225def get_energy_gradient_function(
226        energy_function,
227        gradient_keys: Sequence[str],
228        jit: bool = True,
229    ):
230        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
231
232        def energy_gradient(data):
233            def _etot(inputs):
234                if "strain" in inputs:
235                    scaling = inputs["strain"]
236                    batch_index = data["batch_index"]
237                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
238                    coordinates = jax.vmap(jnp.matmul)(
239                        coordinates, scaling[batch_index]
240                    )
241                    inputs = {**inputs, "coordinates": coordinates}
242                    if "cells" in inputs or "cells" in data:
243                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
244                        cells = jax.vmap(jnp.matmul)(cells, scaling)
245                        inputs["cells"] = cells
246                if "cells" in inputs:
247                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
248                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
249                energy, out = energy_function(inputs)
250                return energy.sum(), out
251
252            if "strain" in gradient_keys and "strain" not in data:
253                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
254            inputs = {k: data[k] for k in gradient_keys}
255            de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs)
256
257            return (
258                de,
259                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
260            )
261
262        if jit:
263            return jax.jit(energy_gradient)
264        else:
265            return energy_gradient
266
267
268def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray:
269    interval = [int(i) for i in indices_interval]
270    indices = []
271    while len(interval) > 0:
272        i = interval.pop(0)
273        if i > 0:
274            indices.append(i)
275        elif i < 0:
276            start = -i
277            end = interval.pop(0)
278            assert end > start, "Syntax error in ligand indices. End index must be greater than start index."
279            indices.extend(range(start, end + 1))
280        else:
281            raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.")
282    indices = np.unique(np.array(indices, dtype=np.int32))
283    return indices - 1  # Convert to zero-based indexing
def minmaxone(x, name=''):
11def minmaxone(x, name=""):
12    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
def minmaxone_jax(x, name=''):
14def minmaxone_jax(x, name=""):
15    jax.debug.print(
16        "{name}  {min}  {max}  {mean}",
17        name=name,
18        min=x.min(),
19        max=x.max(),
20        mean=(x**2).mean(),
21    )
def cell_vectors_to_lengths_angles(cell):
23def cell_vectors_to_lengths_angles(cell):
24    cell = cell.reshape(3, 3)
25    a = np.linalg.norm(cell[0])
26    b = np.linalg.norm(cell[1])
27    c = np.linalg.norm(cell[2])
28    degree = 180.0 / np.pi
29    alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c))
30    beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c))
31    gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b))
32    return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype)
def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
34def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
35    return cellpar_to_cell(lengths_angles, ab_normal=ab_normal, a_direction=a_direction)
def parse_cell(cell):
37def parse_cell(cell):
38    if cell is None:
39        return None
40    cell = np.asarray(cell, dtype=float).flatten()
41    assert cell.size in [1, 3, 6, 9], "Cell must be of size 1, 3, 6 or 9"
42    if cell.size == 9:
43        return cell.reshape(3, 3)
44    
45    return cell_lengths_angles_to_vectors(cell)
def cell_is_triangular(cell, tol=1e-05):
47def cell_is_triangular(cell, tol=1e-5):
48    if cell is None:
49        return False
50    cell = np.asarray(cell, dtype=float).reshape(3, 3)
51    return np.all(np.abs(cell - np.tril(cell)) < tol)
def tril_cell(cell, reciprocal_cell=None):
53def tril_cell(cell,reciprocal_cell=None):
54    if cell is None:
55        return None
56    cell = np.asarray(cell, dtype=float).reshape(3, 3)
57    if reciprocal_cell is None:
58        reciprocal_cell = np.linalg.inv(cell)
59    length_angles = cell_vectors_to_lengths_angles(cell)
60    cell_tril = cell_lengths_angles_to_vectors(length_angles)
61    rotation = reciprocal_cell @ cell_tril
62    return cell_tril, rotation
def mask_filter_1d(mask, max_size, *values_fill):
65def mask_filter_1d(mask, max_size, *values_fill):
66    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
67    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
68    outputs = []
69    for value, fill in values_fill:
70        shape = list(value.shape)
71        shape[0] = max_size
72        output = (
73            jnp.full(shape, fill, dtype=value.dtype)
74            .at[scatter_idx]
75            .set(value, mode="drop")
76        )
77        outputs.append(output)
78    if cumsum.size == 0:
79        return outputs, scatter_idx, 0
80    return outputs, scatter_idx, cumsum[-1]
def deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
83def deep_update(
84    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
85) -> Dict[Any, Any]:
86    updated_mapping = mapping.copy()
87    for updating_mapping in updating_mappings:
88        for k, v in updating_mapping.items():
89            if (
90                k in updated_mapping
91                and isinstance(updated_mapping[k], dict)
92                and isinstance(v, dict)
93            ):
94                updated_mapping[k] = deep_update(updated_mapping[k], v)
95            else:
96                updated_mapping[k] = v
97    return updated_mapping
class Counter:
100class Counter:
101    def __init__(self, nseg, startsave=1):
102        self.i = 0
103        self.i_avg = 0
104        self.nseg = nseg
105        self.startsave = startsave
106
107    @property
108    def count(self):
109        return self.i
110
111    @property
112    def count_avg(self):
113        return self.i_avg
114
115    @property
116    def nsample(self):
117        return max(self.count_avg - self.startsave + 1, 1)
118
119    @property
120    def is_reset_step(self):
121        return self.count == 0
122
123    def reset_avg(self):
124        self.i_avg = 0
125
126    def reset_all(self):
127        self.i = 0
128        self.i_avg = 0
129
130    def increment(self):
131        self.i = self.i + 1
132        if self.i >= self.nseg:
133            self.i = 0
134            self.i_avg = self.i_avg + 1
Counter(nseg, startsave=1)
101    def __init__(self, nseg, startsave=1):
102        self.i = 0
103        self.i_avg = 0
104        self.nseg = nseg
105        self.startsave = startsave
i
i_avg
nseg
startsave
count
107    @property
108    def count(self):
109        return self.i
count_avg
111    @property
112    def count_avg(self):
113        return self.i_avg
nsample
115    @property
116    def nsample(self):
117        return max(self.count_avg - self.startsave + 1, 1)
is_reset_step
119    @property
120    def is_reset_step(self):
121        return self.count == 0
def reset_avg(self):
123    def reset_avg(self):
124        self.i_avg = 0
def reset_all(self):
126    def reset_all(self):
127        self.i = 0
128        self.i_avg = 0
def increment(self):
130    def increment(self):
131        self.i = self.i + 1
132        if self.i >= self.nseg:
133            self.i = 0
134            self.i_avg = self.i_avg + 1
def detect_topology(species, coordinates, cell=None):
174def detect_topology(species,coordinates, cell=None):
175    """
176    Detects the topology of a system based on species and coordinates.
177    Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond.
178    Inspired by OpenBabel's ConnectTheDots in mol.cpp
179    """
180    from .periodic_table import COV_RADII, UFF_MAX_COORDINATION
181    radii = (COV_RADII* AtomicUnits.ANG)[species]
182    max_coord = UFF_MAX_COORDINATION[species]
183
184    if cell is not None:
185        bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell)
186    else:
187        bond1,bond2,distances = _detect_bonds(radii, coordinates)
188
189    bond1 = np.array(bond1, dtype=np.int32)
190    bond2 = np.array(bond2, dtype=np.int32)
191    bonds = np.stack((bond1, bond2), axis=1)
192    
193    coord = np.zeros(len(species), dtype=np.int32)
194    np.add.at(coord, bonds[:, 0], 1)
195    np.add.at(coord, bonds[:, 1], 1)
196
197    if np.all(coord <= max_coord):
198        return bonds
199
200    distances = np.array(distances, dtype=np.float32)
201    radiibonds = radii[bonds]
202    req = radiibonds.sum(axis=1)
203    rminbonds = radiibonds.min(axis=1)
204    sorted_indices = np.lexsort((-distances/req, rminbonds))
205
206    bonds = bonds[sorted_indices,:]
207    distances = distances[sorted_indices]
208
209    true_bonds = []
210    for ibond in range(bonds.shape[0]):
211        i,j = bonds[ibond]
212        ci, cj = coord[i], coord[j]
213        mci, mcj = max_coord[i], max_coord[j]
214        if ci <= mci and cj <= mcj:
215            true_bonds.append((i, j))
216        else:
217            coord[i] -= 1
218            coord[j] -= 1
219
220    true_bonds = np.array(true_bonds, dtype=np.int32)
221    sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0]))
222    true_bonds = true_bonds[sorted_indices, :]
223
224    return true_bonds

Detects the topology of a system based on species and coordinates. Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. Inspired by OpenBabel's ConnectTheDots in mol.cpp

def get_energy_gradient_function(energy_function, gradient_keys: Sequence[str], jit: bool = True):
226def get_energy_gradient_function(
227        energy_function,
228        gradient_keys: Sequence[str],
229        jit: bool = True,
230    ):
231        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
232
233        def energy_gradient(data):
234            def _etot(inputs):
235                if "strain" in inputs:
236                    scaling = inputs["strain"]
237                    batch_index = data["batch_index"]
238                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
239                    coordinates = jax.vmap(jnp.matmul)(
240                        coordinates, scaling[batch_index]
241                    )
242                    inputs = {**inputs, "coordinates": coordinates}
243                    if "cells" in inputs or "cells" in data:
244                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
245                        cells = jax.vmap(jnp.matmul)(cells, scaling)
246                        inputs["cells"] = cells
247                if "cells" in inputs:
248                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
249                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
250                energy, out = energy_function(inputs)
251                return energy.sum(), out
252
253            if "strain" in gradient_keys and "strain" not in data:
254                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
255            inputs = {k: data[k] for k in gradient_keys}
256            de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs)
257
258            return (
259                de,
260                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
261            )
262
263        if jit:
264            return jax.jit(energy_gradient)
265        else:
266            return energy_gradient

Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys

def read_tinker_interval(indices_interval: Sequence[Union[int, str]]) -> numpy.ndarray:
269def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray:
270    interval = [int(i) for i in indices_interval]
271    indices = []
272    while len(interval) > 0:
273        i = interval.pop(0)
274        if i > 0:
275            indices.append(i)
276        elif i < 0:
277            start = -i
278            end = interval.pop(0)
279            assert end > start, "Syntax error in ligand indices. End index must be greater than start index."
280            indices.extend(range(start, end + 1))
281        else:
282            raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.")
283    indices = np.unique(np.array(indices, dtype=np.int32))
284    return indices - 1  # Convert to zero-based indexing