fennol.utils.spherical_harmonics

  1import numpy as np
  2import jax
  3import math
  4import jax.numpy as jnp
  5import sympy
  6from sympy.printing.pycode import pycode
  7from sympy.physics.wigner import clebsch_gordan
  8from functools import partial
  9import sys
 10
 11def CG_SU2(j1: int, j2: int, j3: int) -> np.array:
 12    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SU(2)`
 13    Returns
 14    -------
 15    `np.Array`
 16        tensor :math:`C` of shape :math:`(2j_1+1, 2j_2+1, 2j_3+1)`
 17    """
 18    C = np.zeros((2 * j1 + 1, 2 * j2 + 1, 2 * j3 + 1))
 19    for m1 in range(-j1, j1 + 1):
 20        for m2 in range(-j2, j2 + 1):
 21            for m3 in range(-j3, j3 + 1):
 22                C[m1 + j1, m2 + j2, m3 + j3] = float(
 23                    clebsch_gordan(j1, j2, j3, m1, m2, m3)
 24                )
 25    return C
 26
 27
 28def change_basis_real_to_complex(l: int) -> np.array:
 29    r"""Change of basis matrix from real to complex spherical harmonics
 30    https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
 31    adapted from e3nn.o3._wigner
 32    """
 33    q = np.zeros((2 * l + 1, 2 * l + 1), dtype=complex)
 34    for m in range(-l, 0):
 35        q[l + m, l + abs(m)] = 1 / 2**0.5
 36        q[l + m, l - abs(m)] = -1j / 2**0.5
 37    q[l, l] = 1
 38    for m in range(1, l + 1):
 39        q[l + m, l + abs(m)] = (-1) ** m / 2**0.5
 40        q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5
 41
 42    # factor of (-i)**l to make the Clebsch-Gordan coefficients real
 43    return q * (-1j) ** l
 44
 45
 46def CG_SO3(j1: int, j2: int, j3: int) -> np.array:
 47    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SO(3)`
 48    Returns
 49    -------
 50    `np.array`
 51        tensor :math:`C` of shape :math:`(2l_1+1, 2l_2+1, 2l_3+1)`
 52    """
 53    C = CG_SU2(j1, j2, j3)
 54    Q1 = change_basis_real_to_complex(j1)
 55    Q2 = change_basis_real_to_complex(j2)
 56    Q3 = change_basis_real_to_complex(j3)
 57    C = np.real(np.einsum("ij,kl,mn,ikn->jlm", Q1, Q2, np.conj(Q3.T), C))
 58    return C / np.linalg.norm(C)
 59
 60
 61def generate_spherical_harmonics(
 62    lmax, normalize=False, print_code=False, jit=False, vmapped=False
 63):  # pragma: no cover
 64    r"""returns a function that computes spherical harmonic up to lmax
 65    (adapted from e3nn)
 66    """
 67
 68    def to_frac(x: float):
 69        from fractions import Fraction
 70
 71        s = 1 if x >= 0 else -1
 72        x = x**2
 73        x = Fraction(x).limit_denominator()
 74        x = s * sympy.sqrt(x)
 75        x = sympy.simplify(x)
 76        return x
 77
 78    if vmapped:
 79        fn_str = "def spherical_harmonics_(x,y,z):\n"
 80        fn_str += "  sh_0_0 = 1.\n"
 81    else:
 82        fn_str = "def spherical_harmonics_(vec):\n"
 83        if normalize:
 84            fn_str += "  vec = vec/jnp.linalg.norm(vec,axis=-1,keepdims=True)\n"
 85        fn_str += "  x,y,z = [jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)]\n"
 86        fn_str += "  sh_0_0 = jnp.ones_like(x)\n"
 87
 88    x_var, y_var, z_var = sympy.symbols("x y z")
 89    polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var]
 90
 91    def sub_z1(p, names, polynormz):
 92        p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0)
 93        for n, c in zip(names, polynormz):
 94            p = p.subs(n, c)
 95        return p
 96
 97    poly_evalz = [sub_z1(p, [], []) for p in polynomials]
 98
 99    for l in range(1, lmax + 1):
100        sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1)))
101
102        for n, p in zip(sh_variables, polynomials):
103            fn_str += f"  {n} = {pycode(p.evalf())}\n"
104
105        if l == lmax:
106            break
107
108        polynomials = [
109            sum(
110                to_frac(c.item()) * v * sh
111                for cj, v in zip(cij, [x_var, y_var, z_var])
112                for c, sh in zip(cj, sh_variables)
113            )
114            for cij in CG_SO3(l + 1, 1, l)
115        ]
116
117        poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials]
118        norm = sympy.sqrt(sum(p**2 for p in poly_evalz))
119        polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials]
120        poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz]
121
122        polynomials = [sympy.simplify(p, full=True) for p in polynomials]
123
124    u = ",\n        ".join(
125        ", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1)
126    )
127    if vmapped:
128        fn_str += f"  return jnp.array([\n        {u}\n    ])\n"
129    else:
130        fn_str += f"  return jnp.stack([\n        {u}\n    ], axis=-1)\n"
131
132    if print_code:
133        print(fn_str)
134    if sys.version_info[0] == 2:
135        raise RuntimeError("Python 2 is not supported")
136    if sys.version_info[1] >= 13:
137        new_locals = {}
138        exec(fn_str,locals=new_locals)
139        sh = new_locals["spherical_harmonics_"]
140    else:
141        exec(fn_str)
142        sh = locals()["spherical_harmonics_"]
143    if jit:
144        sh = jax.jit(sh)
145    if not vmapped:
146        return sh
147
148    if normalize:
149
150        def spherical_harmonics(vec):
151            vec = vec / jnp.linalg.norm(vec, axis=-1, keepdims=True)
152            x, y, z = [
153                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
154            ]
155            return jax.vmap(sh)(x, y, z)
156
157    else:
158
159        def spherical_harmonics(vec):
160            x, y, z = [
161                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
162            ]
163            return jax.vmap(sh)(x, y, z)
164
165    if jit:
166        spherical_harmonics = jax.jit(spherical_harmonics)
167    return spherical_harmonics
168
169
170@partial(jax.jit, static_argnums=1)
171def spherical_to_cartesian_tensor(Q, lmax):
172    q = Q[..., 0]
173    if lmax == 0:
174        return q[..., None]
175
176    mu = Q[..., 1:4]
177    if lmax == 1:
178        return jnp.concatenate([q[..., None], mu], axis=-1)
179
180    Q22s = Q[..., 4]
181    Q21s = Q[..., 5]
182    Q20 = Q[..., 6]
183    Q21c = Q[..., 7]
184    Q22c = Q[..., 8]
185    Tzz = -0.5 * Q20 + (0.5 * 3**0.5) * Q22c
186    Txx = -0.5 * Q20 - (0.5 * 3**0.5) * Q22c
187    Tyy = Q20
188    Txz = 0.5 * (3**0.5) * Q22s
189    Tyz = 0.5 * (3**0.5) * Q21c
190    Txy = 0.5 * (3**0.5) * Q21s
191
192    if lmax == 2:
193        return jnp.concatenate(
194            [
195                q[..., None],
196                mu,
197                Txx[..., None],
198                Tyy[..., None],
199                Tzz[..., None],
200                Txy[..., None],
201                Txz[..., None],
202                Tyz[..., None],
203            ],
204            axis=-1,
205        )
def CG_SU2(j1: int, j2: int, j3: int) -> <built-in function array>:
12def CG_SU2(j1: int, j2: int, j3: int) -> np.array:
13    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SU(2)`
14    Returns
15    -------
16    `np.Array`
17        tensor :math:`C` of shape :math:`(2j_1+1, 2j_2+1, 2j_3+1)`
18    """
19    C = np.zeros((2 * j1 + 1, 2 * j2 + 1, 2 * j3 + 1))
20    for m1 in range(-j1, j1 + 1):
21        for m2 in range(-j2, j2 + 1):
22            for m3 in range(-j3, j3 + 1):
23                C[m1 + j1, m2 + j2, m3 + j3] = float(
24                    clebsch_gordan(j1, j2, j3, m1, m2, m3)
25                )
26    return C

Clebsch-Gordan coefficients for the direct product of two irreducible representations of \( SU(2) \)

Returns
  • np.Array: tensor \( C \) of shape \( (2j_1+1, 2j_2+1, 2j_3+1) \)
def change_basis_real_to_complex(l: int) -> <built-in function array>:
29def change_basis_real_to_complex(l: int) -> np.array:
30    r"""Change of basis matrix from real to complex spherical harmonics
31    https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
32    adapted from e3nn.o3._wigner
33    """
34    q = np.zeros((2 * l + 1, 2 * l + 1), dtype=complex)
35    for m in range(-l, 0):
36        q[l + m, l + abs(m)] = 1 / 2**0.5
37        q[l + m, l - abs(m)] = -1j / 2**0.5
38    q[l, l] = 1
39    for m in range(1, l + 1):
40        q[l + m, l + abs(m)] = (-1) ** m / 2**0.5
41        q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5
42
43    # factor of (-i)**l to make the Clebsch-Gordan coefficients real
44    return q * (-1j) ** l

Change of basis matrix from real to complex spherical harmonics https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form adapted from e3nn.o3._wigner

def CG_SO3(j1: int, j2: int, j3: int) -> <built-in function array>:
47def CG_SO3(j1: int, j2: int, j3: int) -> np.array:
48    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SO(3)`
49    Returns
50    -------
51    `np.array`
52        tensor :math:`C` of shape :math:`(2l_1+1, 2l_2+1, 2l_3+1)`
53    """
54    C = CG_SU2(j1, j2, j3)
55    Q1 = change_basis_real_to_complex(j1)
56    Q2 = change_basis_real_to_complex(j2)
57    Q3 = change_basis_real_to_complex(j3)
58    C = np.real(np.einsum("ij,kl,mn,ikn->jlm", Q1, Q2, np.conj(Q3.T), C))
59    return C / np.linalg.norm(C)

Clebsch-Gordan coefficients for the direct product of two irreducible representations of \( SO(3) \)

Returns
  • np.array: tensor \( C \) of shape \( (2l_1+1, 2l_2+1, 2l_3+1) \)
def generate_spherical_harmonics(lmax, normalize=False, print_code=False, jit=False, vmapped=False):
 62def generate_spherical_harmonics(
 63    lmax, normalize=False, print_code=False, jit=False, vmapped=False
 64):  # pragma: no cover
 65    r"""returns a function that computes spherical harmonic up to lmax
 66    (adapted from e3nn)
 67    """
 68
 69    def to_frac(x: float):
 70        from fractions import Fraction
 71
 72        s = 1 if x >= 0 else -1
 73        x = x**2
 74        x = Fraction(x).limit_denominator()
 75        x = s * sympy.sqrt(x)
 76        x = sympy.simplify(x)
 77        return x
 78
 79    if vmapped:
 80        fn_str = "def spherical_harmonics_(x,y,z):\n"
 81        fn_str += "  sh_0_0 = 1.\n"
 82    else:
 83        fn_str = "def spherical_harmonics_(vec):\n"
 84        if normalize:
 85            fn_str += "  vec = vec/jnp.linalg.norm(vec,axis=-1,keepdims=True)\n"
 86        fn_str += "  x,y,z = [jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)]\n"
 87        fn_str += "  sh_0_0 = jnp.ones_like(x)\n"
 88
 89    x_var, y_var, z_var = sympy.symbols("x y z")
 90    polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var]
 91
 92    def sub_z1(p, names, polynormz):
 93        p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0)
 94        for n, c in zip(names, polynormz):
 95            p = p.subs(n, c)
 96        return p
 97
 98    poly_evalz = [sub_z1(p, [], []) for p in polynomials]
 99
100    for l in range(1, lmax + 1):
101        sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1)))
102
103        for n, p in zip(sh_variables, polynomials):
104            fn_str += f"  {n} = {pycode(p.evalf())}\n"
105
106        if l == lmax:
107            break
108
109        polynomials = [
110            sum(
111                to_frac(c.item()) * v * sh
112                for cj, v in zip(cij, [x_var, y_var, z_var])
113                for c, sh in zip(cj, sh_variables)
114            )
115            for cij in CG_SO3(l + 1, 1, l)
116        ]
117
118        poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials]
119        norm = sympy.sqrt(sum(p**2 for p in poly_evalz))
120        polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials]
121        poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz]
122
123        polynomials = [sympy.simplify(p, full=True) for p in polynomials]
124
125    u = ",\n        ".join(
126        ", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1)
127    )
128    if vmapped:
129        fn_str += f"  return jnp.array([\n        {u}\n    ])\n"
130    else:
131        fn_str += f"  return jnp.stack([\n        {u}\n    ], axis=-1)\n"
132
133    if print_code:
134        print(fn_str)
135    if sys.version_info[0] == 2:
136        raise RuntimeError("Python 2 is not supported")
137    if sys.version_info[1] >= 13:
138        new_locals = {}
139        exec(fn_str,locals=new_locals)
140        sh = new_locals["spherical_harmonics_"]
141    else:
142        exec(fn_str)
143        sh = locals()["spherical_harmonics_"]
144    if jit:
145        sh = jax.jit(sh)
146    if not vmapped:
147        return sh
148
149    if normalize:
150
151        def spherical_harmonics(vec):
152            vec = vec / jnp.linalg.norm(vec, axis=-1, keepdims=True)
153            x, y, z = [
154                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
155            ]
156            return jax.vmap(sh)(x, y, z)
157
158    else:
159
160        def spherical_harmonics(vec):
161            x, y, z = [
162                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
163            ]
164            return jax.vmap(sh)(x, y, z)
165
166    if jit:
167        spherical_harmonics = jax.jit(spherical_harmonics)
168    return spherical_harmonics

returns a function that computes spherical harmonic up to lmax (adapted from e3nn)

@partial(jax.jit, static_argnums=1)
def spherical_to_cartesian_tensor(Q, lmax):
171@partial(jax.jit, static_argnums=1)
172def spherical_to_cartesian_tensor(Q, lmax):
173    q = Q[..., 0]
174    if lmax == 0:
175        return q[..., None]
176
177    mu = Q[..., 1:4]
178    if lmax == 1:
179        return jnp.concatenate([q[..., None], mu], axis=-1)
180
181    Q22s = Q[..., 4]
182    Q21s = Q[..., 5]
183    Q20 = Q[..., 6]
184    Q21c = Q[..., 7]
185    Q22c = Q[..., 8]
186    Tzz = -0.5 * Q20 + (0.5 * 3**0.5) * Q22c
187    Txx = -0.5 * Q20 - (0.5 * 3**0.5) * Q22c
188    Tyy = Q20
189    Txz = 0.5 * (3**0.5) * Q22s
190    Tyz = 0.5 * (3**0.5) * Q21c
191    Txy = 0.5 * (3**0.5) * Q21s
192
193    if lmax == 2:
194        return jnp.concatenate(
195            [
196                q[..., None],
197                mu,
198                Txx[..., None],
199                Tyy[..., None],
200                Tzz[..., None],
201                Txy[..., None],
202                Txz[..., None],
203                Tyz[..., None],
204            ],
205            axis=-1,
206        )