Source code for alchemlyb.visualisation.ti_dhdl

"""Functions for Plotting the dhdl for the TI estimator.

To assess the quality of the TI estimator, the dhdl from lambda state 0
to lambda state 1 can plotted to assess if the change in dhdl is sampled
thoroughly.

The code for producing the dhdl plot is modified based on
`Alchemical Analysis <https://github.com/MobleyLab/alchemical-analysis>`_.

"""

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties as FP
from matplotlib.axes import Axes
from typing import Any

from ..postprocessors.units import get_unit_converter


[docs] def plot_ti_dhdl( dhdl_data: Any, labels: None | list[str] = None, colors: None | list[str] = None, units: None | str = None, ax: None | Axes = None, ) -> Axes: """Plot the dhdl of TI. Parameters ---------- dhdl_data : :class:`~alchemlyb.estimators.TI` or list One or more :class:`~alchemlyb.estimators.TI` estimator, where the dhdl value will be taken from. labels : List list of labels for labelling all the alchemical transformations. colors : List list of colors for plotting all the alchemical transformations. Default: ['r', 'g', '#7F38EC', '#9F000F', 'b', 'y'] units : str The unit of the estimate. The default is `None`, which is to use the unit in the input. Setting this will change the output unit. ax : matplotlib.axes.Axes Matplotlib axes object where the plot will be drawn on. If ``ax=None``, a new axes will be generated. Returns ------- matplotlib.axes.Axes An axes with the TI dhdl drawn. Note ---- The code is taken and modified from `Alchemical Analysis <https://github.com/MobleyLab/alchemical-analysis>`_. .. versionchanged:: 1.0.0 If no units is given, the `units` in the input will be used. .. versionchanged:: 0.5.0 The `units` will be used to change the underlying data instead of only changing the figure legend. .. versionadded:: 0.4.0 """ # Make it into a list # separate_dhdl method is used so that the input for the actual plotting # Function are a uniformed list of series object which only contains one # lambda. if not isinstance(dhdl_data, list): dhdl_list = dhdl_data.separate_dhdl() else: dhdl_list = [] for dhdl in dhdl_data: dhdl_list.extend(dhdl.separate_dhdl()) # Convert unit if units is None: units = dhdl_list[0].attrs["energy_unit"] new_unit = [] convert = get_unit_converter(units) for dhdl in dhdl_list: new_unit.append(convert(dhdl)) dhdl_list = new_unit if ax is None: fig, ax = plt.subplots(figsize=(8, 6)) ax.spines["bottom"].set_position("zero") ax.spines["top"].set_color("none") ax.spines["right"].set_color("none") ax.xaxis.set_ticks_position("bottom") ax.yaxis.set_ticks_position("left") for k, spine in ax.spines.items(): spine.set_zorder(12.2) # Make level names if labels is None: lv_names2 = [] for dhdl in dhdl_list: # Assume that the dhdl has only one columns lv_names2.append(dhdl.name.capitalize()) else: if len(labels) == len(dhdl_list): lv_names2 = labels else: # pragma: no cover raise ValueError( "Length of labels ({}) should be the same as the number of data ({})".format( len(labels), len(dhdl_list) ) ) if colors is None: colors = ["r", "g", "#7F38EC", "#9F000F", "b", "y"] else: if len(colors) >= len(dhdl_list): pass else: # pragma: no cover raise ValueError( "Number of colors ({}) should be larger than the number of data ({})".format( len(labels), # type: ignore[arg-type] len(dhdl_list), ) ) # Get the real data out xs, ndx, _dx = [0], 0, 0.001 min_y, max_y = 0, 0 for dhdl in dhdl_list: x = dhdl.index.values y = dhdl.values.ravel() min_y = min(y.min(), min_y) max_y = max(y.max(), max_y) for i in range(len(x) - 1): if i % 2 == 0: ax.fill_between( x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=1.0 ) else: ax.fill_between( x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=0.5 ) xlegend = [-100 * wnum for wnum in range(len(lv_names2))] ax.plot( xlegend, [0 * wnum for wnum in xlegend], ls="-", color=colors[ndx], label=lv_names2[ndx], ) xs += (x + ndx).tolist()[1:] ndx += 1 # Make sure the tick labels are not overcrowded. xs = np.array(xs) # type: ignore[assignment] dl_mat = np.array([xs - i for i in xs]) # type: ignore[operator] ri = range(len(xs)) def getInd(r: range = ri, z: list[int] = [0]) -> list[int]: # type: ignore[return] primo = r[0] min_dl = ndx * 0.02 * 2 ** (primo > 10) if dl_mat[primo].max() < min_dl: return z for i in r: # pragma: no cover for j in range(len(xs)): if dl_mat[i, j] > min_dl: z.append(j) return getInd(ri[j:], z) xt = [] for i in range(len(xs)): if i in getInd(): xt.append(i) else: xt.append("") # type: ignore[arg-type] plt.xticks(xs[1:], xt[1:], fontsize=10) # type: ignore[arg-type] ax.yaxis.label.set_size(10) # type: ignore[attr-defined] # Remove the abscissa ticks and set up the axes limits. for tick in ax.get_xticklines(): tick.set_visible(False) ax.set_xlim(0, ndx) min_y *= 1.01 # type: ignore[assignment] max_y *= 1.01 # type: ignore[assignment] # Modified so that the x label won't conflict with the lambda label min_y -= (max_y - min_y) * 0.1 # type: ignore[assignment] ax.set_ylim(min_y, max_y) for i, j in zip(xs[1:], xt[1:]): ax.annotate( ("{:.2f}".format(i - 1.0 if i > 1.0 else i) if not j == "" else ""), xy=(i, 0), size=10, rotation=90, va="bottom", ha="center", color="#151B54", ) if ndx > 1: lenticks = len(ax.get_ymajorticklabels()) - 1 if min_y < 0: lenticks -= 1 if lenticks < 5: # pragma: no cover from matplotlib.ticker import AutoMinorLocator as AML ax.yaxis.set_minor_locator(AML()) ax.grid(which="both", color="w", lw=0.25, axis="y", zorder=12) ax.set_ylabel( r"$\langle{\frac{\partial U}{\partial\lambda}}\rangle_{\lambda}$" + "({})".format(units), fontsize=20, color="#151B54", ) ax.annotate( r"$\mathit{\lambda}$", xy=(0, 0), xytext=(0.5, -0.05), size=18, textcoords="axes fraction", va="top", ha="center", color="#151B54", ) lege = ax.legend(prop=FP(size=14), frameon=False, loc=1) for legend_handle in lege.legend_handles: legend_handle.set_linewidth(10) # type: ignore[union-attr] return ax