#!/usr/local/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
Multistatereporter
==================
Master multi-thermodynamic state reporter module. Handles all Disk I/O
reporting operations for any MultiStateSampler derived classes.
COPYRIGHT
Current version by Andrea Rizzi <andrea.rizzi@choderalab.org>, Levi N. Naden <levi.naden@choderalab.org> and
John D. Chodera <john.chodera@choderalab.org> while at Memorial Sloan Kettering Cancer Center.
Original version by John D. Chodera <jchodera@gmail.com> while at the University of
California Berkeley.
LICENSE
This code is licensed under the latest available version of the MIT License.
"""
# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================
import os
import copy
import time
import uuid
import yaml
import warnings
import logging
import collections
import numpy as np
import netCDF4 as netcdf
from typing import Union, Any
try:
from openmm import unit
except ImportError: # OpenMM < 7.6
from simtk import unit
import openmmtools
from openmmtools.utils import deserialize, with_timer, serialize, quantity_from_string
from openmmtools import states
logger = logging.getLogger(__name__)
# ==============================================================================
# MULTISTATE SAMPLER REPORTER
# ==============================================================================
[docs]
class MultiStateReporter(object):
"""Handle storage write/read operations and different format conventions.
You can use this object to programmatically inspect the data generated by
ReplicaExchange.
Parameters
----------
storage : str
The path to the storage file for analysis.
A second checkpoint file will be determined from either ``checkpoint_storage`` or automatically based on
the storage option
In the future this will be able to take Storage classes as well.
open_mode : str or None
The mode of the file between 'r', 'w', and 'a' (or equivalently 'r+').
If None, the storage file won't be open on construction, and a call to
:func:`Reporter.open` will be needed before attempting read/write operations.
checkpoint_interval : int >= 1, Default: 50
The frequency at which checkpointing information is written relative to analysis information.
This is a multiple
of the iteration at which energies is written, hence why it must be greater than or equal to 1.
Checkpoint information cannot be written on iterations which where ``iteration % checkpoint_interval != 0``.
checkpoint_storage : str or None, optional
Optional name of the checkpoint point file. This file is used to save trajectory information and other less
frequently accessed data.
This should NOT be a full path, and instead just a filename
If None: the derived checkpoint name is the same as storage, less any extension, then "_checkpoint.nc" is added.
The reporter internally tracks what data goes into which file, so its transparent to all other classes
In the future, this will be able to take Storage classes as well
analysis_particle_indices : tuple of ints, Optional. Default: () (empty tuple)
If specified, it will serialize positions and velocities for the specified particles, at every iteration, in the
reporter storage (.nc) file. If empty, no positions or velocities will be stored in this file for any atoms.
Attributes
----------
filepath
checkpoint_interval
is_periodic
n_states
n_replicas
analysis_particle_indices
"""
[docs]
def __init__(self, storage, open_mode=None,
checkpoint_interval=50, checkpoint_storage=None,
analysis_particle_indices=()):
# Warn that API is experimental
logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases')
# Handle checkpointing
if type(checkpoint_interval) != int:
raise ValueError("checkpoint_interval must be an integer!")
dirname, filename = os.path.split(storage)
if checkpoint_storage is None:
basename, ext = os.path.splitext(filename)
addon = "_checkpoint"
checkpoint_storage = os.path.join(dirname, basename + addon + ext)
logger.debug("Initial checkpoint file automatically chosen as {}".format(checkpoint_storage))
else:
checkpoint_storage = os.path.join(dirname, checkpoint_storage)
self._storage_analysis_file_path = storage
self._storage_checkpoint_file_path = checkpoint_storage
self._storage_checkpoint = None
self._storage_analysis = None
self._checkpoint_interval = checkpoint_interval
# Cast to tuple no mater what 1-D-like input was given
self._analysis_particle_indices = tuple(analysis_particle_indices)
if open_mode is not None:
self.open(open_mode)
# TODO: Maybe we want to expose this flag to control ovrwriting/appending
# Flag to check whether to overwrite real time statistics file -- Defaults to append
self._overwrite_statistics = False
@property
def filepath(self):
"""
Returns the string file name of the primary storage file
Classes outside the Reporter can access the file string for error messages and such.
"""
return self._storage_analysis_file_path
@property
def _storage(self):
"""
Return an iterable of the storage objects, avoids having the [list, of, storage, objects] everywhere
Object 0 is always the primary file, all others are subfiles
"""
return self._storage_analysis, self._storage_checkpoint
@property
def _storage_paths(self):
"""
Return an iterable of paths to the storage files
Object 0 is always the primary file, all others are subfiles
"""
return self._storage_analysis_file_path, self._storage_checkpoint_file_path
@property
def _storage_dict(self):
"""Return an iterable dictionary of the self._storage_X objects"""
return {'checkpoint': self._storage_checkpoint, 'analysis': self._storage_analysis}
@property
def n_states(self):
if not self.is_open():
return None
return self._storage_analysis.dimensions['state'].size
@property
def n_replicas(self):
if not self.is_open():
return None
return self._storage_analysis.dimensions['replica'].size
@property
def is_periodic(self):
if not self.is_open():
return None
if 'box_vectors' in self._storage_analysis.variables:
return True
return False
@property
def analysis_particle_indices(self):
"""Return the tuple of indices of the particles which additional information is stored on for analysis"""
return self._analysis_particle_indices
@property
def checkpoint_interval(self):
"""Returns the checkpoint interval"""
return self._checkpoint_interval
def storage_exists(self, skip_size=False):
"""
Check if the storage files exist on disk.
Reads information on the primary file to see existence of others
Parameters
----------
skip_size : bool, Optional, Default: False
Skip the check of the file size. Helpful if you have just initialized a storage file but written nothing to
it yet and/or its still entirely in memory (e.g. just opened NetCDF files)
Returns
-------
files_exist : bool
If the primary storage file and its related subfiles exist, returns True.
If the primary file or any subfiles do not exist, returns False
"""
# This function serves as a way to mask the subfiles from everything outside the reporter
for file_path in self._storage_paths:
if not os.path.exists(file_path):
return False # Return if any files do not exist
elif not os.path.getsize(file_path) > 0 and not skip_size:
return False # File is 0 size
return True
def is_open(self):
"""Return True if the Reporter is ready to read/write."""
if self._storage[0] is None:
return False
else:
return self._storage[0].isopen()
def _are_subfiles_open(self):
"""Internal function to check if subfiles are open"""
open_check_list = []
for storage in self._storage[1:]:
if storage is None:
return False
else:
open_check_list.append(storage.isopen())
return np.all(open_check_list)
def open(self, mode='r', convention='ReplicaExchange', netcdf_format='NETCDF4'):
"""
Open the storage file for reading/writing.
Creates and pre-formats the required files if they don't exist.
This is not necessary if you have indicated in the constructor to open.
Parameters
----------
mode : str, Optional, Default: 'r'
The mode of the file between 'r', 'w', and 'a' (or equivalently 'r+').
convention : str, Optional, Default: 'ReplicaExchange'
NetCDF convention to write
netcdf_format : str, Optional, Default: 'NETCDF4'
The NetCDF file format to use
"""
# Ensure we don't have already another file
# open (possibly in a different mode).
self.close()
# Create directory if we want to write.
# TODO: We probably want to check here specifically for w when we want to write
if mode != 'r':
for storage_path in self._storage_paths:
# normpath() transform '' to '.' for makedirs().
storage_dir = os.path.normpath(os.path.dirname(storage_path))
os.makedirs(storage_dir, exist_ok=True)
# Analysis file.
# ---------------
# Open analysis file.
self._storage_analysis = self._open_dataset_robustly(self._storage_analysis_file_path,
mode, version=netcdf_format)
# The analysis netcdf file holds a reference UUID so that we can check
# that the secondary netcdf files (currently only the checkpoint
# file) have the same UUID to verify that the user isn't erroneously
# trying to associate the anaysis file to the incorrect checkpoint.
try:
# Check if we have previously created the file.
primary_uuid = self._storage_analysis.UUID
except AttributeError:
# This is a new file. Use uuid4 to avoid assigning hostname information.
primary_uuid = str(uuid.uuid4())
self._storage_analysis.UUID = primary_uuid
# Initialize dataset, if needed.
self._initialize_storage_file(self._storage_analysis, 'analysis', convention)
# Checkpoint file.
# -----------------
# Open checkpoint netcdf files.
msg = ('Could not locate checkpoint subfile. This is okay for analysis if the '
'solvent trajectory is not needed, but not for production simulation!')
self._storage_checkpoint = self._open_dataset_robustly(self._storage_checkpoint_file_path,
mode, catch_io_error=True,
io_error_warning=msg,
version=netcdf_format)
if self._storage_checkpoint is not None:
# Check that the checkpoint file has the same UUID of the analysis file.
try:
assert self._storage_checkpoint.UUID == primary_uuid
except AttributeError:
# This is a new file. Assign UUID.
self._storage_checkpoint.UUID = primary_uuid
except AssertionError:
raise IOError('Checkpoint UUID does not match analysis UUID! '
'This checkpoint file came from another simulation!\n'
'Analysis UUID: {}; Checkpoint UUID: {}'.format(
primary_uuid, self._storage_checkpoint.UUID))
# Initialize dataset, if needed.
self._initialize_storage_file(self._storage_checkpoint, 'checkpoint', convention)
# Further checkpoint interval checks.
# -----------------------------------
if self._storage_analysis is not None:
# The same number will be on checkpoint file as well, but its not guaranteed to be present
on_file_interval = self._storage_analysis.CheckpointInterval
if on_file_interval != self._checkpoint_interval:
logger.debug("checkpoint_interval != on-file checkpoint interval! "
"Using on file analysis interval of {}.".format(on_file_interval))
self._checkpoint_interval = on_file_interval
# Check the special particle indices
# Handle the "variable does not exist" case
if 'analysis_particle_indices' not in self._storage_analysis.variables:
n_particles = len(self._analysis_particle_indices)
# This dimension won't exist if the above statement does not either
self._storage_analysis.createDimension('analysis_particles', n_particles)
ncvar_analysis_particles = \
self._storage_analysis.createVariable('analysis_particle_indices', int, 'analysis_particles')
ncvar_analysis_particles[:] = self._analysis_particle_indices
ncvar_analysis_particles.long_name = ("analysis_particle_indices[analysis_particles] is the indices of "
"the particles with extra information stored about them in the"
"analysis file.")
# Now handle the "variable does exist but does not match the provided ones"
# Although redundant if it was just created, its an easy check to make
stored_analysis_particles = self._storage_analysis.variables['analysis_particle_indices'][:]
if self._analysis_particle_indices != tuple(stored_analysis_particles.astype(int)):
logger.debug("analysis_particle_indices != on-file analysis_particle_indices!"
"Using on file analysis indices of {}".format(stored_analysis_particles))
self._analysis_particle_indices = tuple(stored_analysis_particles.astype(int))
def _open_dataset_robustly(self, *args, n_attempts=5, sleep_time=2,
catch_io_error=False, io_error_warning=None,
**kwargs):
"""Attempt to open the dataset multiple times if it raises an error.
This may be useful to solve some MPI concurrency and locking issues
that routinely and randomly pop up with HDF5. Some sleep time is
added between attempts (in seconds).
If the file is not found and catch_io_error is True, None is returned.
"""
# Catch eventual errors n_attempts - 1 times.
for attempt in range(n_attempts-1):
try:
return netcdf.Dataset(*args, **kwargs)
except Exception as err: # This should be safe since we will raise the error below on return
logger.debug(f"exception thrown {err}")
logger.debug('Attempt {}/{} to open {} failed. Retrying '
'in {} seconds'.format(attempt+1, n_attempts, args[0], sleep_time))
time.sleep(sleep_time)
# Check if file exists and warn if asked
# raise IOError otherwise
if not os.path.isfile(args[0]):
if catch_io_error:
if io_error_warning is not None:
logger.warning(io_error_warning)
return None
raise IOError(f"{args[0]} does not exist")
# At the very last attempt, we try setting the environment variable
# controlling the locking mechanism of HDF5 (see choderalab/yank#1165).
if n_attempts > 1:
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
# Last attempt finally raises any error.
return netcdf.Dataset(*args, **kwargs)
def _initialize_storage_file(self, ncfile, nc_name, convention):
"""Helper function to initialize dimensions and global attributes.
If the dataset has been initialized before, nothing happens. Return True
if the file has been initialized before and False otherwise.
"""
from openmmtools import __version__
if 'scalar' not in ncfile.dimensions:
# Create common dimensions.
ncfile.createDimension('scalar', 1) # Scalar dimension.
ncfile.createDimension('iteration', 0) # Unlimited number of iterations.
ncfile.createDimension('spatial', 3) # Number of spatial dimensions.
# Set global attributes.
ncfile.program = f"openmmtools {openmmtools.__version__}"
ncfile.programVersion = __version__
ncfile.Conventions = convention
ncfile.ConventionVersion = '0.2'
ncfile.DataUsedFor = nc_name
ncfile.CheckpointInterval = self._checkpoint_interval
# Create and initialize the global variables
nc_last_good_iter = ncfile.createVariable('last_iteration', int, 'scalar')
nc_last_good_iter[0] = 0
return True
else:
return False
def close(self):
"""Close the storage files"""
for storage_name, storage in self._storage_dict.items():
if storage is not None:
if storage.isopen():
storage.sync()
storage.close()
setattr(self, '_storage' + storage_name, None)
def sync(self):
"""Force any buffer to be flushed to the file"""
for storage in self._storage:
if storage is not None:
storage.sync()
def __del__(self):
"""Synchronize and close the storage."""
self.close()
def read_end_thermodynamic_states(self):
"""Read thermodynamic states at the ends of the protocol."
Returns
-------
end_thermodynamic_states : list of ThermodynamicState
unsampled_states, if present, or first and last sampled states
"""
end_thermodynamic_states = list()
if 'unsampled_states' in self._storage_analysis.groups:
state_type = 'unsampled_states'
else:
state_type = 'thermodynamic_states'
# Read thermodynamic end states
states_serializations = dict()
n_states = len(self._storage_analysis.groups[state_type].variables)
def extract_serialized_state(inner_type, inner_id):
"""Inner function to help extract the correct serialized state while minimizing number of disk reads
Parameters
----------
inner_type : str, 'unsampled_states' or 'thermodynamic_states'
Where to read the data from, inherited from parent function's property or on the recursive loop
inner_id : int
Which state to pull data from
"""
inner_serialized_state = self.read_dict('{}/state{}'.format(inner_type, inner_id))
def isolate_thermodynamic_state(isolating_serialized_state):
"""Helper function to find true bottom level thermodynamic state from any level of nesting, reduce code
"""
isolating_serial_thermodynamic_state = isolating_serialized_state
while 'thermodynamic_state' in isolating_serial_thermodynamic_state:
# The while loop is necessary for nested CompoundThermodynamicStates.
isolating_serial_thermodynamic_state = isolating_serial_thermodynamic_state['thermodynamic_state']
return isolating_serial_thermodynamic_state
serialized_thermodynamic_state = isolate_thermodynamic_state(inner_serialized_state)
# Check if the standard state is in a previous state.
try:
standard_system_name = serialized_thermodynamic_state.pop('_Reporter__compatible_state')
except KeyError:
# Cache the standard system serialization for future usage.
standard_system_name = '{}/{}'.format(inner_type, inner_id)
states_serializations[standard_system_name] = serialized_thermodynamic_state['standard_system']
else:
# The system serialization can be retrieved from another state.
# Because the unsampled states rely on the thermodynamic states for their deserialization, we have
# to try a secondary/recursive loop to get the thermodynamic states
# However, this loop happens less often as the states_serializations dict fills up.
try:
serialized_standard_system = states_serializations[standard_system_name]
except KeyError:
loop_type, loop_id = standard_system_name.split('/')
looped_standard_state = extract_serialized_state(loop_type, loop_id)
looped_serial_thermodynamic_state = isolate_thermodynamic_state(looped_standard_state)
serialized_standard_system = looped_serial_thermodynamic_state['standard_system']
serialized_thermodynamic_state['standard_system'] = serialized_standard_system
return inner_serialized_state
for state_id in [0, n_states-1]:
serialized_state = extract_serialized_state(state_type, state_id)
# Create ThermodynamicState object.
end_thermodynamic_states.append(deserialize(serialized_state))
return end_thermodynamic_states
@with_timer('Reading thermodynamic states from storage')
def read_thermodynamic_states(self):
"""Retrieve the stored thermodynamic states from the checkpoint file.
Returns
-------
thermodynamic_states : list of ThermodynamicStates
The previously stored thermodynamic states. During the simulation
these are swapped among replicas.
unsampled_states : list of ThermodynamicState
The previously stored unsampled thermodynamic states.
See Also
--------
read_replica_thermodynamic_states
"""
# We have to parse the thermodynamic states first because the
# unsampled states may refer to them for the serialized system.
states = collections.OrderedDict([('thermodynamic_states', list()),
('unsampled_states', list())])
# Caches standard_system_name: serialized_standard_system
states_serializations = dict()
# Read state information.
for state_type, state_list in states.items():
# There may not be unsampled states.
if state_type not in self._storage_analysis.groups:
assert state_type == 'unsampled_states'
continue
# We keep looking for states until we can't find them anymore.
n_states = len(self._storage_analysis.groups[state_type].variables)
for state_id in range(n_states):
serialized_state = self.read_dict('{}/state{}'.format(state_type, state_id))
# Find the thermodynamic state representation.
serialized_thermodynamic_state = serialized_state
while 'thermodynamic_state' in serialized_thermodynamic_state:
# The while loop is necessary for nested CompoundThermodynamicStates.
serialized_thermodynamic_state = serialized_thermodynamic_state['thermodynamic_state']
# Check if the standard state is in a previous state.
try:
standard_system_name = serialized_thermodynamic_state.pop('_Reporter__compatible_state')
except KeyError:
# Cache the standard system serialization for future usage.
standard_system_name = '{}/{}'.format(state_type, state_id)
states_serializations[standard_system_name] = serialized_thermodynamic_state['standard_system']
else:
# The system serialization can be retrieved from another state.
serialized_standard_system = states_serializations[standard_system_name]
serialized_thermodynamic_state['standard_system'] = serialized_standard_system
# Create ThermodynamicState object.
states[state_type].append(deserialize(serialized_state))
return [states['thermodynamic_states'], states['unsampled_states']]
@with_timer('Storing thermodynamic states')
def write_thermodynamic_states(self, thermodynamic_states, unsampled_states):
"""Store all the ThermodynamicStates to the checkpoint file.
Parameters
----------
thermodynamic_states : list of ThermodynamicState
The thermodynamic states to store.
unsampled_states : list of ThermodynamicState
The unsampled thermodynamic states to store.
See Also
--------
write_replica_thermodynamic_states
"""
# Store all thermodynamic states as serialized dictionaries.
stored_states = dict()
def unnest_thermodynamic_state(serialized):
while 'thermodynamic_state' in serialized:
serialized = serialized['thermodynamic_state']
return serialized
for state_type, states in [('thermodynamic_states', thermodynamic_states),
('unsampled_states', unsampled_states)]:
for state_id, state in enumerate(states):
# Check if any compatible state has been found
found_compatible_state = False
for compare_state in stored_states:
if compare_state.is_state_compatible(state):
serialized_state = serialize(state, skip_system=True)
serialized_thermodynamic_state = unnest_thermodynamic_state(serialized_state)
serialized_thermodynamic_state.pop('standard_system') # Remove the unneeded system object
reference_state_name = stored_states[compare_state]
serialized_thermodynamic_state['_Reporter__compatible_state'] = reference_state_name
found_compatible_state = True
break
# If no compatible state is found, do full serialization
if not found_compatible_state:
serialized_state = serialize(state)
serialized_thermodynamic_state = unnest_thermodynamic_state(serialized_state)
serialized_standard_system = serialized_thermodynamic_state['standard_system']
reference_state_name = '{}/{}'.format(state_type, state_id)
len_serialization = len(serialized_standard_system)
# Store new compatibility data
stored_states[state] = reference_state_name
logger.debug("Serialized state {} is {}B | {:.3f}KB | {:.3f}MB".format(
reference_state_name, len_serialization, len_serialization/1024.0,
len_serialization/1024.0/1024.0))
# Finally write the dictionary with fixed dimension to improve compression.
self._write_dict('{}/state{}'.format(state_type, state_id),
serialized_state, fixed_dimension=True)
def read_sampler_states(self, iteration, analysis_particles_only=False):
"""Retrieve the stored sampler states on the checkpoint file
If the iteration is not on the checkpoint interval, None is returned.
Exception to this is if``analysis_particles_only``, see the Returns for output behavior.
Parameters
----------
iteration : int
The iteration at which to read the data.
analysis_particles_only : bool, Optional, Default: False
If set, will return the trajectory of ONLY the analysis particles flagged on original creation of the files
Returns
-------
sampler_states : list of SamplerStates or None
The previously stored sampler states for each replica.
If the iteration is not on the ``checkpoint_interval`` and the ``analysis_particles_only`` is not set,
None is returned
If ``analysis_particles_only`` is set, will return only the subset of particles which were defined by the
``analysis_particle_indices`` on reporter creation
"""
if analysis_particles_only:
if len(self._analysis_particle_indices) == 0:
raise ValueError("No particles were flagged for special analysis! "
"No such trajectory would have been written!")
return self._read_sampler_states_from_given_file(iteration, storage_file='analysis',
obey_checkpoint_interval=False)
else:
return self._read_sampler_states_from_given_file(iteration, storage_file='checkpoint',
obey_checkpoint_interval=True)
@with_timer('Storing sampler states')
def write_sampler_states(self, sampler_states: list, iteration: int):
"""Store all sampler states for a given iteration on the checkpoint file
If the iteration is not on the checkpoint interval, only the ``analysis_particle_indices`` data is written,
if set.
Parameters
----------
sampler_states : list of SamplerStates
The sampler states to store for each replica.
iteration : int
The iteration at which to store the data.
"""
# Case of no special atoms, write to normal checkpoint
self._write_sampler_states_to_given_file(sampler_states, iteration, storage_file='checkpoint',
obey_checkpoint_interval=True)
if len(self._analysis_particle_indices) > 0:
# Special case, pre-process the sampler_states
sampler_subset = []
for sampler_state in sampler_states:
positions = sampler_state.positions
# Subset positions
# Need the [arg, :] to get uniform behavior with tuple and list for arg
# since a ndarray[tuple] is different than ndarray[list]
position_subset = positions[self._analysis_particle_indices, :]
velocities_subset = None
if sampler_state._unitless_velocities is not None:
velocities = sampler_state.velocities
velocities_subset = velocities[self._analysis_particle_indices, :]
sampler_subset.append(states.SamplerState(position_subset, velocities=velocities_subset,
box_vectors=sampler_state.box_vectors))
self._write_sampler_states_to_given_file(sampler_subset, iteration, storage_file='analysis',
obey_checkpoint_interval=False)
def read_replica_thermodynamic_states(self, iteration=slice(None)):
"""Retrieve the indices of the ThermodynamicStates for each replica on the analysis file
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data. The slice(None) allows fetching all iterations at once.
Returns
-------
state_indices : np.ndarray of int
At the given iteration, replica ``i`` propagated the system in
SamplerState ``sampler_states[i]`` and ThermodynamicState
``thermodynamic_states[states_indices[i]]``.
If a slice is given, returns shape ``[len(slice), `len(sampler_states)]``
"""
iteration = self._map_iteration_to_good(iteration)
logger.debug('read_replica_thermodynamic_states: iteration = {}'.format(iteration))
return self._storage_analysis.variables['states'][iteration].astype(np.int64)
def write_replica_thermodynamic_states(self, state_indices, iteration):
"""Store the indices of the ThermodynamicStates for each replica on the analysis file
Parameters
----------
state_indices : list of int of size n_replicas
At the given iteration, replica ``i`` propagated the system in
SamplerState ``sampler_states[i]`` and ThermodynamicState
``thermodynamic_states[replica_thermodynamic_states[i]]``.
iteration : int
The iteration at which to store the data.
"""
# Initialize schema if needed.
if 'states' not in self._storage_analysis.variables:
n_replicas = len(state_indices)
# Create dimension if they don't exist.
self._ensure_dimension_exists('replica', n_replicas)
# Create variables and attach units and description.
ncvar_states = self._storage_analysis.createVariable(
'states', 'i4', ('iteration', 'replica'),
zlib=False, chunksizes=(1, n_replicas)
)
ncvar_states.units = 'none'
ncvar_states.long_name = ("states[iteration][replica] is the thermodynamic state index "
"(0..n_states-1) of replica 'replica' of iteration 'iteration'.")
# Store thermodynamic states indices.
self._storage_analysis.variables['states'][iteration, :] = state_indices[:]
def read_mcmc_moves(self):
"""Return the MCMCMoves of the :class:`yank.multistate.MultiStateSampler` simulation on the checkpoint
Returns
-------
mcmc_moves : list of MCMCMove
The MCMCMoves used to propagate the simulation.
"""
n_moves = len(self._storage_analysis.groups['mcmc_moves'].variables)
# Retrieve all moves in order.
mcmc_moves = list()
for move_id in range(n_moves):
serialized_move = self.read_dict('mcmc_moves/move{}'.format(move_id))
mcmc_moves.append(deserialize(serialized_move))
return mcmc_moves
def write_mcmc_moves(self, mcmc_moves):
"""Store the MCMCMoves of the :class:`yank.multistate.MultiStateSampler` simulation or subclasses on the checkpoint
Parameters
----------
mcmc_moves : list of MCMCMove
The MCMCMoves used to propagate the simulation.
"""
for move_id, move in enumerate(mcmc_moves):
serialized_move = serialize(move)
self.write_dict('mcmc_moves/move{}'.format(move_id), serialized_move)
def read_energies(self, iteration=slice(None)):
"""Retrieve the energy matrix at the given iteration on the analysis file
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data. The slice(None) allows fetching all iterations at once.
Returns
-------
energy_thermodynamic_states : n_replicas x n_states numpy.ndarray
``energy_thermodynamic_states[iteration, i, j]`` is the reduced potential computed at
SamplerState ``sampler_states[iteration, i]`` and ThermodynamicState ``thermodynamic_states[iteration, j]``.
energy_neighborhoods : n_replicas x n_states numpy.ndarray
energy_neighborhoods[replica_index, state_index] is 1 if the energy was computed for this state,
0 otherwise
energy_unsampled_states : n_replicas x n_unsampled_states numpy.ndarray
``energy_unsampled_states[iteration, i, j]`` is the reduced potential computed at SamplerState
``sampler_states[iteration, i]`` and ThermodynamicState ``unsampled_thermodynamic_states[iteration, j]``.
"""
# Determine last consistent iteration
iteration = self._map_iteration_to_good(iteration)
# Retrieve energies at all thermodynamic states
energy_thermodynamic_states = np.array(self._storage_analysis.variables['energies'][iteration, :, :], np.float64)
# Retrieve neighborhoods, assuming global neighborhoods if reading a pre-neighborhoods file
try:
energy_neighborhoods = np.array(self._storage_analysis.variables['neighborhoods'][iteration, :, :], 'i1')
except KeyError:
energy_neighborhoods = np.ones(energy_thermodynamic_states.shape, 'i1')
# Read energies at unsampled states, if present
try:
energy_unsampled_states = np.array(self._storage_analysis.variables['unsampled_energies'][iteration, :, :], np.float64)
except KeyError:
# There are no unsampled thermodynamic states.
unsampled_shape = energy_thermodynamic_states.shape[:-1] + (0,)
energy_unsampled_states = np.zeros(unsampled_shape)
return energy_thermodynamic_states, energy_neighborhoods, energy_unsampled_states
def write_energies(self, energy_thermodynamic_states, energy_neighborhoods, energy_unsampled_states, iteration):
"""Store the energy matrix at the given iteration on the analysis file
Parameters
----------
energy_thermodynamic_states : n_replicas x n_states numpy.ndarray
``energy_thermodynamic_states[i][j]`` is the reduced potential computed at
SamplerState ``sampler_states[i]`` and ThermodynamicState ``thermodynamic_states[j]``.
energy_neighborhoods : n_replicas x n_states numpy.ndarray
energy_neighborhoods[replica_index, state_index] is 1 if the energy was computed for this state,
0 otherwise
energy_unsampled_states : n_replicas x n_unsampled_states numpy.ndarray
``energy_unsampled_states[i][j]`` is the reduced potential computed at SamplerState
``sampler_states[i]`` and ThermodynamicState ``unsampled_thermodynamic_states[j]``.
iteration : int
The iteration at which to store the data.
"""
n_replicas, n_states = energy_thermodynamic_states.shape
# Create dimensions and variables if they weren't created by other functions.
self._ensure_dimension_exists('replica', n_replicas)
self._ensure_dimension_exists('state', n_states)
if 'energies' not in self._storage_analysis.variables:
ncvar_energies = self._storage_analysis.createVariable(
'energies', 'f8', ('iteration', 'replica', 'state'),
zlib=False, chunksizes=(1, n_replicas, n_states)
)
ncvar_energies.units = 'kT'
ncvar_energies.long_name = ("energies[iteration][replica][state] is the reduced (unitless) "
"energy of replica 'replica' from iteration 'iteration' evaluated "
"at the thermodynamic state 'state'.")
if 'neighborhoods' not in self._storage_analysis.variables:
ncvar_neighborhoods = self._storage_analysis.createVariable(
'neighborhoods', 'i1', ('iteration', 'replica', 'state'),
zlib=False, fill_value=1, # old-style files will be upgraded to have all states
chunksizes=(1, n_replicas, n_states)
)
ncvar_neighborhoods.long_name = ("neighborhoods[iteration][replica][state] is 1 if "
"this energy was computed during this iteration.")
if 'unsampled_energies' not in self._storage_analysis.variables:
# Check if we have unsampled states.
if energy_unsampled_states.shape[1] > 0:
n_unsampled_states = len(energy_unsampled_states[0])
self._ensure_dimension_exists('unsampled', n_unsampled_states)
if 'unsampled_energies' not in self._storage_analysis.variables:
# Create variable for thermodynamic state energies with units and descriptions.
ncvar_unsampled = self._storage_analysis.createVariable(
'unsampled_energies', 'f8', ('iteration', 'replica', 'unsampled'),
zlib=False, chunksizes=(1, n_replicas, n_unsampled_states)
)
ncvar_unsampled.units = 'kT'
ncvar_unsampled.long_name = ("unsampled_energies[iteration][replica][state] is the reduced "
"(unitless) energy of replica 'replica' from iteration 'iteration' "
"evaluated at unsampled thermodynamic state 'state'.")
# Store values
self._storage_analysis.variables['energies'][iteration,:,:] = energy_thermodynamic_states
self._storage_analysis.variables['neighborhoods'][iteration,:,:] = energy_neighborhoods
if energy_unsampled_states.shape[1] > 0:
self._storage_analysis.variables['unsampled_energies'][iteration, :, :] = energy_unsampled_states[:, :]
def read_mixing_statistics(self, iteration=slice(None)):
"""Retrieve the mixing statistics for the given iteration on the analysis file
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data.
Returns
-------
n_accepted_matrix : numpy.ndarray with shape (n_states, n_states)
``n_accepted_matrix[i][j]`` is the number of accepted moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from ``iteration-1`` to ``iteration`` (not cumulative).
n_proposed_matrix : numpy.ndarray with shape (n_states, n_states)
``n_proposed_matrix[i][j]`` is the number of proposed moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from ``iteration-1`` to ``iteration`` (not cumulative).
"""
iteration = self._map_iteration_to_good(iteration)
n_accepted_matrix = self._storage_analysis.variables['accepted'][iteration, :, :].astype(np.int64)
n_proposed_matrix = self._storage_analysis.variables['proposed'][iteration, :, :].astype(np.int64)
return n_accepted_matrix, n_proposed_matrix
def write_mixing_statistics(self, n_accepted_matrix, n_proposed_matrix, iteration):
"""Store the mixing statistics for the given iteration on the analysis file
Parameters
----------
n_accepted_matrix : numpy.ndarray with shape (n_states, n_states)
``n_accepted_matrix[i][j]`` is the number of accepted moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from iteration-1 to iteration (not cumulative).
n_proposed_matrix : numpy.ndarray with shape (n_states, n_states)
``n_proposed_matrix[i][j]`` is the number of proposed moves from
state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going
from ``iteration-1`` to ``iteration`` (not cumulative).
iteration : int
The iteration for which to store the data.
"""
# Create schema if necessary.
if 'accepted' not in self._storage_analysis.variables:
n_states = n_accepted_matrix.shape[0]
# Create dimension if it doesn't already exist
self._ensure_dimension_exists('state', n_states)
# Create variables with units and descriptions.
ncvar_accepted = self._storage_analysis.createVariable(
'accepted', 'i4', ('iteration', 'state', 'state'),
zlib=False, chunksizes=(1, n_states, n_states)
)
ncvar_proposed = self._storage_analysis.createVariable(
'proposed', 'i4', ('iteration', 'state', 'state'),
zlib=False, chunksizes=(1, n_states, n_states)
)
ncvar_accepted.units = 'none'
ncvar_proposed.units = 'none'
ncvar_accepted.long_name = ("accepted[iteration][i][j] is the number of proposed transitions "
"between states i and j from iteration 'iteration-1'.")
ncvar_proposed.long_name = ("proposed[iteration][i][j] is the number of proposed transitions "
"between states i and j from iteration 'iteration-1'.")
# Store statistics.
self._storage_analysis.variables['accepted'][iteration, :, :] = n_accepted_matrix[:, :]
self._storage_analysis.variables['proposed'][iteration, :, :] = n_proposed_matrix[:, :]
def read_timestamp(self, iteration=slice(None)):
"""Return the timestamp for the given iteration.
Read from the analysis file, although there is a paired timestamp on the checkpoint file as well
Parameters
----------
iteration : int or slice
The iteration(s) at which to read the data.
Returns
-------
timestamp : str
The timestamp at which the iteration was stored.
"""
iteration = self._map_iteration_to_good(iteration)
return self._storage_analysis.variables['timestamp'][iteration]
def write_timestamp(self, iteration: int):
"""Store a timestamp for the given iteration on both analysis and checkpoint file.
If the iteration is not on the ``checkpoint_interval``, no timestamp is written on the checkpoint file
Parameters
----------
iteration : int
The iteration at which to read the data.
"""
# Create variable if needed.
for storage_key, storage in self._storage_dict.items():
if 'timestamp' not in storage.variables:
storage.createVariable('timestamp', str, ('iteration',),
zlib=False, chunksizes=(1,))
timestamp = time.ctime()
self._storage_analysis.variables['timestamp'][iteration] = timestamp
checkpoint_iteration = self._calculate_checkpoint_iteration(iteration)
if checkpoint_iteration is not None:
self._storage_checkpoint.variables['timestamp'][checkpoint_iteration] = timestamp
def read_dict(self, path: str) -> Union[dict, Any]:
"""Restore a dictionary from the storage file.
The method supports reading only specific dictionary keywords in
path notation. If the dictionary is large, this can be quicker.
However, note that, depending on how the dictionary was saved,
this may end up reading the whole dictionary anyway.
Parameters
----------
path : str
The path to the dictionary or a keyword in the dictionary.
Returns
-------
data : dict or specified data
The restored data as a dict, or the data stored at key, depending on `path`
Examples
--------
>>> import os
>>> import openmmtools as mmtools
>>> data = {'info': [1, 2, 3]}
>>> with mmtools.utils.temporary_directory() as temp_dir_path:
... temp_file_path = os.path.join(temp_dir_path, 'temp.nc')
... reporter = MultiStateReporter(temp_file_path, open_mode='w')
... reporter.write_dict('data', data)
... reporter.read_dict('data')
{'info': [1, 2, 3]}
"""
storage = 'analysis'
# Get lowest possible NC variable/group. The path might refer
# to the keyword of a dictionary that was saved in a single variable.
dict_path = []
nc_element = None
while nc_element is None:
try:
nc_element = self._resolve_nc_path(path, storage)
except KeyError:
# Try the higher level.
path, dict_key = path.rsplit(sep='/', maxsplit=1)
dict_path.insert(0, dict_key)
# If this is a group, the dictionary has been nested.
if isinstance(nc_element, netcdf.Group):
data = {}
for elements in [nc_element.groups, nc_element.variables]:
for key in elements:
data.update({key: self.read_dict(path + '/' + key)})
# Otherwise this is a variable.
else:
if nc_element.dtype == 'S1':
# Handle variables stored in fixed_dimensions
data_chars = nc_element[:]
data_str = data_chars.tostring().decode()
else:
data_str = str(nc_element[0])
data = yaml.load(data_str, Loader=_DictYamlLoader)
# Restore the title in the metadata.
if path == 'metadata':
data['title'] = self._storage_dict[storage].title
# Resolve the rest of the path that was saved un-nested.
for dict_key in dict_path:
data = data[dict_key]
return data
def write_dict(self, path: str, data: dict):
"""Store the contents of a dict.
Parameters
----------
path : str
The path to the dictionary in the storage file.
data : dict
The dict to store.
"""
storage_name = 'analysis'
if path == 'metadata':
# General NetCDF conventions assume the title of the dataset
# to be specified as a global attribute, but the user can
# specify their own titles only in metadata.
data = copy.deepcopy(data)
self._storage_dict[storage_name].title = data.pop('title')
# Metadata is pretty big, read-only attribute (it contains the
# reference state and the topography), and AlchemicalPhase has
# to read the name of the sampler for resuming and checking the
# status, so we store it compressed and in nested form.
self._write_dict(path, data, storage_name=storage_name,
fixed_dimension=True, nested=True)
else:
self._write_dict(path, data, storage_name=storage_name)
def read_checkpoint_iterations(self):
"""
Utility function to provide all iterations on which a checkpoint was written
Returns
-------
checkpoints : np.ndarray of int
All checkpoints from initial iteration to current
"""
return np.array(range(0, self.read_last_iteration(last_checkpoint=True)+1, self._checkpoint_interval),
dtype=int)
def read_last_iteration(self, last_checkpoint=True):
"""
Read the last iteration from file which was written in sequential order.
Parameters
----------
last_checkpoint : bool, optional
If True, returns the last checkpoint iteration (default is True).
Returns
-------
last_iteration : int
Last iteration which was sequentially written.
"""
# Make sure this is returned as Python int.
last_iteration = int(self._storage_analysis.variables['last_iteration'][0])
# Get last checkpoint.
if last_checkpoint:
# -1 for stop ensures the 0th index is searched.
for i in range(last_iteration, -1, -1):
if self._calculate_checkpoint_iteration(i) is not None:
return i
raise RuntimeError("Could not find a checkpoint! This should not happen "
"as the 0th iteration should always be written! "
"Please open a ticket on the YANK GitHub page if you see this error message!")
return last_iteration
def write_last_iteration(self, iteration):
"""
Tell the reporter what the last iteration which was written in sequential order was to allow resuming and
analysis only on valid data.
Call this as the last step of any ``write_iteration``-like routine to ensure
analysis will not use junk data left over from an interrupted simulation past the last checkpoint.
The reporter is sync'ed both before and after writing the last iteration to ensure minimal data corruption
Parameters
----------
iteration : int
Iteration at which the last good data point was written.
"""
self.sync()
self._storage_analysis.variables['last_iteration'][0] = iteration
self.sync()
def read_logZ(self, iteration):
"""
Read logZ at a given iteration from file.
Parameters
----------
iteration : int
iteration to read the free energies from
if the iteration was not written at a the given iteration, then the free_energies are all 0
Returns
-------
logZ : np.array with shape [n_states]
Dimensionless logZ
"""
data = self.read_online_analysis_data(iteration, "logZ")
return data['logZ']
def write_logZ(self, iteration: int, logZ: np.ndarray):
"""
Write logZ
Parameters
----------
iteration : int,
Iteration at which to save the free energy.
Reads the current energy up to this value and stores it in the analysis reporter
logZ : np.array with shape [n_states]
Dimensionless log Z
"""
self.write_online_data_dynamic_and_static(iteration, logZ=logZ)
def read_online_analysis_data(self, iteration, *keys: str):
"""
Parameters
----------
iteration : int or None
Iteration to fetch data at. If ``None``, then assumes static data and will attempt to get the entry with the
name written assuming no iteration-specific data.
keys : str
Variables to fetch data from
Returns
-------
online_analysis_data : dict
Data requested by *keys argument from online analysis, if they exist on disk
Warnings
--------
RuntimeWarning : If some keys were not found as requested
Raises
------
ValueError : If no requested keys were found in the storage or if no online analysis data was written
"""
collected_variables = {}
collected_iteration_failure = []
collected_not_found = []
try:
storage = self._storage_analysis.groups["online_analysis"]
except KeyError:
raise ValueError("Online Analysis information was never written!")
for variable in keys:
try:
data = self._read_1d_online_data(iteration, variable, storage)
collected_variables[variable] = data
except IndexError:
if self._find_alternate_variable(iteration, variable, storage):
collected_iteration_failure.append(variable)
else:
collected_not_found.append(variable)
# Nothing found
if not collected_variables and not collected_iteration_failure:
raise ValueError("None of the requested keys could be found on disk!")
# Found some things possibly named wrong, still nothing to return
elif not collected_variables:
base_error = ("No variables found on disk with{} per-iteration data, but we did find the following "
"variables of the same name with{} per-iteration data. Possibly you meant those?"
)
for failure in collected_iteration_failure:
base_error += "\n\t-{}".format(failure)
if iteration is None:
raise ValueError(base_error.format("out", ""))
else:
raise ValueError(base_error.format("", "out"))
elif collected_iteration_failure or collected_not_found:
base_warn = ("Some requested variables were found, others were missing or found on disk under {}per"
"-iteration data:")
if iteration is None:
iteration_str = ""
else:
iteration_str = "non-"
base_warn = base_warn.format(iteration_str)
for failure in collected_iteration_failure:
base_warn += "\n\t{}per-iteration: {}".format(iteration_str, failure)
for missing in collected_not_found:
base_warn += "\n\tMissing: {}".format(missing)
warnings.warn(base_warn, RuntimeWarning)
return collected_variables
def write_online_analysis_data(self, iteration: Union[int, None], **kwargs):
"""
Write semi-arbitrary 1-D numeric online analysis data to storage with optional per-iteration flag.
This function helps generalize what information is stored by any given reporter, while still
enforcing a regular input and output.
The logic of what to store and how is similar, but not exact to the :func:`write_dict`.
``iteration`` accepts an integer as to indicate this this information should be written
on a per-iteration basis. The iteration the data are written to is the integer argument.
Pass ``None`` if this is *not* per-iteration data and stored independent of the iteration dimension
``**kwargs`` are processed as the variable/value pairs to store and there must be *at least one*
This should be 1-D or scalar numerical value (e.g. numpy array, list, or tuple of numbers; NOT string, dict,
etc.). Type is inferred from the first value of data input.
Parameters
----------
iteration : int or None
Optional iteration to write the data under, if ``None``, the variables will not be written on a
per-iteration basis
kwargs : pairs of name:value of numeric 1-D or scalar data
Name of variable and value to write to disk
Raises
------
TypeError : If no values are given to ``**kwargs``
ValueError : If ``iteration`` is not an integer
"""
self._resolve_iteration_args(iteration)
self._resolve_kwargs_exist(kwargs)
group = self._ensure_group_exists_and_get("online_analysis")
for name, value in kwargs.items():
self._write_1d_online_data(iteration, name, value, group)
def write_online_data_dynamic_and_static(self, iteration: int, **kwargs):
"""
Helper function to do a :func:`write_online_analysis_data` call twice, both
with and without setting iteration.
See Also
--------
write_online_analysis_data
"""
self.write_online_analysis_data(None, **kwargs)
self.write_online_analysis_data(iteration, **kwargs)
def write_current_statistics(self, data):
"""
Write real time YAML file with analysis data.
A real_time_analysis.yaml file will be generated in the same directory for the reporter netcdf file
(see :func:`~multistatereporter.MultiStateReporter` for more information).
Overwrites file if it already exists.
Parameters
----------
data: dict
Dictionary with the key, value pairs to store in YAML format.
"""
reporter_dir, reporter_filename = os.path.split(self._storage_analysis_file_path)
# remove extension from filename
yaml_prefix = os.path.splitext(reporter_filename)[0]
output_filepath = os.path.join(reporter_dir, f"{yaml_prefix}_real_time_analysis.yaml")
# Remove if it is a fresh reporter session
if self._overwrite_statistics:
try:
os.remove(output_filepath)
except FileNotFoundError:
pass
self._overwrite_statistics = False # Append from now on
with open(output_filepath, "a") as out_file:
out_file.write(yaml.dump([data], sort_keys=False))
# -------------------------------------------------------------------------
# Internal-usage.
# -------------------------------------------------------------------------
def _write_1d_online_data(self, iteration, variable, data, storage):
"""Store data on disk given pre-calculated parameters"""
if iteration is not None:
variable = variable + "_history"
if variable not in storage.variables:
variable_parameters = self._determine_netcdf_variable_parameters(iteration, data, storage)
logger.debug('Creating new NetCDF variable %s with parameters: %s' % (variable, variable_parameters)) # DEBUG
storage.createVariable(variable, variable_parameters['dtype'],
dimensions=variable_parameters['dims'],
chunksizes=variable_parameters['chunksizes'],
zlib=False)
# Get the variable
nc_var = storage[variable]
# Only get the specific iteration if specified
if iteration is not None:
nc_var[iteration, :] = data
else:
nc_var[:] = data
@staticmethod
def _find_alternate_variable(iteration, variable, storage):
"""Helper function to figure out what went wrong when data not found"""
iter_var = variable + "_history"
if iteration is None and iter_var in storage.variables:
return True
elif iteration is not None and variable in storage.variables:
return True
return False
@staticmethod
def _read_1d_online_data(iteration, variable, storage):
"""Read data on disk given storage object
Returns
-------
data
"""
if iteration is not None:
variable = variable + "_history"
nc_var = storage[variable]
nc_data = nc_var
if iteration is not None:
nc_data = nc_data[iteration]
data = nc_data[:]
if nc_var.dimensions[-1] == "scalar":
return data[0]
else:
return data
def _determine_netcdf_variable_parameters(self, iteration, data, storage):
"""
Pre-determine the variable information needed to create the variable on the storage layer
"""
if np.isscalar(data):
# Scalar data
size = 1
try:
dtype = data.dtype # numpy
except AttributeError:
dtype = type(data) # python
else:
# Array data
size = len(data)
try:
dtype = data.dtype # numpy
except AttributeError:
dtype = type(data[0]) # python
data_dim = "dim_size{}".format(size)
self._ensure_dimension_exists(data_dim, size, storage=storage)
if iteration is not None:
dims = ("iteration", data_dim)
chunks = (1, size)
else:
dims = (data_dim,)
chunks = (size,)
return {'dtype': dtype, 'dims': dims, 'chunksizes': chunks}
@staticmethod
def _resolve_iteration_args(iteration_arg):
"""
Ensure iterations given as iterations are integer or None
"""
err_message = "Only an int or None is allowed for iteration"
# Ensures int check if its a core int, np.int32, np.int64, or signed/unsigned variants
if iteration_arg is not None and not np.issubdtype(type(iteration_arg), np.integer):
raise ValueError(err_message)
@staticmethod
def _resolve_kwargs_exist(kwargs):
"""
Ensure keyword args exist (at least 1)
"""
if len(kwargs) == 0:
raise TypeError("There must be at least 1 keyword arg!")
def _resolve_nc_path(self, path, storage):
"""Return the NC group or variable at the end of the path.
This can be used to retrieve groups or variables that are nested
inside one or more groups.
"""
path_split = path.split('/')
nc_group = self._storage_dict[storage]
for group_name in path_split[:-1]:
nc_group = nc_group.groups[group_name]
# Check if this is a path to a group or a variable.
try:
return nc_group.groups[path_split[-1]]
except KeyError:
return nc_group.variables[path_split[-1]]
def _calculate_checkpoint_iteration(self, iteration):
"""Compute the iteration on disk of the checkpoint file matching the iteration linked on the analysis iteration.
Although this is a simple function, it provides a common function for calculation
Returns either the integer index, or None if there is no matched index
"""
checkpoint_index, remainder = divmod(iteration, self._checkpoint_interval)
if remainder == 0:
# NetCDF variables can't be assigned using numpy integers.
return int(checkpoint_index)
return None
def _map_iteration_to_good(self, iteration):
"""
Convert the iteration in 'read_X' functions which take a iteration=slice(None)
to avoid returning a slice of data which goes past the last_iteration.
This effectively sets the max iteration to the last_good_iteration.
Parameters
----------
iteration : int or Slice
Iteration to feed into the check
Returns
-------
cast_iteration : int or Slice of type iteration
Iteration, converted as needed to only access certain ranges of data
"""
# Calculate last stored iteration
last_good = self.read_last_iteration(last_checkpoint=False)
# Create the artificial index map
artificial_map = np.arange(last_good + 1, dtype=int)
# Generate true index map from the input
cast_iteration = artificial_map[iteration]
return cast_iteration
def _ensure_dimension_exists(self, dim_name, dim_size, storage=None):
"""
Ensure a dimension exists and is of the appropriate size,
creating it if it does not already exist.
A ``ValueError`` is raised if ``dim_size`` does not match the existing dimension size.
Parameters
----------
dim_name : str
The dimension name
dim_size : int
The dimension size
storage : netCDF4.Dataset or netCDF4.Group, optional, default: None
Storage layer to check the dimension against. If none, the _storage_analysis is used
"""
if storage is None:
storage = self._storage_analysis
if dim_name not in storage.dimensions:
storage.createDimension(dim_name, dim_size)
else:
# Check dimension matches expected size
dimension = storage.dimensions[dim_name]
if dim_size == 0:
if not dimension.isunlimited():
raise ValueError("NetCDF dimension {} already exists: was previously unlimited, but tried to "
"redeclare it with size {}".format(dimension.name, dim_size))
else:
if not int(dimension.size) == int(dim_size):
raise ValueError("NetCDF dimension {} already exists: was previously size {}, but tried to "
"redeclare it with dimension {}".format(dimension.name, dimension.size, dim_size))
def _ensure_group_exists_and_get(self, group_name, storage=None):
"""
Ensure a group exists and fetch it if it does, creating first if it does not.
A ``ValueError`` is raised if ``dimsize`` does not match the existing dimension size.
Parameters
----------
group_name : str
The group name
storage : known storage object or None, default None
Storage object to check against, if None, assumes the analysis storage
"""
if storage is None:
storage = self._storage_analysis
if group_name not in storage.groups:
storage.createGroup(group_name)
return storage.groups[group_name]
@staticmethod
def _initialize_sampler_variables_on_file(dataset, n_atoms, n_replicas, is_periodic):
"""
Initialize the NetCDF variables on the storage file needed to store sampler states.
Does nothing if file already initialized
Parameters
----------
dataset : NetCDF4 Dataset
Dataset to validate
n_atoms : int
Number of atoms which will be stored
n_replicas : int
Number of Sampler states which will be written
is_periodic : bool
True if system is periodic; False otherwise.
"""
if 'positions' not in dataset.variables:
# Create dimensions. Replica dimension could have been created before.
dataset.createDimension('atom', n_atoms)
if 'replica' not in dataset.dimensions:
dataset.createDimension('replica', n_replicas)
# Define position variables.
ncvar_positions = dataset.createVariable('positions', 'f4',
('iteration', 'replica', 'atom', 'spatial'),
zlib=True, chunksizes=(1, n_replicas, n_atoms, 3))
ncvar_positions.units = 'nm'
ncvar_positions.long_name = ("positions[iteration][replica][atom][spatial] is position of "
"coordinate 'spatial' of atom 'atom' from replica 'replica' for "
"iteration 'iteration'.")
# Define velocities variables.
ncvar_velocities = dataset.createVariable('velocities', 'f4',
('iteration', 'replica', 'atom', 'spatial'),
zlib=True, chunksizes=(1, n_replicas, n_atoms, 3))
ncvar_velocities.units = 'nm / ps'
ncvar_velocities.long_name = ("velocities[iteration][replica][atom][spatial] is velocity of "
"coordinate 'spatial' of atom 'atom' from replica 'replica' for "
"iteration 'iteration'.")
# Define variables for periodic systems
if is_periodic:
ncvar_box_vectors = dataset.createVariable('box_vectors', 'f4',
('iteration', 'replica', 'spatial', 'spatial'),
zlib=False, chunksizes=(1, n_replicas, 3, 3))
ncvar_volumes = dataset.createVariable('volumes', 'f8', ('iteration', 'replica'),
zlib=False, chunksizes=(1, n_replicas))
ncvar_box_vectors.units = 'nm'
ncvar_volumes.units = 'nm**3'
ncvar_box_vectors.long_name = ("box_vectors[iteration][replica][i][j] is dimension j of box "
"vector i for replica 'replica' from iteration "
"'iteration-1'.")
ncvar_volumes.long_name = ("volume[iteration][replica] is the box volume for replica "
"'replica' from iteration 'iteration-1'.")
def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: int,
storage_file='checkpoint', obey_checkpoint_interval=True):
"""
Internal function to handle writing sampler states more generically to target file
Parameters
----------
sampler_states : list of states.SamplerStates
The sampler states to store for each replica.
iteration : int
The iteration at which to store the data.
storage_file : string, Optional, Default: 'checkpoint'
Name of storage file we're writing to. Must match a valid key of self._storage_dict
obey_checkpoint_interval : bool, Optional, Default: False
Tells this (attempted) write to obey the checkpoint interval or not.
If True, no write out will be done if iteration is not on the checkpoint interval
If False, the write WILL occur
"""
storage = self._storage_dict[storage_file]
# Check if the schema must be initialized, do this regardless of the checkpoint_interval for consistency
is_periodic = True if (sampler_states[0].box_vectors is not None) else False
n_particles = sampler_states[0].n_particles
n_replicas = len(sampler_states)
self._initialize_sampler_variables_on_file(storage, n_particles,
n_replicas, is_periodic)
if obey_checkpoint_interval:
write_iteration = self._calculate_checkpoint_iteration(iteration)
else:
write_iteration = iteration
# Write the sampler state if we are on the checkpoint interval OR if told to ignore the interval
if write_iteration is not None:
# Store sampler states.
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
positions = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
# Store positions in memory first
x = sampler_state.positions / unit.nanometers
positions[replica_index, :, :] = x[:, :]
# Store positions
storage.variables['positions'][write_iteration, :, :, :] = positions
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
velocities = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
if sampler_state._unitless_velocities is not None:
# Store velocities in memory first
x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities
velocities[replica_index, :, :] = x[:, :]
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities
if is_periodic:
# Store box vectors and volume.
# Allocate whole write to memory first
box_vectors = np.zeros([n_replicas, 3, 3])
volumes = np.zeros([n_replicas])
for replica_index, sampler_state in enumerate(sampler_states):
box_vectors[replica_index, :, :] = sampler_state.box_vectors / unit.nanometers
volumes[replica_index] = sampler_state.volume / unit.nanometers ** 3
storage.variables['box_vectors'][write_iteration, :, :, :] = box_vectors
storage.variables['volumes'][write_iteration, :] = volumes
else:
logger.debug("Iteration {} not on the Checkpoint Interval of {}. "
"Sampler State not written.".format(iteration, self._checkpoint_interval))
def _read_sampler_states_from_given_file(self, iteration, storage_file='checkpoint', obey_checkpoint_interval=True):
"""
Internal function to handle reading sampler states more from a general storage file
Parameters
----------
iteration : int
Iteration on which to read data from
storage_file : string, Optional, Default: 'checkpoint'
Name of storage file we're writing to. Must match a valid key of self._storage_dict
obey_checkpoint_interval : bool, Optional, Default: False
Tells this (attempted) write to obey the checkpoint interval or not.
If True, no write out will be done if iteration is not on the checkpoint interval
If False, the read will be attempted regardless
WARNING: If the storage file you specify was written on the checkpoint interval and you set
obey_checkpoint_interval=False, you will get undefined behavior!
Returns
-------
sampler_states : list of SamplerStates or None
The previously stored sampler states for each replica.
If the iteration is not on the checkpoint_interval and the file only writes on the checkpoint_interval,
None is returned
"""
storage = self._storage_dict[storage_file]
read_iteration = self._map_iteration_to_good(iteration)
if obey_checkpoint_interval:
read_iteration = self._calculate_checkpoint_iteration(iteration)
if read_iteration is not None:
# TODO: Restore n_replicas instead
n_replicas = storage.dimensions['replica'].size
sampler_states = list()
for replica_index in range(n_replicas):
# Restore positions.
x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64)
positions = unit.Quantity(x, unit.nanometers)
# Restore velocities
# try-catch exception, enabling reading legacy/older serialized objects from openmmtools<0.21.3
try:
x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64)
velocities = unit.Quantity(x, unit.nanometer / unit.picoseconds)
except KeyError: # Velocities key/variable not found in serialization (openmmtools<=0.21.2)
# pass zeros as velocities when key is not found (<0.21.3 behavior)
velocities = np.zeros_like(positions)
if 'box_vectors' in storage.variables:
# Restore box vectors.
x = storage.variables['box_vectors'][read_iteration, replica_index, :, :].astype(np.float64)
box_vectors = unit.Quantity(x, unit.nanometers)
else:
box_vectors = None
# Create SamplerState.
sampler_states.append(states.SamplerState(positions=positions, velocities=velocities, box_vectors=box_vectors))
return sampler_states
else:
return None
def _write_dict(self, path, data, storage_name='analysis',
fixed_dimension=False, nested=False):
"""Store the contents of a dict.
Parameters
----------
path : str
The path to the dictionary in the storage file.
data : dict
The dict to store.
storage_name : 'analysis' or 'checkpoint'
The name of the storage file where to save the dict.
fixed_dimension : bool, default: False
Use a fixed length dimension instead of variable length one.
This method seems to allow NetCDF to actually compress strings.
A unique dimension name called ``"fixedL{}".format(len(data))``
will be created.
Do NOT use this flag if you expect to constantly be changing
the length of the data fed in. Use only for static data.
nested : bool, default False
In nested representation, dictionaries are represented as
groups, and values as strings. In this mode, it is possible
to retrieve a keyword without reading the whole dictionary.
In this mode, the dictionary can be overwritten, but ONLY if
the structure of the dict doesn't change, as it's impossible
to delete groups/variables from a netcdf database.
"""
storage_nc = self._storage_dict[storage_name]
# Save nested dictionary into a group if requested, unless the dictionary is empty.
if nested and isinstance(data, dict) and len(data) > 0:
for key, value in data.items():
if not isinstance(key, str):
raise ValueError('Cannot store dict in nested form with non-string keys.')
self._write_dict(path + '/' + key, value, storage_name,
fixed_dimension, nested)
return
# Activate flow style to save space.
data_str = yaml.dump(data, Dumper=_DictYamlDumper)
# Check if we are updating the dictionary or creating it.
try:
nc_variable = self._resolve_nc_path(path, storage_name)
except KeyError:
if fixed_dimension:
variable_type = 'S1'
dimension_name = "fixedL{}".format(len(data_str))
# Create a new fixed dimension if necessary.
if dimension_name not in storage_nc.dimensions:
storage_nc.createDimension(dimension_name, len(data_str))
else:
variable_type = str
dimension_name = 'scalar'
# Create variable.
nc_variable = storage_nc.createVariable(path, variable_type,
dimension_name, zlib=False)
# Assign the value to the variable.
if fixed_dimension:
packed_data = np.array(list(data_str), dtype='S1')
else:
packed_data = np.empty(1, 'O')
packed_data[0] = data_str
nc_variable[:] = packed_data
# ==============================================================================
# MODULE INTERNAL USAGE UTILITIES
# ==============================================================================
class _DictYamlLoader(yaml.CLoader):
"""PyYAML Loader that reads !Quantity tags."""
def __init__(self, *args, **kwargs):
super(_DictYamlLoader, self).__init__(*args, **kwargs)
self.add_constructor(u'!Quantity', self.quantity_constructor)
self.add_constructor(u'!ndarray', self.ndarray_constructor)
@staticmethod
def quantity_constructor(loader, node):
loaded_mapping = loader.construct_mapping(node)
data_unit = quantity_from_string(loaded_mapping['unit'])
data_value = loaded_mapping['value']
return data_value * data_unit
@staticmethod
def ndarray_constructor(loader, node):
loaded_mapping = loader.construct_mapping(node, deep=True)
data_type = np.dtype(loaded_mapping['type'])
data_shape = loaded_mapping['shape']
data_values = loaded_mapping['values']
data = np.ndarray(shape=data_shape, dtype=data_type)
if 0 not in data_shape:
data[:] = data_values
return data
class _DictYamlDumper(yaml.CDumper):
"""PyYAML Dumper that handle openmm Quantities through !Quantity tags."""
def __init__(self, *args, **kwargs):
super(_DictYamlDumper, self).__init__(*args, **kwargs)
self.add_representer(unit.Quantity, self.quantity_representer)
self.add_representer(np.ndarray, self.ndarray_representer)
@staticmethod
def quantity_representer(dumper, data):
data_unit = data.unit
data_value = data / data_unit
data_dump = dict(unit=str(data_unit), value=data_value)
return dumper.represent_mapping(u'!Quantity', data_dump)
@staticmethod
def ndarray_representer(dumper, data):
"""Convert a numpy array to native Python types."""
data_type = str(data.dtype)
data_shape = data.shape
data_values = data.tolist()
data_dump = dict(type=data_type, shape=data_shape, values=data_values)
return dumper.represent_mapping(u'!ndarray', data_dump)