"""Functions for Plotting the dF states.
To assess the quality of the free energy estimation, The dF between adjacent
lambda states can be plotted to assess the quality of the estimation.
The code for producing the dF states plot is modified based on
`Alchemical Analysis <https://github.com/MobleyLab/alchemical-analysis>`_.
"""
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties as FP
from matplotlib.figure import Figure
from typing import Any
from ..estimators import TI, BAR, MBAR
from ..postprocessors.units import get_unit_converter
[docs]
def plot_dF_state(
estimators: Any,
labels: None | list[str] = None,
colors: None | list[str] = None,
units: None | str = None,
orientation: str = "portrait",
nb: int = 10,
) -> Figure:
"""Plot the dhdl of TI.
Parameters
----------
estimators : :class:`~alchemlyb.estimators` or list
One or more :class:`~alchemlyb.estimators`, where the
dhdl value will be taken from. For more than one estimators
with more than one alchemical transformation, a list of list format
is used.
labels : List
list of labels for labelling different estimators.
colors : List
list of colors for plotting different estimators.
units : str
The unit of the estimate. The default is `None`, which is to use the
unit in the input. Setting this will change the output unit.
orientation : string
The orientation of the figure. Can be `portrait` or `landscape`
nb : int
Maximum number of dF states in one row in the `portrait` mode
Returns
-------
matplotlib.figure.Figure
An Figure with the dF states drawn.
Note
----
The code is taken and modified from
`Alchemical Analysis <https://github.com/MobleyLab/alchemical-analysis>`_.
.. versionchanged:: 1.0.0
If no units is given, the `units` in the input will be used.
.. versionchanged:: 0.5.0
The `units` will be used to change the underlying data instead of only
changing the figure legend.
.. versionadded:: 0.4.0
"""
try:
len(estimators)
except TypeError:
estimators = [
estimators,
]
formatted_data = []
for dhdl in estimators:
try:
len(dhdl)
formatted_data.append(dhdl)
except TypeError:
formatted_data.append(
[
dhdl,
]
)
if units is None:
units = formatted_data[0][0].delta_f_.attrs["energy_unit"]
estimators = formatted_data
# Get the dF
dF_list = []
error_list = []
max_length = 0
convert = get_unit_converter(units)
for dhdl_list in estimators:
len_dF = sum([len(dhdl.delta_f_) - 1 for dhdl in dhdl_list])
if len_dF > max_length:
max_length = len_dF
dF = []
error = []
for dhdl in dhdl_list:
for i in range(len(dhdl.delta_f_) - 1):
dF.append(convert(dhdl.delta_f_).iloc[i, i + 1])
error.append(convert(dhdl.d_delta_f_).iloc[i, i + 1])
dF_list.append(dF)
error_list.append(error)
# Get the determine orientation
if orientation == "landscape":
if max_length < 8:
fig, ax = plt.subplots(figsize=(8, 6))
else:
fig, ax = plt.subplots(figsize=(max_length, 6))
axs = [
ax,
]
xs = [
np.arange(max_length),
]
elif orientation == "portrait":
if max_length < nb:
xs = [
np.arange(max_length),
]
fig, ax = plt.subplots(figsize=(8, 6))
axs = [
ax,
]
else:
xs = np.array_split(np.arange(max_length), max_length / nb + 1) # type: ignore[call-overload]
fig, axs = plt.subplots(nrows=len(xs), figsize=(8, 6))
mnb = max([len(i) for i in xs])
else:
raise ValueError(
"Not recognising {}, only supports 'landscape' or 'portrait'.".format(
orientation
)
)
# Sort out the colors
if colors is None:
colors_dict = {
"TI": "#C45AEC",
"TI-CUBIC": "#33CC33",
"DEXP": "#F87431",
"IEXP": "#FF3030",
"GINS": "#EAC117",
"GDEL": "#347235",
"BAR": "#6698FF",
"UBAR": "#817339",
"RBAR": "#C11B17",
"MBAR": "#F9B7FF",
}
colors = []
for dhdl in estimators:
dhdl = dhdl[0]
if isinstance(dhdl, TI):
colors.append(colors_dict["TI"])
elif isinstance(dhdl, BAR):
colors.append(colors_dict["BAR"])
elif isinstance(dhdl, MBAR):
colors.append(colors_dict["MBAR"])
else:
if len(colors) >= len(estimators):
pass
else:
raise ValueError(
"Number of colors ({}) should be larger than the number of data ({})".format(
len(colors), len(estimators)
)
)
# Sort out the labels
if labels is None:
labels = []
for dhdl in estimators:
dhdl = dhdl[0]
if isinstance(dhdl, TI):
labels.append("TI")
elif isinstance(dhdl, BAR):
labels.append("BAR")
elif isinstance(dhdl, MBAR):
labels.append("MBAR")
else:
if len(labels) == len(estimators):
pass
else:
raise ValueError(
"Length of labels ({}) should be the same as the number of data ({})".format(
len(labels), len(estimators)
)
)
# Plot the figure
width = 1.0 / (len(estimators) + 1)
elw = 30 * width
ndx = 1
for x, ax in zip(xs, axs):
lines: list[Any] = []
for i, (dF, error) in enumerate(zip(dF_list, error_list)):
y = [dF[j] for j in x]
ye = [error[j] for j in x]
if orientation == "landscape":
lw = 0.1 * elw
elif orientation == "portrait":
lw = 0.05 * elw
line = ax.bar(
x + len(lines) * width,
y,
width,
color=colors[i],
yerr=ye,
lw=lw,
error_kw=dict(elinewidth=elw, ecolor="black", capsize=0.5 * elw),
)
lines += (line[0],)
for dir in ["left", "right", "top", "bottom"]:
if dir == "left":
ax.yaxis.set_ticks_position(dir) # type: ignore[arg-type]
else:
ax.spines[dir].set_color("none")
if orientation == "landscape":
plt.yticks(fontsize=8)
ax.set_xlim(x[0] - width, x[-1] + len(lines) * width)
plt.xticks(
x + 0.5 * width * len(estimators),
tuple([f"{i}--{i + 1}" for i in x]),
fontsize=8,
)
elif orientation == "portrait":
plt.yticks(fontsize=10)
ax.xaxis.set_ticks([])
for i in x + 0.5 * width * len(estimators):
ax.annotate(
r"$\mathrm{%d-%d}$" % (i, i + 1),
xy=(i, 0),
xycoords=("data", "axes fraction"),
xytext=(0, -2),
size=10,
textcoords="offset points",
va="top",
ha="center",
)
ax.set_xlim(x[0] - width, x[-1] + len(lines) * width + (mnb - len(x)))
ndx += 1
x = np.arange(max_length)
ax = plt.gca()
for tick in ax.get_xticklines():
tick.set_visible(False)
if orientation == "landscape":
leg = plt.legend(lines, labels, loc=3, ncol=2, prop=FP(size=10), fancybox=True)
plt.title("The free energy change breakdown", fontsize=12)
plt.xlabel("States", fontsize=12, color="#151B54")
plt.ylabel(r"$\Delta G$ ({})".format(units), fontsize=12, color="#151B54")
elif orientation == "portrait":
leg = ax.legend(
lines,
labels,
loc=0,
ncol=2,
prop=FP(size=8),
title=r"$\Delta G$ ({})".format(units) + r"$\mathit{vs.}$ lambda pair",
fancybox=True,
)
leg.get_frame().set_alpha(0.5)
return fig