import argparse
from collections.abc import Sequence
import sys
import warnings
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from pysisyphus.constants import AU2NU, AU2EV
from pysisyphus.diabatization import logger
EV2NU = AU2NU / AU2EV
[docs]
def state_graph_from_en_mat(
en_mat: np.ndarray, state_inds: Sequence[int], thresh_eV: float = 0.0
) -> nx.Graph:
"""Graph representation of energy matrix with state couplings as edges.
Parameters
----------
en_mat
Quadratic energy matrix containing electronic energies in eV.
state_inds
Sequence of positive integers containing state labels.
thresh_eV
Positive floating point number that can be used to filter the
couplings. If a coupling is below 'thresh_eV', no edge is created.
Defaults to 0.0 eV, so by default all couplings are included.
Returns
-------
G
networkx.Graph representation of the energy matrix with states as
nodes and couplings as edges. The node have an 'energy' attribute,
containing the state energy in eV and the edges have an "weight"
attribute, containing the electronic coupling in eV.
"""
nrows, ncols = en_mat.shape
assert nrows == ncols
nstates = len(state_inds)
assert nstates == nrows
# Matrix must be symmetric
np.testing.assert_allclose(en_mat, en_mat.T, atol=1e-12)
# Create nodes from states
state_ens = np.diag(en_mat)
G = nx.Graph()
for i, state_ind in enumerate(state_inds):
energy = state_ens[i]
G.add_node(state_ind, energy=energy)
# Add couplingss as edges
for i in range(nstates):
state_i = state_inds[i]
for j in range(i + 1, nstates):
state_j = state_inds[j]
coupling = abs(en_mat[i, j])
if coupling > thresh_eV:
coupling_meV = coupling * 1e3
G.add_edge(state_i, state_j, weight=coupling_meV)
return G
[docs]
def map_array_to_interval(
arr: np.ndarray, min_new: float, max_new: float, thresh=1e-12
) -> np.ndarray:
"""Map array from interval [array.min(), array.max()] to [min_new, max_new].
Parameter
---------
arr
1d array containing floating point numbers.
min_new
Lower bound of the new interval.
max_new
Upper bound of the new interval.
thresh
When 'max_new - min_new' falls between this threshold an exception is
raised.
Returns
-------
mapped
1d array w/ original shapped mapped onto the new interval.
"""
spread_new = abs(max_new - min_new)
if len(arr) == 1:
mean = max(spread_new / 2.0, min_new)
mapped = np.array((mean,))
return mapped
min_ = arr.min()
max_ = arr.max()
spread_org = max_ - min_
if (thresh is not None) and (spread_new < thresh):
raise Exception(
f"'max_new - min_new = {spread_new: >10.4e}' < {thresh=: >10.4e}! Either set "
f"thresh to None or decrease thresh below {spread_new: >10.4e}"
)
scale = spread_org / spread_new
if scale <= 1e-14:
warnings.warn(
f"Obtained very small scaling factor {scale=: >10.4e}! "
"Results may be unreliable/wrong!"
)
mapped = min_new + (arr - min_) / scale
return mapped
[docs]
def draw_state_graph(G: nx.Graph) -> plt.Figure:
"""Draw state graph."""
fig, ax = plt.subplots(figsize=(16, 8))
# Spring layout takes "weight" atributes of edges into account.
pos = nx.spring_layout(G)
nx.draw(
G,
pos=pos,
ax=ax,
node_color="white",
edgecolors="purple",
node_size=3000,
linewidths=2,
)
# Choose edge widths according to coupling strength.
#
# Loop over all edges to determine the minimum and the maximum edge weight.
min_weight = float("inf")
max_weight = -1
widths = list()
for edge in G.edges:
weight = G.get_edge_data(*edge)["weight"]
min_weight = min(min_weight, weight)
max_weight = min(max_weight, weight)
widths.append(weight)
# Map edge weights onto new interval [1, 12], so the highest coupling has
# width 12 and the lowest coupling has width 1.
# TODO: what to do when couplings are degenerate?
widths = map_array_to_interval(np.array(widths), 1.0, 13.0)
nx.draw_networkx_edges(G, pos, width=widths)
# Draw node labels. Render one node per state w/ its energy.
labels = nx.get_node_attributes(G, "energy")
labels = {k: f"State {k}\n{v: >6.3} eV" for k, v in labels.items()}
nx.draw_networkx_labels(G, pos, labels=labels, font_size=10)
# Draw couplings as edge labels in meV
edge_labels = nx.get_edge_attributes(G, "weight")
edge_labels = {k: f"{v:.1f}" for k, v in edge_labels.items()}
label_pos = 0.6
nx.draw_networkx_edge_labels(G, pos, edge_labels, label_pos=label_pos)
ax.set_title(f"Couplings in meV; labels at position {label_pos}.")
return fig
[docs]
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument("fn", help="Path to pysisyphus diabatiatzion result file.")
parser.add_argument("--out-fn", default=None)
parser.add_argument("--state-inds", nargs="+", type=int)
return parser.parse_args(args)
[docs]
def run():
args = parse_args(sys.argv[1:])
fn = args.fn
out_fn = args.out_fn
# TODO: read state_inds from npz file
state_inds = args.state_inds
logger.info(f"Rendering {fn}")
data = np.load(fn)
adia_ens = data["adia_ens"]
try:
state_inds = data["states"]
except KeyError:
# If state_inds was not set via --state-inds we enumerate them by ourselves,
# starting from 0.
if not state_inds:
state_inds = list(range(len(adia_ens)))
logger.info(f"Using {state_inds=}")
with np.printoptions(precision=4, formatter={"float": lambda f: f"{f: >8.4f}"}):
logger.info(f"Adiabatic energies: {adia_ens} eV")
adia_mat = np.diag(adia_ens)
U = data["U"]
dia_mat = U.T @ adia_mat @ U
G = state_graph_from_en_mat(dia_mat, state_inds=state_inds)
fig = draw_state_graph(G)
fig.suptitle(fn)
fig.tight_layout()
if out_fn is not None:
fig.savefig(out_fn)
logger.info(f"Saved figure to '{out_fn}'.")
plt.show()
if __name__ == "__main__":
run()