from dataclasses import dataclass
import functools
from math import ceil
import tempfile
from typing import Tuple
import jinja2
import numpy as np
from numpy.typing import NDArray
import pyparsing as pp
from pysisyphus.elem_data import ATOMIC_NUMBERS, INV_ATOMIC_NUMBERS
from pysisyphus.Geometry import Geometry
from pysisyphus.helpers_pure import file_or_str
from pysisyphus.wrapper.jmol import view_cdd_cube
from pysisyphus.wavefunction import Wavefunction
DEFAULT_MARGIN = 3.0
[docs]
def get_mins_maxs(coords3d, margin=DEFAULT_MARGIN):
mins = coords3d.min(axis=0) - margin
maxs = coords3d.max(axis=0) + margin
return mins, maxs
[docs]
@functools.singledispatch
def get_grid(coords3d: np.ndarray, num=10, margin=DEFAULT_MARGIN):
(minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin)
X, Y, Z = np.mgrid[
minx : maxx : num * 1j,
miny : maxy : num * 1j,
minz : maxz : num * 1j,
]
xyz = np.stack((X.flatten(), Y.flatten(), Z.flatten()), axis=1)
spacing = np.array((maxx - minx, maxy - miny, maxz - minz)) / (num - 1)
return xyz, spacing, (num, num, num)
@get_grid.register
def _(wf: Wavefunction, **kwargs):
return get_grid(wf.coords3d, **kwargs)
[docs]
def get_grid_with_spacing(coords3d, spacing=0.30, margin=DEFAULT_MARGIN):
(minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin)
dx = maxx - minx
dy = maxy - miny
dz = maxz - minz
nx = ceil(dx / spacing)
ny = ceil(dy / spacing)
nz = ceil(dz / spacing)
X, Y, Z = np.mgrid[
minx : maxx : nx * 1j,
miny : maxy : ny * 1j,
minz : maxz : nz * 1j,
]
xyz = np.stack((X.flatten(), Y.flatten(), Z.flatten()), axis=1)
act_spacing = np.array(
((maxx - minx) / (nx - 1), (maxy - miny) / (ny - 1), (maxz - minz) / (nz - 1))
)
return xyz, act_spacing, (nx, ny, nz)
[docs]
@functools.singledispatch
def get_box_grid(coords3d: np.ndarray, num=100, margin=DEFAULT_MARGIN, edge_length=5):
(minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin)
try:
numx, numy, numz = num
# Cubic grid
except TypeError:
numx = numy = numz = num
try:
edge_lengthx, edge_lengthy, edge_lengthz = edge_length
# Same number of points along all dimensions
except TypeError:
edge_lengthx = edge_lengthy = edge_lengthz = edge_length
for n, e in zip((numx, numy, numz), (edge_lengthx, edge_lengthy, edge_lengthz)):
assert n % e == 0, f"'num={n}' must be a multiple of 'edge_length={e}'!"
# Grid extents
dx = maxx - minx
dy = maxy - miny
dz = maxz - minz
# Distance between points
spacing_x = dx / (numx - 1)
spacing_y = dy / (numy - 1)
spacing_z = dz / (numz - 1)
spacing = np.array((spacing_x, spacing_y, spacing_z))
trans_vec = np.array(
(
spacing_x * edge_lengthx,
spacing_y * edge_lengthy,
spacing_z * edge_lengthz,
)
)
ind_vec = np.array((edge_lengthx, edge_lengthy, edge_lengthz))
num_yz = numy * numz
def transform_inds(inds3d):
x, y, z = inds3d.T
return num_yz * x + numz * y + z
boxes_per_x = numx // edge_lengthx
boxes_per_y = numy // edge_lengthy
boxes_per_z = numz // edge_lengthz
box_size = edge_lengthx * edge_lengthy * edge_lengthz
# Build initial box
box3d = np.zeros((box_size, 3))
ind_box3d = np.zeros((box_size, 3), dtype=int)
xyz = np.zeros(3)
i = 0
for x in range(edge_lengthx):
xyz[0] = x * spacing_x
for y in range(edge_lengthy):
xyz[1] = y * spacing_y
for z in range(edge_lengthz):
xyz[2] = z * spacing_z
box3d[i] = xyz
ind_box3d[i] = (x, y, z)
i += 1
# Shift box to grid origin
box3d += np.array((minx, miny, minz))[None, :]
# Translate initial box along grid
grid3d = np.zeros((numx * numy * numz, 3))
# Also build & return an index array that sorts our box-grid into
# the usual grid order expected in cubes.
sort_inds = np.zeros(grid3d.shape[0], dtype=int)
xyz = np.zeros(3, dtype=int)
i = 0
for x in range(boxes_per_x):
xyz[0] = x
for y in range(boxes_per_y):
xyz[1] = y
for z in range(boxes_per_z):
xyz[2] = z
slc = slice(i * box_size, (i + 1) * box_size)
grid3d[slc] = box3d + (xyz * trans_vec)
sort_inds[slc] = transform_inds(ind_box3d + (xyz * ind_vec))
i += 1
return grid3d, spacing, (numx, numy, numz), sort_inds
@get_box_grid.register
def _(wf: Wavefunction, **kwargs):
return get_box_grid(wf.coords3d, **kwargs)
[docs]
@dataclass
class Cube:
atoms: Tuple
coords3d: NDArray
origin: NDArray
npoints: NDArray
axes: NDArray
vol_data: NDArray
comment1: str = "Generated by pysisyphus"
comment2: str = ""
@property
def vol_element(self):
return np.prod(np.diag(self.axes))
[docs]
@staticmethod
def from_file(fn):
return parse_cube(fn)
[docs]
@staticmethod
def from_wf_and_grid(wf, vol_data, origin, axes):
assert vol_data.ndim == 3, "Volume data 'vol_data' must be 3d!"
return Cube(
atoms=wf.atoms,
coords3d=wf.coords3d,
origin=origin,
npoints=vol_data.shape,
axes=axes,
vol_data=vol_data,
)
[docs]
@staticmethod
def from_cube_and_vol_data(cube, vol_data):
return Cube(
atoms=cube.atoms,
coords3d=cube.coords3d,
origin=cube.origin,
npoints=vol_data.shape,
axes=cube.axes,
vol_data=vol_data,
)
[docs]
def to_str(self):
return cube_to_str(
self.atoms,
self.coords3d,
self.vol_data.reshape(*self.npoints),
self.origin,
self.axes,
)
[docs]
def write(self, fn):
with open(fn, "w") as handle:
handle.write(self.to_str())
[docs]
def view_cdd(self, **kwargs):
with tempfile.NamedTemporaryFile("w", suffix=".cube") as tmp_cube:
self.write(tmp_cube.name)
view_cdd_cube(tmp_cube.name, **kwargs)
@file_or_str(".cube", ".cub")
def parse_cube(text):
int_ = pp.common.integer
sci_real = pp.common.sci_real
def get_line_word(*args):
return pp.Word(*args).setWhitespaceChars("\n")
comment = get_line_word(pp.printables + " \t")
cart_vec = pp.Group(sci_real + sci_real + sci_real)
axis = pp.Group(int_ + cart_vec)
atom_line = pp.Group(
int_.set_results_name("atomic_num")
+ sci_real.set_results_name("charge")
+ cart_vec.set_results_name("coords")
+ pp.LineEnd()
)
parser = (
comment.set_results_name("comment1")
+ comment.set_results_name("comment2")
+ int_.set_results_name("atom_num")
+ cart_vec.set_results_name("origin")
+ axis.set_results_name("axis1")
+ axis.set_results_name("axis2")
+ axis.set_results_name("axis3")
+ pp.ZeroOrMore(atom_line).set_results_name("atom_lines")
+ pp.ZeroOrMore(sci_real).set_results_name("vol_data")
)
res = parser.parseString(text).as_dict()
comment1 = " ".join(res["comment1"])
comment2 = " ".join(res["comment2"])
atom_num = res["atom_num"]
atom_lines = res["atom_lines"]
assert atom_num == len(atom_lines)
atom_nums, coords3d = zip(*[(al["atomic_num"], al["coords"]) for al in atom_lines])
atoms = [INV_ATOMIC_NUMBERS[an].capitalize() for an in atom_nums]
coords3d = np.array(coords3d)
origin = np.array(res["origin"])
npoints, axes = zip(*[res[k] for k in ("axis1", "axis2", "axis3")])
npoints = np.array(npoints, int)
assert all(npoints > 0), "Only Bohr are supported!"
axes = np.array(axes, float)
vol_data = np.array(res["vol_data"]).reshape(*npoints)
cube = Cube(
atoms=atoms,
coords3d=coords3d,
origin=origin,
npoints=npoints,
axes=axes,
vol_data=vol_data,
comment1=comment1,
comment2=comment2,
)
return cube
@file_or_str(".cube", ".cub")
def geom_from_cube(text, **kwargs):
cube = parse_cube(text)
geom = Geometry(cube.atoms, cube.coords3d, **kwargs)
return geom
CUBE_TPL = """{{ comment1 }}
{{ comment2 }}
{{ ifmt(atom_nums|length) }} {{ fmt(origin_x) }} {{ fmt(origin_y) }} {{ fmt(origin_z) }}
{% for n_points, (x, y, z) in zip(axes_npoints, axes) %}
{{ ifmt(n_points) }} {{ fmt(x) }} {{ fmt(y) }} {{ fmt(z) }}
{% endfor %}
{% for atomic_num, (x, y, z) in zip(atom_nums, coords3d) %}
{{ ifmt(atomic_num) }} 0.000000 {{ fmt(x) }} {{ fmt(y) }} {{ fmt(z) }}
{% endfor %}
{{ grid_str }}
"""
[docs]
def cube_to_str(atoms, coords3d, vol_data, origin, axes):
assert vol_data.ndim == 3
env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True)
env.globals.update(zip=zip)
tpl = env.from_string(CUBE_TPL)
org_x, org_y, org_z = origin
axes_npoints = vol_data.shape
atom_nums = [ATOMIC_NUMBERS[atom.lower()] for atom in atoms]
nx, ny, nz = axes_npoints
grid_str = ""
for x in range(nx):
for y in range(ny):
for z in range(nz):
d = "{: >12.6e} ".format(vol_data[x, y, z])
grid_str += d
if z % 6 == 5:
grid_str += "\n"
grid_str += "\n"
def fmt(num):
return f"{num: >12.6f}"
def ifmt(num):
return f"{num: >5d}"
rendered = tpl.render(
fmt=fmt,
ifmt=ifmt,
origin_x=org_x,
origin_y=org_y,
origin_z=org_z,
comment1="Generated by pysisyphus",
comment2=f"Total of {nx*ny*nz} points",
axes=axes,
axes_npoints=axes_npoints,
atom_nums=atom_nums,
coords3d=coords3d,
grid_str=grid_str,
)
return rendered