Source code for pysisyphus.plotters.AnimPlot

import os

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

from pysisyphus.helpers import get_coords_diffs
from pysisyphus.cos.GrowingChainOfStates import GrowingChainOfStates


[docs] class AnimPlot: def __init__( self, calculator, optimizer, xlim=None, ylim=None, levels=None, num=100, figsize=(8, 8), interval=250, energy_profile=True, colorbar=True, save=None, title=True, tight_layout=False, ): self.calculator = calculator self.optimizer = optimizer self.interval = interval if xlim is None: try: xlim = calculator.xlim except AttributeError: xlim = (-1, 1) if ylim is None: try: ylim = calculator.ylim except AttributeError: ylim = (-1, 1) if levels is None: try: lvls = calculator.levels levels = (lvls.min(), lvls.max(), lvls.size) except AttributeError: levels = (-150, 5, 30) self.energy_profile = energy_profile self.colorbar = colorbar self.save = save self.title = title self.tight_layout = tight_layout self.growing = isinstance(self.optimizer.geometry, GrowingChainOfStates) self.coords = [c.reshape(-1, 3) for c in self.optimizer.coords] self.forces = [f.reshape((-1, 3)) for f in self.optimizer.forces] self.energies = self.optimizer.energies self.tangents = [t.reshape((-1, 3)) for t in self.optimizer.tangents] # ax: the contour plot # ax1: energy along the path if self.energy_profile: self.fig, (self.ax, self.ax1) = plt.subplots( 2, figsize=figsize, gridspec_kw={"height_ratios": [3, 1]} ) else: self.fig, self.ax = plt.subplots(figsize=figsize) self.pause = True self.fig.canvas.mpl_connect("key_press_event", self.on_keypress) # Calculate the potential x = np.linspace(*xlim, 100) y = np.linspace(*ylim, 100) X, Y = np.meshgrid(x, y) Z = np.full_like(X, 0) fake_atoms = ("H",) pot_coords = np.stack((X, Y, Z)) pot = self.calculator.get_energy(fake_atoms, pot_coords)["energy"] # Draw the contourlines of the potential levels = np.linspace(*levels) contours = self.ax.contour(X, Y, pot, levels) # self.ax.clabel(contours, inline=1, fontsize=5) self.ax.set_xlabel("x") self.ax.set_ylabel("y") if self.colorbar: # Create a colorbar self.fig.subplots_adjust(right=0.8) cbar_ax = self.fig.add_axes([0.85, 0.15, 0.05, 0.7]) self.fig.colorbar(contours, cax=cbar_ax) images_x = self.coords[0][:, 0] images_y = self.coords[0][:, 1] forces_x = self.forces[0][:, 0] forces_y = self.forces[0][:, 1] tangents_x = self.tangents[0][:, 0] tangents_y = self.tangents[0][:, 1] energies = self.energies[0] # Create artists, so we can update their data later # Image positions (self.images,) = self.ax.plot(images_x, images_y, "ro", ls="-") if not self.growing: # Total forces self.total_forces_quiv = self.ax.quiver( images_x, images_y, forces_x, forces_y ) # Tangents self.tangent_quiv = self.ax.quiver( images_x, images_y, tangents_x, tangents_y, color="b" ) # Energy along the path if self.energy_profile: (self.energies_plot,) = self.ax1.plot( get_coords_diffs(self.coords[0]), energies, "ro", ls="-" ) self.ax1.set_xlabel("q(x, y)") self.ax1.set_ylabel("f(x, y)")
[docs] def func(self, frame): if self.title: self.fig.suptitle("Cycle {}".format(frame)) images_x = self.coords[frame][:, 0] images_y = self.coords[frame][:, 1] self.images.set_xdata(images_x) self.images.set_ydata(images_y) if not self.growing: # Update total forces quiver forces_x = self.forces[frame][:, 0] forces_y = self.forces[frame][:, 1] offsets = np.stack((images_x, images_y), axis=-1) # https://stackoverflow.com/questions/19329039 # https://stackoverflow.com/questions/17758942 self.total_forces_quiv.set_offsets(offsets) self.total_forces_quiv.set_UVC(forces_x, forces_y) # Update tangent quiver tangents_x = self.tangents[frame][:, 0] tangents_y = self.tangents[frame][:, 1] self.tangent_quiv.set_offsets(offsets) self.tangent_quiv.set_UVC(tangents_x, tangents_y) if self.energy_profile: coords_diffs = get_coords_diffs(self.coords[frame]) energies = self.energies[frame] self.energies_plot.set_xdata(coords_diffs) self.energies_plot.set_ydata(energies) self.ax1.relim() self.ax1.autoscale_view() if self.tight_layout: plt.tight_layout() if self.save: frame_fn = f"step{frame}.png" if not os.path.exists(frame_fn): self.fig.savefig(frame_fn)
[docs] def animate(self): cycles = range(self.optimizer.cur_cycle) self.animation = animation.FuncAnimation( self.fig, self.func, frames=cycles, interval=self.interval )
[docs] def as_html5(self, out_fn): mpl.rcParams["animation.bitrate"] = 5000 html5 = self.animation.to_html5_video() with open(out_fn, "w") as handle: handle.write(html5)
[docs] def on_keypress(self, event): """Pause on SPACE press.""" # https://stackoverflow.com/questions/41557578 if event.key == " ": if self.pause: self.animation.event_source.stop() else: self.animation.event_source.start() self.pause = not self.pause