import os
import warnings
from os.path import join
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import joblib
from loguru import logger
from matplotlib.axes import Axes
from typing import Any, Callable
from .base import WorkflowBase
from .. import concat
from ..convergence import forward_backward_convergence
from ..estimators import MBAR, BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS
from ..parsing import gmx, amber, parquet
from ..postprocessors.units import get_unit_converter
from ..preprocessing.subsampling import decorrelate_dhdl, decorrelate_u_nk
from ..visualisation import (
plot_mbar_overlap_matrix,
plot_ti_dhdl,
plot_dF_state,
plot_convergence,
)
[docs]
class ABFE(WorkflowBase):
"""Workflow for absolute and relative binding free energy calculations.
This workflow provides functionality similar to the ``alchemical-analysis.py`` script.
It loads multiple input files from alchemical free energy calculations and computes the
free energies between different alchemical windows using different estimators. It
produces plots to aid in the assessment of convergence.
Parameters
----------
T : float
Temperature in K.
units : str
The unit used for printing and plotting results. {'kcal/mol', 'kJ/mol',
'kT'}. Default: 'kT'.
software : str
The software used for generating input (case-insensitive). {'GROMACS', 'AMBER', 'PARQUET'}.
This option chooses the appropriate parser for the input file.
dir : str
Directory in which data files are stored. Default: os.path.curdir.
prefix : str
Prefix for datafile sets. This argument accepts regular expressions and
the input files are searched using
``Path(dir).glob("**/" + prefix + "*" + suffix)``. Default: 'dhdl'.
suffix : str
Suffix for datafile sets. Default: 'xvg'.
outdirectory : str
Directory in which the output files produced by this script will be
stored. Default: os.path.curdir.
Attributes
----------
logger : Logger
The logging object.
file_list : list
The list of filenames sorted by the lambda state.
.. versionadded:: 1.0.0
.. versionchanged:: 2.0.1
The `dir` argument expects a real directory without wildcards and wildcards will no longer
work as expected. Use `prefix` to specify wildcard-based patterns to search under `dir`.
.. versionchanged:: 2.1.0
The serialised dataframe could be read via software='PARQUET'.
"""
[docs]
def __init__(
self,
T: float,
units: str = "kT",
software: str = "GROMACS",
dir: str = os.path.curdir,
prefix: str = "dhdl",
suffix: str = "xvg",
outdirectory: str = os.path.curdir,
) -> None:
super().__init__(units, software, T, outdirectory)
logger.info("Initialise Alchemlyb ABFE Workflow")
self.update_units(units)
logger.info(
f"Finding files with prefix: {prefix}, suffix: "
f"{suffix} under directory {dir} produced by "
f"{software}"
)
reg_exp = "**/" + prefix + "*" + suffix
if "*" in dir:
warnings.warn(
f"A real directory is expected in `dir`={dir}, wildcard expressions should be supplied to `prefix`."
)
if not Path(dir).is_dir():
raise ValueError(f"The input directory `dir`={dir} is not a directory.")
self.file_list = list(map(str, Path(dir).glob(reg_exp)))
if len(self.file_list) == 0:
raise ValueError(f"No file has been matched to {reg_exp}.")
logger.info(f"Found {len(self.file_list)} {suffix} files.")
logger.info("Unsorted file list: \n {}", "\n".join(self.file_list))
if software == "GROMACS":
logger.info(f"Using {software} parser to read the data.")
self._extract_u_nk = gmx.extract_u_nk
self._extract_dHdl = gmx.extract_dHdl
elif software == "AMBER":
self._extract_u_nk = amber.extract_u_nk
self._extract_dHdl = amber.extract_dHdl
elif software == "PARQUET":
self._extract_u_nk = parquet.extract_u_nk
self._extract_dHdl = parquet.extract_dHdl
else:
raise NotImplementedError(f"{software} parser not found.")
[docs]
def read(
self, read_u_nk: bool = True, read_dHdl: bool = True, n_jobs: int = 1
) -> None:
"""Read the u_nk and dHdL data from the
:attr:`~alchemlyb.workflows.ABFE.file_list`
Parameters
----------
read_u_nk : bool
Whether to read the u_nk.
read_dHdl : bool
Whether to read the dHdl.
n_jobs : int
Number of parallel workers to use for reading the data.
(-1 means using all the threads)
Attributes
----------
u_nk_list : list
A list of :class:`pandas.DataFrame` of u_nk.
dHdl_list : list
A list of :class:`pandas.DataFrame` of dHdl.
"""
self.u_nk_sample_list = None # type: ignore[assignment]
self.dHdl_sample_list = None # type: ignore[assignment]
if read_u_nk:
def extract_u_nk(
_extract_u_nk: Callable, file: str, T: float
) -> pd.DataFrame:
try:
u_nk = _extract_u_nk(file, T)
logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}")
return u_nk # type: ignore[no-any-return]
except Exception as exc:
msg = f"Error reading u_nk from {file}."
logger.error(msg)
raise OSError(msg) from exc
u_nk_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(extract_u_nk)(self._extract_u_nk, file, self.T)
for file in self.file_list
)
else:
u_nk_list = []
if read_dHdl:
def extract_dHdl(
_extract_dHdl: Callable, file: str, T: float
) -> pd.DataFrame:
try:
dhdl = _extract_dHdl(file, T)
logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}")
return dhdl # type: ignore[no-any-return]
except Exception as exc:
msg = f"Error reading dHdl from {file}."
logger.error(msg)
raise OSError(msg) from exc
dHdl_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(extract_dHdl)(self._extract_dHdl, file, self.T)
for file in self.file_list
)
else:
dHdl_list = []
# Sort the files according to the state
if read_u_nk:
logger.info("Sort files according to the u_nk.")
column_names = u_nk_list[0].columns.values.tolist()
index_list = sorted(
range(len(self.file_list)),
key=lambda x: column_names.index(
u_nk_list[x].reset_index("time").index.values[0]
),
)
elif read_dHdl:
logger.info("Sort files according to the dHdl.")
index_list = sorted(
range(len(self.file_list)),
key=lambda x: dHdl_list[x].reset_index("time").index.values[0],
)
else:
self.u_nk_list = []
self.dHdl_list = []
return
self.file_list = [self.file_list[i] for i in index_list]
logger.info("Sorted file list: \n{}", "\n".join(self.file_list))
if read_u_nk:
self.u_nk_list = [u_nk_list[i] for i in index_list]
else:
self.u_nk_list = []
if read_dHdl:
self.dHdl_list = [dHdl_list[i] for i in index_list]
else:
self.dHdl_list = []
[docs]
def run(
self,
skiptime: float = 0,
uncorr: str = "dE",
threshold: int = 50,
estimators: tuple[str, ...] = ("MBAR", "BAR", "TI"),
overlap: str = "O_MBAR.pdf",
breakdown: bool = True,
forwrev: None | int = None,
n_jobs: int = 1,
*args: Any,
**kwargs: Any,
) -> None:
"""The method for running the automatic analysis.
Parameters
----------
skiptime : float
Discard data prior to this specified time as 'equilibration' data.
Units are specified by the corresponding MD Engine. Default: 0.
uncorr : str
The observable to be used for the autocorrelation analysis; 'dE'.
threshold : int
Proceed with correlated samples if the number of uncorrelated samples is
found to be less than this number. If 0 is given, the time series
analysis will not be performed at all. Default: 50.
estimators : str or list of str
A list of the estimators to estimate the free energy with. Default:
`('MBAR', 'BAR', 'TI')`.
overlap : str
The filename for the plot of overlap matrix. Default: 'O_MBAR.pdf'.
breakdown : bool
Plot the free energy differences evaluated for each pair of adjacent
states for all methods, including the dH/dlambda curve for TI. Default:
``True``.
forwrev : int
Plot the free energy change as a function of time in both directions,
with the specified number of points in the time plot. The number of time
points (an integer) must be provided. Specify as ``None`` will not do
the convergence analysis. Default: None. By default, 'MBAR'
estimator will be used for convergence analysis, as it is
usually the fastest converging method. If the dataset does not
contain u_nk, please run
meth:`~alchemlyb.workflows.ABFE.check_convergence` manually
with estimator='TI'.
n_jobs : int
Number of parallel workers to use for reading and decorrelating the data.
(-1 means using all the threads)
Attributes
----------
summary : Dataframe
The summary of the free energy estimate.
convergence : DataFrame
The summary of the convergence results. See
:func:`~alchemlyb.convergence.forward_backward_convergence` for
further explanation.
"""
use_FEP = False
use_TI = False
if estimators is not None:
if isinstance(estimators, str): # type: ignore[unreachable]
estimators = [ # type: ignore[unreachable]
estimators
]
for estimator in estimators:
if estimator in FEP_ESTIMATORS:
use_FEP = True
elif estimator in TI_ESTIMATORS:
use_TI = True
else:
msg = (
f"Estimator {estimator} is not supported. Choose one from "
f"{FEP_ESTIMATORS + TI_ESTIMATORS}."
)
logger.error(msg)
raise ValueError(msg)
self.read(read_u_nk=use_FEP, read_dHdl=use_TI, n_jobs=n_jobs)
if uncorr is not None:
self.preprocess(
skiptime=skiptime, uncorr=uncorr, threshold=threshold, n_jobs=n_jobs
)
if estimators is not None:
self.estimate(estimators) # type: ignore[arg-type]
self.generate_result()
if overlap is not None and use_FEP:
ax = self.plot_overlap_matrix(overlap)
plt.close(ax.figure) # type: ignore[union-attr,arg-type]
if breakdown:
if use_TI:
ax = self.plot_ti_dhdl()
plt.close(ax.figure) # type: ignore[union-attr,arg-type]
fig = self.plot_dF_state()
plt.close(fig)
fig = self.plot_dF_state(
dF_state="dF_state_long.pdf", orientation="landscape"
)
plt.close(fig)
if forwrev:
ax = self.check_convergence(forwrev, estimator="MBAR", dF_t="dF_t.pdf")
plt.close(ax.figure) # type: ignore[union-attr,arg-type]
[docs]
def update_units(self, units: None | str = None) -> None:
"""Update the unit.
Parameters
----------
units : {'kcal/mol', 'kJ/mol', 'kT'}
The unit used for printing and plotting results.
"""
if units is not None:
logger.info(f"Set unit to {units}.")
self.units = units or None
[docs]
def preprocess(
self,
skiptime: float = 0,
uncorr: str = "dE",
threshold: int = 50,
n_jobs: int = 1,
) -> None:
"""Preprocess the data by removing the equilibration time and
decorrelate the date.
Parameters
----------
skiptime : float
Discard data prior to this specified time as 'equilibration' data.
Units are specified by the corresponding MD Engine. Default: 0.
uncorr : str
The observable to be used for the autocorrelation analysis; 'dE'.
threshold : int
Proceed with correlated samples if the number of uncorrelated
samples is found to be less than this number. If 0 is given, the
time series analysis will not be performed at all. Default: 50.
n_jobs : int
Number of parallel workers to use for decorrelating the data.
(-1 means using all the threads)
Attributes
----------
u_nk_sample_list : list
The list of u_nk after decorrelation.
dHdl_sample_list : list
The list of dHdl after decorrelation.
"""
logger.info(
f"Start preprocessing with skiptime of {skiptime} "
f"uncorrelation method of {uncorr} and threshold of "
f"{threshold}"
)
if len(self.u_nk_list) > 0:
logger.info(f"Processing the u_nk data set with skiptime of {skiptime}.")
def _decorrelate_u_nk(
u_nk: pd.DataFrame, skiptime: float, threshold: int, index: int
) -> pd.DataFrame:
u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime]
subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True)
if len(subsample) < threshold:
logger.warning(
f"Number of u_nk {len(subsample)} "
f"for state {index} is less than the "
f"threshold {threshold}."
)
logger.info(f"Take all the u_nk for state {index}.")
subsample = u_nk
else:
logger.info(
f"Take {len(subsample)} uncorrelated u_nk for state {index}."
)
return subsample
self.u_nk_sample_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index)
for index, u_nk in enumerate(self.u_nk_list)
)
else:
logger.info("No u_nk data being subsampled")
if len(self.dHdl_list) > 0:
def _decorrelate_dhdl(
dHdl: pd.DataFrame, skiptime: float, threshold: int, index: int
) -> pd.DataFrame:
dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime]
subsample = decorrelate_dhdl(dHdl, remove_burnin=True)
if len(subsample) < threshold:
logger.warning(
f"Number of dHdl {len(subsample)} for "
f"state {index} is less than the "
f"threshold {threshold}."
)
logger.info(f"Take all the dHdl for state {index}.")
subsample = dHdl
else:
logger.info(
f"Take {len(subsample)} uncorrelated dHdl for state {index}."
)
return subsample
self.dHdl_sample_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index)
for index, dHdl in enumerate(self.dHdl_list)
)
else:
logger.info("No dHdl data being subsampled")
[docs]
def estimate(
self,
estimators: tuple[str] = ("MBAR", "BAR", "TI"), # type: ignore[assignment]
**kwargs: Any,
) -> None:
"""Estimate the free energy using the selected estimator.
Parameters
----------
estimators : str or list of str
A list of the estimators to estimate the free energy with. Default:
['TI', 'BAR', 'MBAR'].
kwargs : dict
Keyword arguments to be passed to the estimator.
Attributes
----------
estimator : dict
The dictionary of estimators. The keys are in ['TI', 'BAR',
'MBAR']. Note that the estimators are in their original form where
no unit conversion has been attempted.
.. versionchanged:: 2.1.0
DeprecationWarning for using analytic error for MBAR estimator; from 2.6.0
onwards, the `MBAR bootstrap error`_ will be used instead.
.. _`MBAR bootstrap error`:
https://pymbar.readthedocs.io/en/stable/pymbar.mbar.html#pymbar.mbar.MBAR.bootstrap_error
"""
# Make estimators into a tuple
if isinstance(estimators, str):
estimators = (estimators,) # type: ignore[unreachable]
for estimator in estimators:
if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS):
msg = f"Estimator {estimator} is not available in {FEP_ESTIMATORS + TI_ESTIMATORS}."
logger.error(msg)
raise ValueError(msg)
logger.info(f"Start running estimator: {','.join(estimators)}.")
self.estimator = {}
# Use unprocessed data if preprocess is not performed.
if "TI" in estimators:
if self.dHdl_sample_list is not None:
dHdl = concat(self.dHdl_sample_list)
else:
dHdl = concat(self.dHdl_list) # type: ignore[unreachable]
logger.warning("dHdl has not been preprocessed.")
logger.info(f"A total {len(dHdl)} lines of dHdl is used.")
if "BAR" in estimators or "MBAR" in estimators:
if self.u_nk_sample_list is not None:
u_nk = concat(self.u_nk_sample_list)
else:
u_nk = concat(self.u_nk_list) # type: ignore[unreachable]
logger.warning("u_nk has not been preprocessed.")
logger.info(f"A total {len(u_nk)} lines of u_nk is used.")
for estimator in estimators:
if estimator == "MBAR":
logger.info("Run MBAR estimator.")
warnings.warn(
"From 2.6.0, n_bootstraps=50 will be the default for estimating MBAR error.",
DeprecationWarning,
)
self.estimator[estimator] = MBAR(**kwargs).fit(u_nk) # type: ignore[arg-type]
elif estimator == "BAR":
logger.info("Run BAR estimator.")
self.estimator[estimator] = BAR(**kwargs).fit(u_nk) # type: ignore[arg-type]
elif estimator == "TI":
logger.info("Run TI estimator.")
self.estimator[estimator] = TI(**kwargs).fit(dHdl) # type: ignore[arg-type]
[docs]
def generate_result(self) -> pd.DataFrame:
"""Summarise the result into a dataframe.
Returns
-------
DataFrame
The DataFrame with convergence data. ::
MBAR MBAR_Error BAR BAR_Error TI TI_Error
States 0 -- 1 0.065967 0.001293 0.066544 0.001661 0.066663 0.001675
1 -- 2 0.089774 0.001398 0.089303 0.002101 0.089566 0.002144
2 -- 3 0.132036 0.001638 0.132687 0.002990 0.133292 0.003055
3 -- 4 0.116494 0.001213 0.116348 0.002691 0.116845 0.002750
4 -- 5 0.105251 0.000980 0.106344 0.002337 0.106603 0.002362
5 -- 6 0.349320 0.002781 0.343399 0.006839 0.350568 0.007393
6 -- 7 0.402346 0.002767 0.391368 0.006641 0.395754 0.006961
7 -- 8 0.322284 0.002058 0.319395 0.005333 0.321542 0.005434
8 -- 9 0.434999 0.002683 0.425680 0.006823 0.430251 0.007155
9 -- 10 0.355672 0.002219 0.350564 0.005472 0.352745 0.005591
10 -- 11 3.574227 0.008744 3.513595 0.018711 3.514790 0.018078
11 -- 12 2.896685 0.009905 2.821760 0.017844 2.823210 0.018088
12 -- 13 2.223769 0.011229 2.188885 0.018438 2.189784 0.018478
13 -- 14 1.520978 0.012526 1.493598 0.019155 1.490070 0.019288
14 -- 15 0.911279 0.009527 0.894878 0.015023 0.896010 0.015140
15 -- 16 0.892365 0.010558 0.886706 0.015260 0.884698 0.015392
16 -- 17 1.737971 0.025315 1.720643 0.031416 1.741028 0.030624
17 -- 18 1.790706 0.025560 1.788112 0.029435 1.801695 0.029244
18 -- 19 1.998635 0.023340 2.007404 0.027447 2.019213 0.027096
19 -- 20 2.263475 0.020286 2.265322 0.025023 2.282040 0.024566
20 -- 21 2.565680 0.016695 2.561324 0.023611 2.552977 0.023753
21 -- 22 1.384094 0.007553 1.385837 0.011672 1.381999 0.011991
22 -- 23 1.428567 0.007504 1.422689 0.012524 1.416010 0.013012
23 -- 24 1.440581 0.008059 1.412517 0.013125 1.408267 0.013539
24 -- 25 1.411329 0.009022 1.419167 0.013356 1.411446 0.013795
25 -- 26 1.340320 0.010167 1.360679 0.015213 1.356953 0.015260
26 -- 27 1.243745 0.011239 1.245873 0.015711 1.248959 0.015762
27 -- 28 1.128429 0.012859 1.124554 0.016999 1.121892 0.016962
28 -- 29 1.010313 0.016442 1.005444 0.017692 1.019747 0.017257
Stages coul 10.215658 0.033903 10.017838 0.037086 10.017854 0.048744
vdw 22.547489 0.098699 22.501150 0.077284 22.542936 0.106723
bonded 2.374144 0.014995 2.341631 0.014988 2.363828 0.021078
TOTAL 35.137291 0.103580 34.860619 0.087022 34.924618 0.119206
Attributes
----------
summary : Dataframe
The summary of the free energy estimate.
"""
# Write estimate
logger.info("Summarise the estimate into a dataframe.")
# Make the header name
logger.info("Generate the row names.")
estimator_names = list(self.estimator.keys())
num_states = len(self.estimator[estimator_names[0]].states_) # type: ignore[arg-type]
data_dict: dict[str, list] = {"name": [], "state": []}
for i in range(num_states - 1):
data_dict["name"].append(str(i) + " -- " + str(i + 1))
data_dict["state"].append("States")
try:
u_nk = self.u_nk_list[0]
stages = u_nk.reset_index("time").index.names
logger.info("use the stage name from u_nk")
except Exception:
dHdl = self.dHdl_list[0]
stages = dHdl.reset_index("time").index.names
logger.info("use the stage name from dHdl")
for stage in stages:
data_dict["name"].append(stage.split("-")[0]) # type: ignore[union-attr]
data_dict["state"].append("Stages")
data_dict["name"].append("TOTAL")
data_dict["state"].append("Stages")
col_names = []
for estimator_name, estimator in self.estimator.items():
logger.info(f"Read the results from estimator {estimator_name}")
# Do the unit conversion
delta_f_ = estimator.delta_f_
d_delta_f_ = estimator.d_delta_f_
# Write the estimator header
col_names.append(estimator_name)
col_names.append(estimator_name + "_Error")
data_dict[estimator_name] = []
data_dict[estimator_name + "_Error"] = []
for index in range(1, num_states):
data_dict[estimator_name].append(delta_f_.iloc[index - 1, index]) # type: ignore[union-attr]
data_dict[estimator_name + "_Error"].append(
d_delta_f_.iloc[index - 1, index] # type: ignore[union-attr]
)
logger.info(f"Generate the staged result from estimator {estimator_name}")
for index, stage in enumerate(stages):
if len(stages) == 1:
start = 0
end = len(estimator.states_) - 1 # type: ignore[arg-type]
else:
# Get the start and the end of the state
lambda_min = min([state[index] for state in estimator.states_]) # type: ignore[union-attr]
lambda_max = max([state[index] for state in estimator.states_]) # type: ignore[union-attr]
if lambda_min == lambda_max:
# Deal with the case where a certain lambda is used but
# not perturbed
start = 0
end = 0
else:
states = [state[index] for state in estimator.states_] # type: ignore[union-attr]
start = list(reversed(states)).index(lambda_min)
start = num_states - start - 1
end = states.index(lambda_max)
logger.info(f"Stage {stage} is from state {start} to state {end}.")
# This assumes that the indexes are sorted as the
# preprocessing should sort the index of the df.
result = delta_f_.iloc[start, end] # type: ignore[union-attr]
if estimator_name != "BAR":
error = d_delta_f_.iloc[start, end] # type: ignore[union-attr]
else:
error = np.sqrt(
sum(
[
d_delta_f_.iloc[i, i + 1] ** 2 # type: ignore[operator,union-attr,misc]
for i in range(start, end)
]
)
)
data_dict[estimator_name].append(result)
data_dict[estimator_name + "_Error"].append(error)
# Total result
# This assumes that the indexes are sorted as the
# preprocessing should sort the index of the df.
result = delta_f_.iloc[0, -1] # type: ignore[union-attr]
if estimator_name != "BAR":
error = d_delta_f_.iloc[0, -1] # type: ignore[union-attr]
else:
error = np.sqrt(
sum([d_delta_f_.iloc[i, i + 1] ** 2 for i in range(num_states - 1)]) # type: ignore[operator,union-attr,misc]
)
data_dict[estimator_name].append(result)
data_dict[estimator_name + "_Error"].append(error)
summary = pd.DataFrame.from_dict(data_dict)
summary = summary.set_index(["state", "name"])
# Make sure that the columns are in the right order
summary = summary[col_names]
# Remove the name of the index column to make it prettier
summary.index.names = [None, None]
summary.attrs = estimator.delta_f_.attrs # type: ignore[union-attr]
converter = get_unit_converter(self.units) # type: ignore[arg-type]
summary = converter(summary)
self.summary = summary
logger.info(f"Write results:\n{summary.to_string()}")
return summary # type: ignore[no-any-return]
[docs]
def plot_overlap_matrix(
self, overlap: str = "O_MBAR.pdf", ax: None | Axes = None
) -> None | Axes:
"""Plot the overlap matrix for MBAR estimator using
:func:`~alchemlyb.visualisation.plot_mbar_overlap_matrix`.
Parameters
----------
overlap : str
The filename for the plot of overlap matrix. Default: 'O_MBAR.pdf'
ax : matplotlib.axes.Axes
Matplotlib axes object where the plot will be drawn on. If
``ax=None``, a new axes will be generated.
Returns
-------
matplotlib.axes.Axes
An axes with the overlap matrix drawn.
"""
logger.info("Plot overlap matrix.")
if "MBAR" in self.estimator:
ax = plot_mbar_overlap_matrix(self.estimator["MBAR"].overlap_matrix, ax=ax)
ax.figure.savefig(join(self.out, overlap)) # type: ignore[union-attr]
logger.info(f"Plot overlap matrix to {self.out} under {overlap}.")
return ax
else:
logger.warning("MBAR estimator not found. Overlap matrix not plotted.")
return None
[docs]
def plot_ti_dhdl(
self,
dhdl_TI: str = "dhdl_TI.pdf",
labels: None | list[str] = None,
colors: None | list[str] = None,
ax: None | Axes = None,
) -> None | Axes:
"""Plot the dHdl for TI estimator using
:func:`~alchemlyb.visualisation.plot_ti_dhdl`.
Parameters
----------
dhdl_TI : str
The filename for the plot of TI dHdl. Default: 'dhdl_TI.pdf'
labels : List
list of labels for labelling all the alchemical transformations.
colors : List
list of colors for plotting all the alchemical transformations.
Default: ['r', 'g', '#7F38EC', '#9F000F', 'b', 'y']
ax : matplotlib.axes.Axes
Matplotlib axes object where the plot will be drawn on. If ``ax=None``,
a new axes will be generated.
Returns
-------
matplotlib.axes.Axes
An axes with the TI dhdl drawn.
"""
logger.info("Plot TI dHdl.")
if "TI" in self.estimator:
ax = plot_ti_dhdl(
self.estimator["TI"],
units=self.units,
labels=labels,
colors=colors,
ax=ax,
)
ax.figure.savefig(join(self.out, dhdl_TI)) # type: ignore[union-attr]
logger.info(f"Plot TI dHdl to {dhdl_TI} under {self.out}.")
return ax
else:
raise ValueError("No TI data available in estimators.")
[docs]
def plot_dF_state(
self,
dF_state: str = "dF_state.pdf",
labels: None | list[str] = None,
colors: None | list[str] = None,
orientation: str = "portrait",
nb: int = 10,
) -> Any:
"""Plot the dF states using
:func:`~alchemlyb.visualisation.plot_dF_state`.
Parameters
----------
dF_state : str
The filename for the plot of dF states. Default: 'dF_state.pdf'
labels : List
list of labels for labelling different estimators.
colors : List
list of colors for plotting different estimators.
orientation : string
The orientation of the figure. Can be `portrait` or `landscape`
nb : int
Maximum number of dF states in one row in the `portrait` mode
Returns
-------
matplotlib.figure.Figure
An Figure with the dF states drawn.
"""
logger.info("Plot dF states.")
fig = plot_dF_state(
self.estimator.values(),
labels=labels,
colors=colors,
units=self.units,
orientation=orientation,
nb=nb,
)
fig.savefig(join(self.out, dF_state))
logger.info(f"Plot dF state to {dF_state} under {self.out}.")
return fig
[docs]
def check_convergence( # type: ignore[override]
self,
forwrev: int,
estimator: str = "MBAR",
dF_t: str = "dF_t.pdf",
ax: None | Axes = None,
**kwargs: Any,
) -> None | Axes:
"""Compute the forward and backward convergence using
:func:`~alchemlyb.convergence.forward_backward_convergence`and
plot with
:func:`~alchemlyb.visualisation.plot_convergence`.
Parameters
----------
forwrev : int
Plot the free energy change as a function of time in both
directions, with the specified number of points in the time plot.
The number of time points (an integer) must be provided.
estimator : {'TI', 'BAR', 'MBAR'}
The estimator used for convergence analysis. Default: 'MBAR'
dF_t : str
The filename for the plot of convergence. Default: 'dF_t.pdf'
ax : matplotlib.axes.Axes
Matplotlib axes object where the plot will be drawn on. If ``ax=None``,
a new axes will be generated.
kwargs : dict
Keyword arguments to be passed to the estimator.
Attributes
----------
convergence : DataFrame
Returns
-------
matplotlib.axes.Axes
An axes with the convergence drawn.
"""
logger.info("Start convergence analysis.")
logger.info("Checking data availability.")
if estimator in FEP_ESTIMATORS:
if self.u_nk_sample_list is not None:
u_nk_list = self.u_nk_sample_list
logger.info("Subsampled u_nk is available.")
else:
if self.u_nk_list is not None: # type: ignore[unreachable]
u_nk_list = self.u_nk_list
logger.info(
"Subsampled u_nk not available, use original data instead."
)
else:
msg = (
f"u_nk is needed for the f{estimator} estimator. "
f"If the dataset only has dHdl, "
f"run ABFE.check_convergence(estimator='TI') to "
f"use a TI estimator."
)
logger.error(msg)
raise ValueError(msg)
convergence = forward_backward_convergence(
u_nk_list, estimator=estimator, num=forwrev, **kwargs
)
elif estimator in TI_ESTIMATORS:
logger.warning("No valid FEP estimator or dataset found. Fallback to TI.")
if self.dHdl_sample_list is not None:
dHdl_list = self.dHdl_sample_list
logger.info("Subsampled dHdl is available.")
else:
if self.dHdl_list is not None: # type: ignore[unreachable]
dHdl_list = self.dHdl_list
logger.info(
"Subsampled dHdl not available, use original data instead."
)
else:
logger.error(f"dHdl is needed for the f{estimator} estimator.")
raise ValueError(f"dHdl is needed for the f{estimator} estimator.")
convergence = forward_backward_convergence(
dHdl_list, estimator=estimator, num=forwrev, **kwargs
)
else:
msg = (
f"Estimator {estimator} is not supported. Choose one from "
f"{FEP_ESTIMATORS + TI_ESTIMATORS}."
)
logger.error(msg)
raise ValueError(msg)
unit_converted_convergence = get_unit_converter(self.units)(convergence) # type: ignore[arg-type]
# Otherwise the data_fraction column is converted as well.
unit_converted_convergence["data_fraction"] = convergence["data_fraction"]
self.convergence = unit_converted_convergence
logger.info(f"Plot convergence analysis to {dF_t} under {self.out}.")
ax = plot_convergence(self.convergence, units=self.units, ax=ax)
ax.figure.savefig(join(self.out, dF_t)) # type: ignore[union-attr]
return ax