Source code for pysisyphus.wavefunction.backend_numba

import importlib

import numba
from numba import i8, f8
from numba.experimental import jitclass
from numba.core import types
from numba.core.extending import overload_method
from numba.experimental import structref
import numpy as np


shell_spec = [
    ("L", i8),
    ("center", f8[:]),
    ("coeffs", f8[:]),
    ("exps", f8[:]),
    ("index", i8),
    ("size", i8),
]


[docs] @jitclass(shell_spec) class NumbaShell(object): def __init__(self, L, center, coeffs, exps, index, size): self.L = L self.center = center self.coeffs = coeffs self.exps = exps self.index = index self.size = size def as_tuple(self): return self.L, self.center, self.coeffs, self.exps, self.index, self.size @property def cart_size(self): return (self.L + 2) * (self.L + 1) // 2 @property def sph_size(self): return 2 * self.L + 1
_NUMBA_MODULES = { "int1e_ovlp": ("ovlp3d", 0), "int1e_r": ("dipole3d", 3), "int1e_rr": ("quadrupole3d", 6), "int1e_drr": ("diag_quadrupole3d", 3), "int1e_kin": ("kinetic3d", 0), } # This dict will be populated as needed, as one import costs about ~2s _func_data = {}
[docs] def get_func_data(key): """Get func_dict and component. _func_data dict is laziliy populated, as there is still a ~ 2 s compile time per import :/ """ if key not in _func_data: module_name, components = _NUMBA_MODULES[key] full_module_name = f"pysisyphus.wavefunction.ints_numba.{module_name}" module = importlib.import_module(full_module_name) func_dict = getattr(module, "get_func_dict")() _func_data[key] = (func_dict, components) return _func_data[key]
_R0 = np.zeros(3)
[docs] @numba.jit(parallel=True, nopython=True, cache=True) def get_2c_ints_cart( shells_a, shells_b, func_dict, components, symmetric=True, R=_R0, ): tot_size_a = 0 for shell in shells_a: tot_size_a += shell.size tot_size_b = 0 for shell in shells_b: tot_size_b += shell.size # Allocate final integral array integrals = np.zeros((components, tot_size_a, tot_size_b)) shells_b = shells_a nshells_a = len(shells_a) nshells_b = len(shells_b) # Start loop over contracted gaussians in shells_a for i in numba.prange(nshells_a): shell_a = shells_a[i] La, A, das, axs, indexa, sizea = shell_a.as_tuple() na = len(das) slicea = slice(indexa, indexa + sizea) # Start loop over contracted gaussians in shells_b for j in range(i, nshells_b): shell_b = shells_b[j] Lb, B, dbs, bxs, indexb, sizeb = shell_b.as_tuple() nb = len(dbs) sliceb = slice(indexb, indexb + sizeb) result = np.zeros(max(components, 1) * sizea * sizeb) # Pick correct function depending on La and Lb func = func_dict[(La, Lb)] # Start loop over primitives for k in range(na): ax = axs[k] da = das[k] for l in range(nb): bx = bxs[l] db = dbs[l] func(ax, da, A, bx, db, B, R, result) integrals[:, slicea, sliceb] = result.reshape(components, sizea, sizeb) # End loop over primitives gaussians if symmetric and (i != j): for k in range(indexa, indexa + sizea): for l in range(indexb, indexb + sizeb): integrals[:, l, k] = integrals[:, k, l] # End loop over contracted gaussians in shells_b # End loop over contracted gaussians in shells_a return integrals
[docs] def to_numba_shells(shells): numba_shells = numba.typed.List() for shell in shells: numba_shell = NumbaShell(*shell.as_tuple()) numba_shells.append(numba_shell) return numba_shells
""" @numba.jit(nopython=True, cache=True) def to_numba_shells_from_tuples(shells): numba_shells = numba.typed.List() for shell in shells: numba_shell = NumbaShell(*shell) numba_shells.append(numba_shell) return numba_shells """
[docs] @structref.register class ShellStructType(types.StructRef):
[docs] def preprocess_fields(self, fields): # This method is called by the type constructor for additional # preprocessing on the fields. # Here, we don't want the struct to take Literal types. return tuple((name, types.unliteral(typ)) for name, typ in fields)
SST = ShellStructType( [ ("L", numba.i8), ("center", numba.f8[:]), ("center_ind", numba.i8), ("coeffs", numba.f8[:]), ("exps", numba.f8[:]), ("index", numba.i8), ("size", numba.i8), ] )
[docs] class ShellStruct(structref.StructRefProxy): def __new__(cls, L, center, center_ind, coeffs, exps, index, size): return structref.StructRefProxy.__new__( cls, L, center, center_ind, coeffs, exps, index, size )
[docs] @overload_method(ShellStructType, "cart_size") def ol_cart_size(self): def inner(self): return (self.L + 2) * (self.L + 1) // 2 return inner
[docs] @overload_method(ShellStructType, "sph_size") def ol_sph_size(self): def inner(self): return 2 * self.L + 1 return inner
[docs] @overload_method(ShellStructType, "as_tuple") def ol_as_tuple(self): def inner(self): return ( self.L, self.center, self.center_ind, self.coeffs, self.exps, self.index, self.size, ) return inner
structref.define_proxy( ShellStruct, ShellStructType, ["L", "center", "center_ind", "coeffs", "exps", "index", "size"], )
[docs] def to_numba_shellstructs(shells): shellstructs = numba.typed.List() for shell in shells: L, center, coeffs, exps, index, size = shell.as_tuple() center_ind = shell.center_ind mst = ShellStruct(L, center, center_ind, coeffs, exps, index, size) shellstructs.append(mst) return shellstructs
[docs] def get_1el_ints_cart( shells_a, func_dict, components, shells_b=None, R=np.zeros(3), **kwargs, ): symmetric = shells_b is None if symmetric: shells_b = shells_a return get_2c_ints_cart( shells_a, shells_b, func_dict, components=components, symmetric=symmetric, R=R, )