Source code for pysisyphus.io.cube

from dataclasses import dataclass
import functools
from math import ceil
import tempfile
from typing import Tuple

import jinja2
import numpy as np
from numpy.typing import NDArray
import pyparsing as pp

from pysisyphus.elem_data import ATOMIC_NUMBERS, INV_ATOMIC_NUMBERS
from pysisyphus.Geometry import Geometry
from pysisyphus.helpers_pure import file_or_str
from pysisyphus.wrapper.jmol import view_cdd_cube
from pysisyphus.wavefunction import Wavefunction


DEFAULT_MARGIN = 3.0


[docs] def get_mins_maxs(coords3d, margin=DEFAULT_MARGIN): mins = coords3d.min(axis=0) - margin maxs = coords3d.max(axis=0) + margin return mins, maxs
[docs] @functools.singledispatch def get_grid(coords3d: np.ndarray, num=10, margin=DEFAULT_MARGIN): (minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin) X, Y, Z = np.mgrid[ minx : maxx : num * 1j, miny : maxy : num * 1j, minz : maxz : num * 1j, ] xyz = np.stack((X.flatten(), Y.flatten(), Z.flatten()), axis=1) spacing = np.array((maxx - minx, maxy - miny, maxz - minz)) / (num - 1) return xyz, spacing, (num, num, num)
@get_grid.register def _(wf: Wavefunction, **kwargs): return get_grid(wf.coords3d, **kwargs)
[docs] def get_grid_with_spacing(coords3d, spacing=0.30, margin=DEFAULT_MARGIN): (minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin) dx = maxx - minx dy = maxy - miny dz = maxz - minz nx = ceil(dx / spacing) ny = ceil(dy / spacing) nz = ceil(dz / spacing) X, Y, Z = np.mgrid[ minx : maxx : nx * 1j, miny : maxy : ny * 1j, minz : maxz : nz * 1j, ] xyz = np.stack((X.flatten(), Y.flatten(), Z.flatten()), axis=1) act_spacing = np.array( ((maxx - minx) / (nx - 1), (maxy - miny) / (ny - 1), (maxz - minz) / (nz - 1)) ) return xyz, act_spacing, (nx, ny, nz)
[docs] @functools.singledispatch def get_box_grid(coords3d: np.ndarray, num=100, margin=DEFAULT_MARGIN, edge_length=5): (minx, miny, minz), (maxx, maxy, maxz) = get_mins_maxs(coords3d, margin) try: numx, numy, numz = num # Cubic grid except TypeError: numx = numy = numz = num try: edge_lengthx, edge_lengthy, edge_lengthz = edge_length # Same number of points along all dimensions except TypeError: edge_lengthx = edge_lengthy = edge_lengthz = edge_length for n, e in zip((numx, numy, numz), (edge_lengthx, edge_lengthy, edge_lengthz)): assert n % e == 0, f"'num={n}' must be a multiple of 'edge_length={e}'!" # Grid extents dx = maxx - minx dy = maxy - miny dz = maxz - minz # Distance between points spacing_x = dx / (numx - 1) spacing_y = dy / (numy - 1) spacing_z = dz / (numz - 1) spacing = np.array((spacing_x, spacing_y, spacing_z)) trans_vec = np.array( ( spacing_x * edge_lengthx, spacing_y * edge_lengthy, spacing_z * edge_lengthz, ) ) ind_vec = np.array((edge_lengthx, edge_lengthy, edge_lengthz)) num_yz = numy * numz def transform_inds(inds3d): x, y, z = inds3d.T return num_yz * x + numz * y + z boxes_per_x = numx // edge_lengthx boxes_per_y = numy // edge_lengthy boxes_per_z = numz // edge_lengthz box_size = edge_lengthx * edge_lengthy * edge_lengthz # Build initial box box3d = np.zeros((box_size, 3)) ind_box3d = np.zeros((box_size, 3), dtype=int) xyz = np.zeros(3) i = 0 for x in range(edge_lengthx): xyz[0] = x * spacing_x for y in range(edge_lengthy): xyz[1] = y * spacing_y for z in range(edge_lengthz): xyz[2] = z * spacing_z box3d[i] = xyz ind_box3d[i] = (x, y, z) i += 1 # Shift box to grid origin box3d += np.array((minx, miny, minz))[None, :] # Translate initial box along grid grid3d = np.zeros((numx * numy * numz, 3)) # Also build & return an index array that sorts our box-grid into # the usual grid order expected in cubes. sort_inds = np.zeros(grid3d.shape[0], dtype=int) xyz = np.zeros(3, dtype=int) i = 0 for x in range(boxes_per_x): xyz[0] = x for y in range(boxes_per_y): xyz[1] = y for z in range(boxes_per_z): xyz[2] = z slc = slice(i * box_size, (i + 1) * box_size) grid3d[slc] = box3d + (xyz * trans_vec) sort_inds[slc] = transform_inds(ind_box3d + (xyz * ind_vec)) i += 1 return grid3d, spacing, (numx, numy, numz), sort_inds
@get_box_grid.register def _(wf: Wavefunction, **kwargs): return get_box_grid(wf.coords3d, **kwargs)
[docs] @dataclass class Cube: atoms: Tuple coords3d: NDArray origin: NDArray npoints: NDArray axes: NDArray vol_data: NDArray comment1: str = "Generated by pysisyphus" comment2: str = "" @property def vol_element(self): return np.prod(np.diag(self.axes))
[docs] @staticmethod def from_file(fn): return parse_cube(fn)
[docs] @staticmethod def from_wf_and_grid(wf, vol_data, origin, axes): assert vol_data.ndim == 3, "Volume data 'vol_data' must be 3d!" return Cube( atoms=wf.atoms, coords3d=wf.coords3d, origin=origin, npoints=vol_data.shape, axes=axes, vol_data=vol_data, )
[docs] @staticmethod def from_cube_and_vol_data(cube, vol_data): return Cube( atoms=cube.atoms, coords3d=cube.coords3d, origin=cube.origin, npoints=vol_data.shape, axes=cube.axes, vol_data=vol_data, )
[docs] def to_str(self): return cube_to_str( self.atoms, self.coords3d, self.vol_data.reshape(*self.npoints), self.origin, self.axes, )
[docs] def write(self, fn): with open(fn, "w") as handle: handle.write(self.to_str())
[docs] def view_cdd(self, **kwargs): with tempfile.NamedTemporaryFile("w", suffix=".cube") as tmp_cube: self.write(tmp_cube.name) view_cdd_cube(tmp_cube.name, **kwargs)
@file_or_str(".cube", ".cub") def parse_cube(text): int_ = pp.common.integer sci_real = pp.common.sci_real def get_line_word(*args): return pp.Word(*args).setWhitespaceChars("\n") comment = get_line_word(pp.printables + " \t") cart_vec = pp.Group(sci_real + sci_real + sci_real) axis = pp.Group(int_ + cart_vec) atom_line = pp.Group( int_.set_results_name("atomic_num") + sci_real.set_results_name("charge") + cart_vec.set_results_name("coords") + pp.LineEnd() ) parser = ( comment.set_results_name("comment1") + comment.set_results_name("comment2") + int_.set_results_name("atom_num") + cart_vec.set_results_name("origin") + axis.set_results_name("axis1") + axis.set_results_name("axis2") + axis.set_results_name("axis3") + pp.ZeroOrMore(atom_line).set_results_name("atom_lines") + pp.ZeroOrMore(sci_real).set_results_name("vol_data") ) res = parser.parseString(text).as_dict() comment1 = " ".join(res["comment1"]) comment2 = " ".join(res["comment2"]) atom_num = res["atom_num"] atom_lines = res["atom_lines"] assert atom_num == len(atom_lines) atom_nums, coords3d = zip(*[(al["atomic_num"], al["coords"]) for al in atom_lines]) atoms = [INV_ATOMIC_NUMBERS[an].capitalize() for an in atom_nums] coords3d = np.array(coords3d) origin = np.array(res["origin"]) npoints, axes = zip(*[res[k] for k in ("axis1", "axis2", "axis3")]) npoints = np.array(npoints, int) assert all(npoints > 0), "Only Bohr are supported!" axes = np.array(axes, float) vol_data = np.array(res["vol_data"]).reshape(*npoints) cube = Cube( atoms=atoms, coords3d=coords3d, origin=origin, npoints=npoints, axes=axes, vol_data=vol_data, comment1=comment1, comment2=comment2, ) return cube @file_or_str(".cube", ".cub") def geom_from_cube(text, **kwargs): cube = parse_cube(text) geom = Geometry(cube.atoms, cube.coords3d, **kwargs) return geom CUBE_TPL = """{{ comment1 }} {{ comment2 }} {{ ifmt(atom_nums|length) }} {{ fmt(origin_x) }} {{ fmt(origin_y) }} {{ fmt(origin_z) }} {% for n_points, (x, y, z) in zip(axes_npoints, axes) %} {{ ifmt(n_points) }} {{ fmt(x) }} {{ fmt(y) }} {{ fmt(z) }} {% endfor %} {% for atomic_num, (x, y, z) in zip(atom_nums, coords3d) %} {{ ifmt(atomic_num) }} 0.000000 {{ fmt(x) }} {{ fmt(y) }} {{ fmt(z) }} {% endfor %} {{ grid_str }} """
[docs] def cube_to_str(atoms, coords3d, vol_data, origin, axes): assert vol_data.ndim == 3 env = jinja2.Environment(trim_blocks=True, lstrip_blocks=True) env.globals.update(zip=zip) tpl = env.from_string(CUBE_TPL) org_x, org_y, org_z = origin axes_npoints = vol_data.shape atom_nums = [ATOMIC_NUMBERS[atom.lower()] for atom in atoms] nx, ny, nz = axes_npoints grid_str = "" for x in range(nx): for y in range(ny): for z in range(nz): d = "{: >12.6e} ".format(vol_data[x, y, z]) grid_str += d if z % 6 == 5: grid_str += "\n" grid_str += "\n" def fmt(num): return f"{num: >12.6f}" def ifmt(num): return f"{num: >5d}" rendered = tpl.render( fmt=fmt, ifmt=ifmt, origin_x=org_x, origin_y=org_y, origin_z=org_z, comment1="Generated by pysisyphus", comment2=f"Total of {nx*ny*nz} points", axes=axes, axes_npoints=axes_npoints, atom_nums=atom_nums, coords3d=coords3d, grid_str=grid_str, ) return rendered