Source code for pysisyphus.diabatization.driver_yaml

import argparse
import itertools as it
import sys

import matplotlib.pyplot as plt
import numpy as np
import yaml

from pysisyphus.constants import AU2EV
from pysisyphus.diabatization import logger
from pysisyphus.diabatization.multipole import dq_diabatization


[docs] def make_array(nstates, components, lines): if len(lines) == 0: return None arr = np.zeros((components, nstates, nstates)) expect_diag = nstates expect_off_diag = sum(range(nstates)) for line in lines: from_, to_, *props = line assert (from_ >= 0) and (to_ >= 0) assert ( len(props) == components ), f"Expected line of length {components} but got '{line}'!" if from_ == to_: expect_diag -= 1 elif from_ != to_: expect_off_diag -= 1 arr[:, from_, to_] = arr[:, to_, from_] = props assert expect_diag == 0, "Some diagonal elements are missing!" assert expect_off_diag == 0, "Some off-diagonal elements are missing!" return arr
[docs] def dq_diabatization_from_run_dict(run_dict): adia_ens = np.array(run_dict["adiabatic_energies"], dtype=float) nstates = adia_ens.size # Dipole moments must be present dip_moms = make_array(nstates, 3, run_dict["dipoles"]) # Quadrupole moments and electrostatic potential are optional. quad_moms = make_array(nstates, 1, run_dict.get("quadrupoles", list())) epots = make_array(nstates, 1, run_dict.get("epots", list())) kwargs = {} if "alpha" in run_dict: kwargs["alpha"] = run_dict["alpha"] if "beta" in run_dict: kwargs["beta"] = run_dict["beta"] return dq_diabatization( adia_ens, dip_moms, quad_moms=quad_moms, epots=epots, **kwargs )
[docs] def diabatize_path(adia_ens, dip_moms, tr_quad_moms=None, epots=None, **kwargs): nones = [None for _ in adia_ens] if tr_quad_moms is None: tr_quad_moms = nones if epots is None: epots = nones assert len(adia_ens) == len(dip_moms) == len(tr_quad_moms) == len(epots) for aens, dpm, qpm, epot in zip(adia_ens, dip_moms, tr_quad_moms, epots): yield dq_diabatization(aens, dpm, qpm, epot, **kwargs)
[docs] def plot_dia_res(dia_res, show=False): nstates = dia_res[0].nstates adia_ens = np.zeros((len(dia_res), nstates)) dia_ens = np.zeros((len(dia_res), nstates)) keys = list(it.combinations(range(nstates), 2)) couplings = np.zeros((len(dia_res), len(keys))) for i, dr in enumerate(dia_res): adia_ens[i] = dr.adia_ens dia_ens[i] = dr.dia_ens for j, key in enumerate(keys): couplings[i, j] = dr.couplings[key] adia_min = adia_ens.min() adia_ens -= adia_min adia_ens *= AU2EV dia_ens -= adia_min dia_ens *= AU2EV couplings *= AU2EV fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) for i, state in enumerate(adia_ens.T): ax0.plot(state, "o--", label=f"$V_{i}$") for i, state in enumerate(dia_ens.T): ax0.plot(state, "x-", label=f"$U_{i}$") ax0.legend() ax0.set_xlabel("Step") ax0.set_ylabel(r"$\Delta E$ / eV") # Couplings for i, cpls in enumerate(couplings.T): from_to = "".join([str(_) for _ in keys[i]]) ax1.plot(cpls, "o-", label=f"$|D_{{{from_to}}}|$") ax1.axhline(0.0, ls="--", c="k") ax1.legend() fig.tight_layout() if show: plt.show() return fig, (ax0, ax1)
[docs] def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument("yaml") return parser.parse_args(args)
[docs] def run(): args = parse_args(sys.argv[1:]) yaml_fn = args.yaml with open(yaml_fn) as handle: run_dict = yaml.load(handle, Loader=yaml.SafeLoader) adia_labels = run_dict.pop("adiabatic_labels", None) unit = run_dict.pop("unit", "eV") dia_res = dq_diabatization_from_run_dict(run_dict) report = dia_res.render_report(adia_labels=adia_labels, unit=unit) logger.info(report)
if __name__ == "__main__": run()