Source code for openmmtools.multistate.multistatereporter

#!/usr/local/bin/env python

# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================

"""
Multistatereporter
==================

Master multi-thermodynamic state reporter module. Handles all Disk I/O
reporting operations for any MultiStateSampler derived classes.

COPYRIGHT

Current version by Andrea Rizzi <andrea.rizzi@choderalab.org>, Levi N. Naden <levi.naden@choderalab.org> and
John D. Chodera <john.chodera@choderalab.org> while at Memorial Sloan Kettering Cancer Center.

Original version by John D. Chodera <jchodera@gmail.com> while at the University of
California Berkeley.

LICENSE

This code is licensed under the latest available version of the MIT License.

"""

# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================

import os
import copy
import time
import uuid
import yaml
import warnings
import logging
import collections

import numpy as np
import netCDF4 as netcdf

from typing import Union, Any

try:
    from openmm import unit
except ImportError:  # OpenMM < 7.6
    from simtk import unit

import openmmtools
from openmmtools.utils import deserialize, with_timer, serialize, quantity_from_string
from openmmtools import states


logger = logging.getLogger(__name__)


# ==============================================================================
# MULTISTATE SAMPLER REPORTER
# ==============================================================================

[docs] class MultiStateReporter(object): """Handle storage write/read operations and different format conventions. You can use this object to programmatically inspect the data generated by ReplicaExchange. Parameters ---------- storage : str The path to the storage file for analysis. A second checkpoint file will be determined from either ``checkpoint_storage`` or automatically based on the storage option In the future this will be able to take Storage classes as well. open_mode : str or None The mode of the file between 'r', 'w', and 'a' (or equivalently 'r+'). If None, the storage file won't be open on construction, and a call to :func:`Reporter.open` will be needed before attempting read/write operations. checkpoint_interval : int >= 1, Default: 50 The frequency at which checkpointing information is written relative to analysis information. This is a multiple of the iteration at which energies is written, hence why it must be greater than or equal to 1. Checkpoint information cannot be written on iterations which where ``iteration % checkpoint_interval != 0``. checkpoint_storage : str or None, optional Optional name of the checkpoint point file. This file is used to save trajectory information and other less frequently accessed data. This should NOT be a full path, and instead just a filename If None: the derived checkpoint name is the same as storage, less any extension, then "_checkpoint.nc" is added. The reporter internally tracks what data goes into which file, so its transparent to all other classes In the future, this will be able to take Storage classes as well analysis_particle_indices : tuple of ints, Optional. Default: () (empty tuple) If specified, it will serialize positions and velocities for the specified particles, at every iteration, in the reporter storage (.nc) file. If empty, no positions or velocities will be stored in this file for any atoms. Attributes ---------- filepath checkpoint_interval is_periodic n_states n_replicas analysis_particle_indices """
[docs] def __init__(self, storage, open_mode=None, checkpoint_interval=50, checkpoint_storage=None, analysis_particle_indices=()): # Warn that API is experimental logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases') # Handle checkpointing if type(checkpoint_interval) != int: raise ValueError("checkpoint_interval must be an integer!") dirname, filename = os.path.split(storage) if checkpoint_storage is None: basename, ext = os.path.splitext(filename) addon = "_checkpoint" checkpoint_storage = os.path.join(dirname, basename + addon + ext) logger.debug("Initial checkpoint file automatically chosen as {}".format(checkpoint_storage)) else: checkpoint_storage = os.path.join(dirname, checkpoint_storage) self._storage_analysis_file_path = storage self._storage_checkpoint_file_path = checkpoint_storage self._storage_checkpoint = None self._storage_analysis = None self._checkpoint_interval = checkpoint_interval # Cast to tuple no mater what 1-D-like input was given self._analysis_particle_indices = tuple(analysis_particle_indices) if open_mode is not None: self.open(open_mode) # TODO: Maybe we want to expose this flag to control ovrwriting/appending # Flag to check whether to overwrite real time statistics file -- Defaults to append self._overwrite_statistics = False
@property def filepath(self): """ Returns the string file name of the primary storage file Classes outside the Reporter can access the file string for error messages and such. """ return self._storage_analysis_file_path @property def _storage(self): """ Return an iterable of the storage objects, avoids having the [list, of, storage, objects] everywhere Object 0 is always the primary file, all others are subfiles """ return self._storage_analysis, self._storage_checkpoint @property def _storage_paths(self): """ Return an iterable of paths to the storage files Object 0 is always the primary file, all others are subfiles """ return self._storage_analysis_file_path, self._storage_checkpoint_file_path @property def _storage_dict(self): """Return an iterable dictionary of the self._storage_X objects""" return {'checkpoint': self._storage_checkpoint, 'analysis': self._storage_analysis} @property def n_states(self): if not self.is_open(): return None return self._storage_analysis.dimensions['state'].size @property def n_replicas(self): if not self.is_open(): return None return self._storage_analysis.dimensions['replica'].size @property def is_periodic(self): if not self.is_open(): return None if 'box_vectors' in self._storage_analysis.variables: return True return False @property def analysis_particle_indices(self): """Return the tuple of indices of the particles which additional information is stored on for analysis""" return self._analysis_particle_indices @property def checkpoint_interval(self): """Returns the checkpoint interval""" return self._checkpoint_interval def storage_exists(self, skip_size=False): """ Check if the storage files exist on disk. Reads information on the primary file to see existence of others Parameters ---------- skip_size : bool, Optional, Default: False Skip the check of the file size. Helpful if you have just initialized a storage file but written nothing to it yet and/or its still entirely in memory (e.g. just opened NetCDF files) Returns ------- files_exist : bool If the primary storage file and its related subfiles exist, returns True. If the primary file or any subfiles do not exist, returns False """ # This function serves as a way to mask the subfiles from everything outside the reporter for file_path in self._storage_paths: if not os.path.exists(file_path): return False # Return if any files do not exist elif not os.path.getsize(file_path) > 0 and not skip_size: return False # File is 0 size return True def is_open(self): """Return True if the Reporter is ready to read/write.""" if self._storage[0] is None: return False else: return self._storage[0].isopen() def _are_subfiles_open(self): """Internal function to check if subfiles are open""" open_check_list = [] for storage in self._storage[1:]: if storage is None: return False else: open_check_list.append(storage.isopen()) return np.all(open_check_list) def open(self, mode='r', convention='ReplicaExchange', netcdf_format='NETCDF4'): """ Open the storage file for reading/writing. Creates and pre-formats the required files if they don't exist. This is not necessary if you have indicated in the constructor to open. Parameters ---------- mode : str, Optional, Default: 'r' The mode of the file between 'r', 'w', and 'a' (or equivalently 'r+'). convention : str, Optional, Default: 'ReplicaExchange' NetCDF convention to write netcdf_format : str, Optional, Default: 'NETCDF4' The NetCDF file format to use """ # Ensure we don't have already another file # open (possibly in a different mode). self.close() # Create directory if we want to write. # TODO: We probably want to check here specifically for w when we want to write if mode != 'r': for storage_path in self._storage_paths: # normpath() transform '' to '.' for makedirs(). storage_dir = os.path.normpath(os.path.dirname(storage_path)) os.makedirs(storage_dir, exist_ok=True) # Analysis file. # --------------- # Open analysis file. self._storage_analysis = self._open_dataset_robustly(self._storage_analysis_file_path, mode, version=netcdf_format) # The analysis netcdf file holds a reference UUID so that we can check # that the secondary netcdf files (currently only the checkpoint # file) have the same UUID to verify that the user isn't erroneously # trying to associate the anaysis file to the incorrect checkpoint. try: # Check if we have previously created the file. primary_uuid = self._storage_analysis.UUID except AttributeError: # This is a new file. Use uuid4 to avoid assigning hostname information. primary_uuid = str(uuid.uuid4()) self._storage_analysis.UUID = primary_uuid # Initialize dataset, if needed. self._initialize_storage_file(self._storage_analysis, 'analysis', convention) # Checkpoint file. # ----------------- # Open checkpoint netcdf files. msg = ('Could not locate checkpoint subfile. This is okay for analysis if the ' 'solvent trajectory is not needed, but not for production simulation!') self._storage_checkpoint = self._open_dataset_robustly(self._storage_checkpoint_file_path, mode, catch_io_error=True, io_error_warning=msg, version=netcdf_format) if self._storage_checkpoint is not None: # Check that the checkpoint file has the same UUID of the analysis file. try: assert self._storage_checkpoint.UUID == primary_uuid except AttributeError: # This is a new file. Assign UUID. self._storage_checkpoint.UUID = primary_uuid except AssertionError: raise IOError('Checkpoint UUID does not match analysis UUID! ' 'This checkpoint file came from another simulation!\n' 'Analysis UUID: {}; Checkpoint UUID: {}'.format( primary_uuid, self._storage_checkpoint.UUID)) # Initialize dataset, if needed. self._initialize_storage_file(self._storage_checkpoint, 'checkpoint', convention) # Further checkpoint interval checks. # ----------------------------------- if self._storage_analysis is not None: # The same number will be on checkpoint file as well, but its not guaranteed to be present on_file_interval = self._storage_analysis.CheckpointInterval if on_file_interval != self._checkpoint_interval: logger.debug("checkpoint_interval != on-file checkpoint interval! " "Using on file analysis interval of {}.".format(on_file_interval)) self._checkpoint_interval = on_file_interval # Check the special particle indices # Handle the "variable does not exist" case if 'analysis_particle_indices' not in self._storage_analysis.variables: n_particles = len(self._analysis_particle_indices) # This dimension won't exist if the above statement does not either self._storage_analysis.createDimension('analysis_particles', n_particles) ncvar_analysis_particles = \ self._storage_analysis.createVariable('analysis_particle_indices', int, 'analysis_particles') ncvar_analysis_particles[:] = self._analysis_particle_indices ncvar_analysis_particles.long_name = ("analysis_particle_indices[analysis_particles] is the indices of " "the particles with extra information stored about them in the" "analysis file.") # Now handle the "variable does exist but does not match the provided ones" # Although redundant if it was just created, its an easy check to make stored_analysis_particles = self._storage_analysis.variables['analysis_particle_indices'][:] if self._analysis_particle_indices != tuple(stored_analysis_particles.astype(int)): logger.debug("analysis_particle_indices != on-file analysis_particle_indices!" "Using on file analysis indices of {}".format(stored_analysis_particles)) self._analysis_particle_indices = tuple(stored_analysis_particles.astype(int)) def _open_dataset_robustly(self, *args, n_attempts=5, sleep_time=2, catch_io_error=False, io_error_warning=None, **kwargs): """Attempt to open the dataset multiple times if it raises an error. This may be useful to solve some MPI concurrency and locking issues that routinely and randomly pop up with HDF5. Some sleep time is added between attempts (in seconds). If the file is not found and catch_io_error is True, None is returned. """ # Catch eventual errors n_attempts - 1 times. for attempt in range(n_attempts-1): try: return netcdf.Dataset(*args, **kwargs) except Exception as err: # This should be safe since we will raise the error below on return logger.debug(f"exception thrown {err}") logger.debug('Attempt {}/{} to open {} failed. Retrying ' 'in {} seconds'.format(attempt+1, n_attempts, args[0], sleep_time)) time.sleep(sleep_time) # Check if file exists and warn if asked # raise IOError otherwise if not os.path.isfile(args[0]): if catch_io_error: if io_error_warning is not None: logger.warning(io_error_warning) return None raise IOError(f"{args[0]} does not exist") # At the very last attempt, we try setting the environment variable # controlling the locking mechanism of HDF5 (see choderalab/yank#1165). if n_attempts > 1: os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' # Last attempt finally raises any error. return netcdf.Dataset(*args, **kwargs) def _initialize_storage_file(self, ncfile, nc_name, convention): """Helper function to initialize dimensions and global attributes. If the dataset has been initialized before, nothing happens. Return True if the file has been initialized before and False otherwise. """ from openmmtools import __version__ if 'scalar' not in ncfile.dimensions: # Create common dimensions. ncfile.createDimension('scalar', 1) # Scalar dimension. ncfile.createDimension('iteration', 0) # Unlimited number of iterations. ncfile.createDimension('spatial', 3) # Number of spatial dimensions. # Set global attributes. ncfile.program = f"openmmtools {openmmtools.__version__}" ncfile.programVersion = __version__ ncfile.Conventions = convention ncfile.ConventionVersion = '0.2' ncfile.DataUsedFor = nc_name ncfile.CheckpointInterval = self._checkpoint_interval # Create and initialize the global variables nc_last_good_iter = ncfile.createVariable('last_iteration', int, 'scalar') nc_last_good_iter[0] = 0 return True else: return False def close(self): """Close the storage files""" for storage_name, storage in self._storage_dict.items(): if storage is not None: if storage.isopen(): storage.sync() storage.close() setattr(self, '_storage' + storage_name, None) def sync(self): """Force any buffer to be flushed to the file""" for storage in self._storage: if storage is not None: storage.sync() def __del__(self): """Synchronize and close the storage.""" self.close() def read_end_thermodynamic_states(self): """Read thermodynamic states at the ends of the protocol." Returns ------- end_thermodynamic_states : list of ThermodynamicState unsampled_states, if present, or first and last sampled states """ end_thermodynamic_states = list() if 'unsampled_states' in self._storage_analysis.groups: state_type = 'unsampled_states' else: state_type = 'thermodynamic_states' # Read thermodynamic end states states_serializations = dict() n_states = len(self._storage_analysis.groups[state_type].variables) def extract_serialized_state(inner_type, inner_id): """Inner function to help extract the correct serialized state while minimizing number of disk reads Parameters ---------- inner_type : str, 'unsampled_states' or 'thermodynamic_states' Where to read the data from, inherited from parent function's property or on the recursive loop inner_id : int Which state to pull data from """ inner_serialized_state = self.read_dict('{}/state{}'.format(inner_type, inner_id)) def isolate_thermodynamic_state(isolating_serialized_state): """Helper function to find true bottom level thermodynamic state from any level of nesting, reduce code """ isolating_serial_thermodynamic_state = isolating_serialized_state while 'thermodynamic_state' in isolating_serial_thermodynamic_state: # The while loop is necessary for nested CompoundThermodynamicStates. isolating_serial_thermodynamic_state = isolating_serial_thermodynamic_state['thermodynamic_state'] return isolating_serial_thermodynamic_state serialized_thermodynamic_state = isolate_thermodynamic_state(inner_serialized_state) # Check if the standard state is in a previous state. try: standard_system_name = serialized_thermodynamic_state.pop('_Reporter__compatible_state') except KeyError: # Cache the standard system serialization for future usage. standard_system_name = '{}/{}'.format(inner_type, inner_id) states_serializations[standard_system_name] = serialized_thermodynamic_state['standard_system'] else: # The system serialization can be retrieved from another state. # Because the unsampled states rely on the thermodynamic states for their deserialization, we have # to try a secondary/recursive loop to get the thermodynamic states # However, this loop happens less often as the states_serializations dict fills up. try: serialized_standard_system = states_serializations[standard_system_name] except KeyError: loop_type, loop_id = standard_system_name.split('/') looped_standard_state = extract_serialized_state(loop_type, loop_id) looped_serial_thermodynamic_state = isolate_thermodynamic_state(looped_standard_state) serialized_standard_system = looped_serial_thermodynamic_state['standard_system'] serialized_thermodynamic_state['standard_system'] = serialized_standard_system return inner_serialized_state for state_id in [0, n_states-1]: serialized_state = extract_serialized_state(state_type, state_id) # Create ThermodynamicState object. end_thermodynamic_states.append(deserialize(serialized_state)) return end_thermodynamic_states @with_timer('Reading thermodynamic states from storage') def read_thermodynamic_states(self): """Retrieve the stored thermodynamic states from the checkpoint file. Returns ------- thermodynamic_states : list of ThermodynamicStates The previously stored thermodynamic states. During the simulation these are swapped among replicas. unsampled_states : list of ThermodynamicState The previously stored unsampled thermodynamic states. See Also -------- read_replica_thermodynamic_states """ # We have to parse the thermodynamic states first because the # unsampled states may refer to them for the serialized system. states = collections.OrderedDict([('thermodynamic_states', list()), ('unsampled_states', list())]) # Caches standard_system_name: serialized_standard_system states_serializations = dict() # Read state information. for state_type, state_list in states.items(): # There may not be unsampled states. if state_type not in self._storage_analysis.groups: assert state_type == 'unsampled_states' continue # We keep looking for states until we can't find them anymore. n_states = len(self._storage_analysis.groups[state_type].variables) for state_id in range(n_states): serialized_state = self.read_dict('{}/state{}'.format(state_type, state_id)) # Find the thermodynamic state representation. serialized_thermodynamic_state = serialized_state while 'thermodynamic_state' in serialized_thermodynamic_state: # The while loop is necessary for nested CompoundThermodynamicStates. serialized_thermodynamic_state = serialized_thermodynamic_state['thermodynamic_state'] # Check if the standard state is in a previous state. try: standard_system_name = serialized_thermodynamic_state.pop('_Reporter__compatible_state') except KeyError: # Cache the standard system serialization for future usage. standard_system_name = '{}/{}'.format(state_type, state_id) states_serializations[standard_system_name] = serialized_thermodynamic_state['standard_system'] else: # The system serialization can be retrieved from another state. serialized_standard_system = states_serializations[standard_system_name] serialized_thermodynamic_state['standard_system'] = serialized_standard_system # Create ThermodynamicState object. states[state_type].append(deserialize(serialized_state)) return [states['thermodynamic_states'], states['unsampled_states']] @with_timer('Storing thermodynamic states') def write_thermodynamic_states(self, thermodynamic_states, unsampled_states): """Store all the ThermodynamicStates to the checkpoint file. Parameters ---------- thermodynamic_states : list of ThermodynamicState The thermodynamic states to store. unsampled_states : list of ThermodynamicState The unsampled thermodynamic states to store. See Also -------- write_replica_thermodynamic_states """ # Store all thermodynamic states as serialized dictionaries. stored_states = dict() def unnest_thermodynamic_state(serialized): while 'thermodynamic_state' in serialized: serialized = serialized['thermodynamic_state'] return serialized for state_type, states in [('thermodynamic_states', thermodynamic_states), ('unsampled_states', unsampled_states)]: for state_id, state in enumerate(states): # Check if any compatible state has been found found_compatible_state = False for compare_state in stored_states: if compare_state.is_state_compatible(state): serialized_state = serialize(state, skip_system=True) serialized_thermodynamic_state = unnest_thermodynamic_state(serialized_state) serialized_thermodynamic_state.pop('standard_system') # Remove the unneeded system object reference_state_name = stored_states[compare_state] serialized_thermodynamic_state['_Reporter__compatible_state'] = reference_state_name found_compatible_state = True break # If no compatible state is found, do full serialization if not found_compatible_state: serialized_state = serialize(state) serialized_thermodynamic_state = unnest_thermodynamic_state(serialized_state) serialized_standard_system = serialized_thermodynamic_state['standard_system'] reference_state_name = '{}/{}'.format(state_type, state_id) len_serialization = len(serialized_standard_system) # Store new compatibility data stored_states[state] = reference_state_name logger.debug("Serialized state {} is {}B | {:.3f}KB | {:.3f}MB".format( reference_state_name, len_serialization, len_serialization/1024.0, len_serialization/1024.0/1024.0)) # Finally write the dictionary with fixed dimension to improve compression. self._write_dict('{}/state{}'.format(state_type, state_id), serialized_state, fixed_dimension=True) def read_sampler_states(self, iteration, analysis_particles_only=False): """Retrieve the stored sampler states on the checkpoint file If the iteration is not on the checkpoint interval, None is returned. Exception to this is if``analysis_particles_only``, see the Returns for output behavior. Parameters ---------- iteration : int The iteration at which to read the data. analysis_particles_only : bool, Optional, Default: False If set, will return the trajectory of ONLY the analysis particles flagged on original creation of the files Returns ------- sampler_states : list of SamplerStates or None The previously stored sampler states for each replica. If the iteration is not on the ``checkpoint_interval`` and the ``analysis_particles_only`` is not set, None is returned If ``analysis_particles_only`` is set, will return only the subset of particles which were defined by the ``analysis_particle_indices`` on reporter creation """ if analysis_particles_only: if len(self._analysis_particle_indices) == 0: raise ValueError("No particles were flagged for special analysis! " "No such trajectory would have been written!") return self._read_sampler_states_from_given_file(iteration, storage_file='analysis', obey_checkpoint_interval=False) else: return self._read_sampler_states_from_given_file(iteration, storage_file='checkpoint', obey_checkpoint_interval=True) @with_timer('Storing sampler states') def write_sampler_states(self, sampler_states: list, iteration: int): """Store all sampler states for a given iteration on the checkpoint file If the iteration is not on the checkpoint interval, only the ``analysis_particle_indices`` data is written, if set. Parameters ---------- sampler_states : list of SamplerStates The sampler states to store for each replica. iteration : int The iteration at which to store the data. """ # Case of no special atoms, write to normal checkpoint self._write_sampler_states_to_given_file(sampler_states, iteration, storage_file='checkpoint', obey_checkpoint_interval=True) if len(self._analysis_particle_indices) > 0: # Special case, pre-process the sampler_states sampler_subset = [] for sampler_state in sampler_states: positions = sampler_state.positions # Subset positions # Need the [arg, :] to get uniform behavior with tuple and list for arg # since a ndarray[tuple] is different than ndarray[list] position_subset = positions[self._analysis_particle_indices, :] velocities_subset = None if sampler_state._unitless_velocities is not None: velocities = sampler_state.velocities velocities_subset = velocities[self._analysis_particle_indices, :] sampler_subset.append(states.SamplerState(position_subset, velocities=velocities_subset, box_vectors=sampler_state.box_vectors)) self._write_sampler_states_to_given_file(sampler_subset, iteration, storage_file='analysis', obey_checkpoint_interval=False) def read_replica_thermodynamic_states(self, iteration=slice(None)): """Retrieve the indices of the ThermodynamicStates for each replica on the analysis file Parameters ---------- iteration : int or slice The iteration(s) at which to read the data. The slice(None) allows fetching all iterations at once. Returns ------- state_indices : np.ndarray of int At the given iteration, replica ``i`` propagated the system in SamplerState ``sampler_states[i]`` and ThermodynamicState ``thermodynamic_states[states_indices[i]]``. If a slice is given, returns shape ``[len(slice), `len(sampler_states)]`` """ iteration = self._map_iteration_to_good(iteration) logger.debug('read_replica_thermodynamic_states: iteration = {}'.format(iteration)) return self._storage_analysis.variables['states'][iteration].astype(np.int64) def write_replica_thermodynamic_states(self, state_indices, iteration): """Store the indices of the ThermodynamicStates for each replica on the analysis file Parameters ---------- state_indices : list of int of size n_replicas At the given iteration, replica ``i`` propagated the system in SamplerState ``sampler_states[i]`` and ThermodynamicState ``thermodynamic_states[replica_thermodynamic_states[i]]``. iteration : int The iteration at which to store the data. """ # Initialize schema if needed. if 'states' not in self._storage_analysis.variables: n_replicas = len(state_indices) # Create dimension if they don't exist. self._ensure_dimension_exists('replica', n_replicas) # Create variables and attach units and description. ncvar_states = self._storage_analysis.createVariable( 'states', 'i4', ('iteration', 'replica'), zlib=False, chunksizes=(1, n_replicas) ) ncvar_states.units = 'none' ncvar_states.long_name = ("states[iteration][replica] is the thermodynamic state index " "(0..n_states-1) of replica 'replica' of iteration 'iteration'.") # Store thermodynamic states indices. self._storage_analysis.variables['states'][iteration, :] = state_indices[:] def read_mcmc_moves(self): """Return the MCMCMoves of the :class:`yank.multistate.MultiStateSampler` simulation on the checkpoint Returns ------- mcmc_moves : list of MCMCMove The MCMCMoves used to propagate the simulation. """ n_moves = len(self._storage_analysis.groups['mcmc_moves'].variables) # Retrieve all moves in order. mcmc_moves = list() for move_id in range(n_moves): serialized_move = self.read_dict('mcmc_moves/move{}'.format(move_id)) mcmc_moves.append(deserialize(serialized_move)) return mcmc_moves def write_mcmc_moves(self, mcmc_moves): """Store the MCMCMoves of the :class:`yank.multistate.MultiStateSampler` simulation or subclasses on the checkpoint Parameters ---------- mcmc_moves : list of MCMCMove The MCMCMoves used to propagate the simulation. """ for move_id, move in enumerate(mcmc_moves): serialized_move = serialize(move) self.write_dict('mcmc_moves/move{}'.format(move_id), serialized_move) def read_energies(self, iteration=slice(None)): """Retrieve the energy matrix at the given iteration on the analysis file Parameters ---------- iteration : int or slice The iteration(s) at which to read the data. The slice(None) allows fetching all iterations at once. Returns ------- energy_thermodynamic_states : n_replicas x n_states numpy.ndarray ``energy_thermodynamic_states[iteration, i, j]`` is the reduced potential computed at SamplerState ``sampler_states[iteration, i]`` and ThermodynamicState ``thermodynamic_states[iteration, j]``. energy_neighborhoods : n_replicas x n_states numpy.ndarray energy_neighborhoods[replica_index, state_index] is 1 if the energy was computed for this state, 0 otherwise energy_unsampled_states : n_replicas x n_unsampled_states numpy.ndarray ``energy_unsampled_states[iteration, i, j]`` is the reduced potential computed at SamplerState ``sampler_states[iteration, i]`` and ThermodynamicState ``unsampled_thermodynamic_states[iteration, j]``. """ # Determine last consistent iteration iteration = self._map_iteration_to_good(iteration) # Retrieve energies at all thermodynamic states energy_thermodynamic_states = np.array(self._storage_analysis.variables['energies'][iteration, :, :], np.float64) # Retrieve neighborhoods, assuming global neighborhoods if reading a pre-neighborhoods file try: energy_neighborhoods = np.array(self._storage_analysis.variables['neighborhoods'][iteration, :, :], 'i1') except KeyError: energy_neighborhoods = np.ones(energy_thermodynamic_states.shape, 'i1') # Read energies at unsampled states, if present try: energy_unsampled_states = np.array(self._storage_analysis.variables['unsampled_energies'][iteration, :, :], np.float64) except KeyError: # There are no unsampled thermodynamic states. unsampled_shape = energy_thermodynamic_states.shape[:-1] + (0,) energy_unsampled_states = np.zeros(unsampled_shape) return energy_thermodynamic_states, energy_neighborhoods, energy_unsampled_states def write_energies(self, energy_thermodynamic_states, energy_neighborhoods, energy_unsampled_states, iteration): """Store the energy matrix at the given iteration on the analysis file Parameters ---------- energy_thermodynamic_states : n_replicas x n_states numpy.ndarray ``energy_thermodynamic_states[i][j]`` is the reduced potential computed at SamplerState ``sampler_states[i]`` and ThermodynamicState ``thermodynamic_states[j]``. energy_neighborhoods : n_replicas x n_states numpy.ndarray energy_neighborhoods[replica_index, state_index] is 1 if the energy was computed for this state, 0 otherwise energy_unsampled_states : n_replicas x n_unsampled_states numpy.ndarray ``energy_unsampled_states[i][j]`` is the reduced potential computed at SamplerState ``sampler_states[i]`` and ThermodynamicState ``unsampled_thermodynamic_states[j]``. iteration : int The iteration at which to store the data. """ n_replicas, n_states = energy_thermodynamic_states.shape # Create dimensions and variables if they weren't created by other functions. self._ensure_dimension_exists('replica', n_replicas) self._ensure_dimension_exists('state', n_states) if 'energies' not in self._storage_analysis.variables: ncvar_energies = self._storage_analysis.createVariable( 'energies', 'f8', ('iteration', 'replica', 'state'), zlib=False, chunksizes=(1, n_replicas, n_states) ) ncvar_energies.units = 'kT' ncvar_energies.long_name = ("energies[iteration][replica][state] is the reduced (unitless) " "energy of replica 'replica' from iteration 'iteration' evaluated " "at the thermodynamic state 'state'.") if 'neighborhoods' not in self._storage_analysis.variables: ncvar_neighborhoods = self._storage_analysis.createVariable( 'neighborhoods', 'i1', ('iteration', 'replica', 'state'), zlib=False, fill_value=1, # old-style files will be upgraded to have all states chunksizes=(1, n_replicas, n_states) ) ncvar_neighborhoods.long_name = ("neighborhoods[iteration][replica][state] is 1 if " "this energy was computed during this iteration.") if 'unsampled_energies' not in self._storage_analysis.variables: # Check if we have unsampled states. if energy_unsampled_states.shape[1] > 0: n_unsampled_states = len(energy_unsampled_states[0]) self._ensure_dimension_exists('unsampled', n_unsampled_states) if 'unsampled_energies' not in self._storage_analysis.variables: # Create variable for thermodynamic state energies with units and descriptions. ncvar_unsampled = self._storage_analysis.createVariable( 'unsampled_energies', 'f8', ('iteration', 'replica', 'unsampled'), zlib=False, chunksizes=(1, n_replicas, n_unsampled_states) ) ncvar_unsampled.units = 'kT' ncvar_unsampled.long_name = ("unsampled_energies[iteration][replica][state] is the reduced " "(unitless) energy of replica 'replica' from iteration 'iteration' " "evaluated at unsampled thermodynamic state 'state'.") # Store values self._storage_analysis.variables['energies'][iteration,:,:] = energy_thermodynamic_states self._storage_analysis.variables['neighborhoods'][iteration,:,:] = energy_neighborhoods if energy_unsampled_states.shape[1] > 0: self._storage_analysis.variables['unsampled_energies'][iteration, :, :] = energy_unsampled_states[:, :] def read_mixing_statistics(self, iteration=slice(None)): """Retrieve the mixing statistics for the given iteration on the analysis file Parameters ---------- iteration : int or slice The iteration(s) at which to read the data. Returns ------- n_accepted_matrix : numpy.ndarray with shape (n_states, n_states) ``n_accepted_matrix[i][j]`` is the number of accepted moves from state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going from ``iteration-1`` to ``iteration`` (not cumulative). n_proposed_matrix : numpy.ndarray with shape (n_states, n_states) ``n_proposed_matrix[i][j]`` is the number of proposed moves from state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going from ``iteration-1`` to ``iteration`` (not cumulative). """ iteration = self._map_iteration_to_good(iteration) n_accepted_matrix = self._storage_analysis.variables['accepted'][iteration, :, :].astype(np.int64) n_proposed_matrix = self._storage_analysis.variables['proposed'][iteration, :, :].astype(np.int64) return n_accepted_matrix, n_proposed_matrix def write_mixing_statistics(self, n_accepted_matrix, n_proposed_matrix, iteration): """Store the mixing statistics for the given iteration on the analysis file Parameters ---------- n_accepted_matrix : numpy.ndarray with shape (n_states, n_states) ``n_accepted_matrix[i][j]`` is the number of accepted moves from state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going from iteration-1 to iteration (not cumulative). n_proposed_matrix : numpy.ndarray with shape (n_states, n_states) ``n_proposed_matrix[i][j]`` is the number of proposed moves from state ``thermodynamic_states[i]`` to ``thermodynamic_states[j]`` going from ``iteration-1`` to ``iteration`` (not cumulative). iteration : int The iteration for which to store the data. """ # Create schema if necessary. if 'accepted' not in self._storage_analysis.variables: n_states = n_accepted_matrix.shape[0] # Create dimension if it doesn't already exist self._ensure_dimension_exists('state', n_states) # Create variables with units and descriptions. ncvar_accepted = self._storage_analysis.createVariable( 'accepted', 'i4', ('iteration', 'state', 'state'), zlib=False, chunksizes=(1, n_states, n_states) ) ncvar_proposed = self._storage_analysis.createVariable( 'proposed', 'i4', ('iteration', 'state', 'state'), zlib=False, chunksizes=(1, n_states, n_states) ) ncvar_accepted.units = 'none' ncvar_proposed.units = 'none' ncvar_accepted.long_name = ("accepted[iteration][i][j] is the number of proposed transitions " "between states i and j from iteration 'iteration-1'.") ncvar_proposed.long_name = ("proposed[iteration][i][j] is the number of proposed transitions " "between states i and j from iteration 'iteration-1'.") # Store statistics. self._storage_analysis.variables['accepted'][iteration, :, :] = n_accepted_matrix[:, :] self._storage_analysis.variables['proposed'][iteration, :, :] = n_proposed_matrix[:, :] def read_timestamp(self, iteration=slice(None)): """Return the timestamp for the given iteration. Read from the analysis file, although there is a paired timestamp on the checkpoint file as well Parameters ---------- iteration : int or slice The iteration(s) at which to read the data. Returns ------- timestamp : str The timestamp at which the iteration was stored. """ iteration = self._map_iteration_to_good(iteration) return self._storage_analysis.variables['timestamp'][iteration] def write_timestamp(self, iteration: int): """Store a timestamp for the given iteration on both analysis and checkpoint file. If the iteration is not on the ``checkpoint_interval``, no timestamp is written on the checkpoint file Parameters ---------- iteration : int The iteration at which to read the data. """ # Create variable if needed. for storage_key, storage in self._storage_dict.items(): if 'timestamp' not in storage.variables: storage.createVariable('timestamp', str, ('iteration',), zlib=False, chunksizes=(1,)) timestamp = time.ctime() self._storage_analysis.variables['timestamp'][iteration] = timestamp checkpoint_iteration = self._calculate_checkpoint_iteration(iteration) if checkpoint_iteration is not None: self._storage_checkpoint.variables['timestamp'][checkpoint_iteration] = timestamp def read_dict(self, path: str) -> Union[dict, Any]: """Restore a dictionary from the storage file. The method supports reading only specific dictionary keywords in path notation. If the dictionary is large, this can be quicker. However, note that, depending on how the dictionary was saved, this may end up reading the whole dictionary anyway. Parameters ---------- path : str The path to the dictionary or a keyword in the dictionary. Returns ------- data : dict or specified data The restored data as a dict, or the data stored at key, depending on `path` Examples -------- >>> import os >>> import openmmtools as mmtools >>> data = {'info': [1, 2, 3]} >>> with mmtools.utils.temporary_directory() as temp_dir_path: ... temp_file_path = os.path.join(temp_dir_path, 'temp.nc') ... reporter = MultiStateReporter(temp_file_path, open_mode='w') ... reporter.write_dict('data', data) ... reporter.read_dict('data') {'info': [1, 2, 3]} """ storage = 'analysis' # Get lowest possible NC variable/group. The path might refer # to the keyword of a dictionary that was saved in a single variable. dict_path = [] nc_element = None while nc_element is None: try: nc_element = self._resolve_nc_path(path, storage) except KeyError: # Try the higher level. path, dict_key = path.rsplit(sep='/', maxsplit=1) dict_path.insert(0, dict_key) # If this is a group, the dictionary has been nested. if isinstance(nc_element, netcdf.Group): data = {} for elements in [nc_element.groups, nc_element.variables]: for key in elements: data.update({key: self.read_dict(path + '/' + key)}) # Otherwise this is a variable. else: if nc_element.dtype == 'S1': # Handle variables stored in fixed_dimensions data_chars = nc_element[:] data_str = data_chars.tostring().decode() else: data_str = str(nc_element[0]) data = yaml.load(data_str, Loader=_DictYamlLoader) # Restore the title in the metadata. if path == 'metadata': data['title'] = self._storage_dict[storage].title # Resolve the rest of the path that was saved un-nested. for dict_key in dict_path: data = data[dict_key] return data def write_dict(self, path: str, data: dict): """Store the contents of a dict. Parameters ---------- path : str The path to the dictionary in the storage file. data : dict The dict to store. """ storage_name = 'analysis' if path == 'metadata': # General NetCDF conventions assume the title of the dataset # to be specified as a global attribute, but the user can # specify their own titles only in metadata. data = copy.deepcopy(data) self._storage_dict[storage_name].title = data.pop('title') # Metadata is pretty big, read-only attribute (it contains the # reference state and the topography), and AlchemicalPhase has # to read the name of the sampler for resuming and checking the # status, so we store it compressed and in nested form. self._write_dict(path, data, storage_name=storage_name, fixed_dimension=True, nested=True) else: self._write_dict(path, data, storage_name=storage_name) def read_checkpoint_iterations(self): """ Utility function to provide all iterations on which a checkpoint was written Returns ------- checkpoints : np.ndarray of int All checkpoints from initial iteration to current """ return np.array(range(0, self.read_last_iteration(last_checkpoint=True)+1, self._checkpoint_interval), dtype=int) def read_last_iteration(self, last_checkpoint=True): """ Read the last iteration from file which was written in sequential order. Parameters ---------- last_checkpoint : bool, optional If True, returns the last checkpoint iteration (default is True). Returns ------- last_iteration : int Last iteration which was sequentially written. """ # Make sure this is returned as Python int. last_iteration = int(self._storage_analysis.variables['last_iteration'][0]) # Get last checkpoint. if last_checkpoint: # -1 for stop ensures the 0th index is searched. for i in range(last_iteration, -1, -1): if self._calculate_checkpoint_iteration(i) is not None: return i raise RuntimeError("Could not find a checkpoint! This should not happen " "as the 0th iteration should always be written! " "Please open a ticket on the YANK GitHub page if you see this error message!") return last_iteration def write_last_iteration(self, iteration): """ Tell the reporter what the last iteration which was written in sequential order was to allow resuming and analysis only on valid data. Call this as the last step of any ``write_iteration``-like routine to ensure analysis will not use junk data left over from an interrupted simulation past the last checkpoint. The reporter is sync'ed both before and after writing the last iteration to ensure minimal data corruption Parameters ---------- iteration : int Iteration at which the last good data point was written. """ self.sync() self._storage_analysis.variables['last_iteration'][0] = iteration self.sync() def read_logZ(self, iteration): """ Read logZ at a given iteration from file. Parameters ---------- iteration : int iteration to read the free energies from if the iteration was not written at a the given iteration, then the free_energies are all 0 Returns ------- logZ : np.array with shape [n_states] Dimensionless logZ """ data = self.read_online_analysis_data(iteration, "logZ") return data['logZ'] def write_logZ(self, iteration: int, logZ: np.ndarray): """ Write logZ Parameters ---------- iteration : int, Iteration at which to save the free energy. Reads the current energy up to this value and stores it in the analysis reporter logZ : np.array with shape [n_states] Dimensionless log Z """ self.write_online_data_dynamic_and_static(iteration, logZ=logZ) def read_online_analysis_data(self, iteration, *keys: str): """ Parameters ---------- iteration : int or None Iteration to fetch data at. If ``None``, then assumes static data and will attempt to get the entry with the name written assuming no iteration-specific data. keys : str Variables to fetch data from Returns ------- online_analysis_data : dict Data requested by *keys argument from online analysis, if they exist on disk Warnings -------- RuntimeWarning : If some keys were not found as requested Raises ------ ValueError : If no requested keys were found in the storage or if no online analysis data was written """ collected_variables = {} collected_iteration_failure = [] collected_not_found = [] try: storage = self._storage_analysis.groups["online_analysis"] except KeyError: raise ValueError("Online Analysis information was never written!") for variable in keys: try: data = self._read_1d_online_data(iteration, variable, storage) collected_variables[variable] = data except IndexError: if self._find_alternate_variable(iteration, variable, storage): collected_iteration_failure.append(variable) else: collected_not_found.append(variable) # Nothing found if not collected_variables and not collected_iteration_failure: raise ValueError("None of the requested keys could be found on disk!") # Found some things possibly named wrong, still nothing to return elif not collected_variables: base_error = ("No variables found on disk with{} per-iteration data, but we did find the following " "variables of the same name with{} per-iteration data. Possibly you meant those?" ) for failure in collected_iteration_failure: base_error += "\n\t-{}".format(failure) if iteration is None: raise ValueError(base_error.format("out", "")) else: raise ValueError(base_error.format("", "out")) elif collected_iteration_failure or collected_not_found: base_warn = ("Some requested variables were found, others were missing or found on disk under {}per" "-iteration data:") if iteration is None: iteration_str = "" else: iteration_str = "non-" base_warn = base_warn.format(iteration_str) for failure in collected_iteration_failure: base_warn += "\n\t{}per-iteration: {}".format(iteration_str, failure) for missing in collected_not_found: base_warn += "\n\tMissing: {}".format(missing) warnings.warn(base_warn, RuntimeWarning) return collected_variables def write_online_analysis_data(self, iteration: Union[int, None], **kwargs): """ Write semi-arbitrary 1-D numeric online analysis data to storage with optional per-iteration flag. This function helps generalize what information is stored by any given reporter, while still enforcing a regular input and output. The logic of what to store and how is similar, but not exact to the :func:`write_dict`. ``iteration`` accepts an integer as to indicate this this information should be written on a per-iteration basis. The iteration the data are written to is the integer argument. Pass ``None`` if this is *not* per-iteration data and stored independent of the iteration dimension ``**kwargs`` are processed as the variable/value pairs to store and there must be *at least one* This should be 1-D or scalar numerical value (e.g. numpy array, list, or tuple of numbers; NOT string, dict, etc.). Type is inferred from the first value of data input. Parameters ---------- iteration : int or None Optional iteration to write the data under, if ``None``, the variables will not be written on a per-iteration basis kwargs : pairs of name:value of numeric 1-D or scalar data Name of variable and value to write to disk Raises ------ TypeError : If no values are given to ``**kwargs`` ValueError : If ``iteration`` is not an integer """ self._resolve_iteration_args(iteration) self._resolve_kwargs_exist(kwargs) group = self._ensure_group_exists_and_get("online_analysis") for name, value in kwargs.items(): self._write_1d_online_data(iteration, name, value, group) def write_online_data_dynamic_and_static(self, iteration: int, **kwargs): """ Helper function to do a :func:`write_online_analysis_data` call twice, both with and without setting iteration. See Also -------- write_online_analysis_data """ self.write_online_analysis_data(None, **kwargs) self.write_online_analysis_data(iteration, **kwargs) def write_current_statistics(self, data): """ Write real time YAML file with analysis data. A real_time_analysis.yaml file will be generated in the same directory for the reporter netcdf file (see :func:`~multistatereporter.MultiStateReporter` for more information). Overwrites file if it already exists. Parameters ---------- data: dict Dictionary with the key, value pairs to store in YAML format. """ reporter_dir, reporter_filename = os.path.split(self._storage_analysis_file_path) # remove extension from filename yaml_prefix = os.path.splitext(reporter_filename)[0] output_filepath = os.path.join(reporter_dir, f"{yaml_prefix}_real_time_analysis.yaml") # Remove if it is a fresh reporter session if self._overwrite_statistics: try: os.remove(output_filepath) except FileNotFoundError: pass self._overwrite_statistics = False # Append from now on with open(output_filepath, "a") as out_file: out_file.write(yaml.dump([data], sort_keys=False)) # ------------------------------------------------------------------------- # Internal-usage. # ------------------------------------------------------------------------- def _write_1d_online_data(self, iteration, variable, data, storage): """Store data on disk given pre-calculated parameters""" if iteration is not None: variable = variable + "_history" if variable not in storage.variables: variable_parameters = self._determine_netcdf_variable_parameters(iteration, data, storage) logger.debug('Creating new NetCDF variable %s with parameters: %s' % (variable, variable_parameters)) # DEBUG storage.createVariable(variable, variable_parameters['dtype'], dimensions=variable_parameters['dims'], chunksizes=variable_parameters['chunksizes'], zlib=False) # Get the variable nc_var = storage[variable] # Only get the specific iteration if specified if iteration is not None: nc_var[iteration, :] = data else: nc_var[:] = data @staticmethod def _find_alternate_variable(iteration, variable, storage): """Helper function to figure out what went wrong when data not found""" iter_var = variable + "_history" if iteration is None and iter_var in storage.variables: return True elif iteration is not None and variable in storage.variables: return True return False @staticmethod def _read_1d_online_data(iteration, variable, storage): """Read data on disk given storage object Returns ------- data """ if iteration is not None: variable = variable + "_history" nc_var = storage[variable] nc_data = nc_var if iteration is not None: nc_data = nc_data[iteration] data = nc_data[:] if nc_var.dimensions[-1] == "scalar": return data[0] else: return data def _determine_netcdf_variable_parameters(self, iteration, data, storage): """ Pre-determine the variable information needed to create the variable on the storage layer """ if np.isscalar(data): # Scalar data size = 1 try: dtype = data.dtype # numpy except AttributeError: dtype = type(data) # python else: # Array data size = len(data) try: dtype = data.dtype # numpy except AttributeError: dtype = type(data[0]) # python data_dim = "dim_size{}".format(size) self._ensure_dimension_exists(data_dim, size, storage=storage) if iteration is not None: dims = ("iteration", data_dim) chunks = (1, size) else: dims = (data_dim,) chunks = (size,) return {'dtype': dtype, 'dims': dims, 'chunksizes': chunks} @staticmethod def _resolve_iteration_args(iteration_arg): """ Ensure iterations given as iterations are integer or None """ err_message = "Only an int or None is allowed for iteration" # Ensures int check if its a core int, np.int32, np.int64, or signed/unsigned variants if iteration_arg is not None and not np.issubdtype(type(iteration_arg), np.integer): raise ValueError(err_message) @staticmethod def _resolve_kwargs_exist(kwargs): """ Ensure keyword args exist (at least 1) """ if len(kwargs) == 0: raise TypeError("There must be at least 1 keyword arg!") def _resolve_nc_path(self, path, storage): """Return the NC group or variable at the end of the path. This can be used to retrieve groups or variables that are nested inside one or more groups. """ path_split = path.split('/') nc_group = self._storage_dict[storage] for group_name in path_split[:-1]: nc_group = nc_group.groups[group_name] # Check if this is a path to a group or a variable. try: return nc_group.groups[path_split[-1]] except KeyError: return nc_group.variables[path_split[-1]] def _calculate_checkpoint_iteration(self, iteration): """Compute the iteration on disk of the checkpoint file matching the iteration linked on the analysis iteration. Although this is a simple function, it provides a common function for calculation Returns either the integer index, or None if there is no matched index """ checkpoint_index, remainder = divmod(iteration, self._checkpoint_interval) if remainder == 0: # NetCDF variables can't be assigned using numpy integers. return int(checkpoint_index) return None def _map_iteration_to_good(self, iteration): """ Convert the iteration in 'read_X' functions which take a iteration=slice(None) to avoid returning a slice of data which goes past the last_iteration. This effectively sets the max iteration to the last_good_iteration. Parameters ---------- iteration : int or Slice Iteration to feed into the check Returns ------- cast_iteration : int or Slice of type iteration Iteration, converted as needed to only access certain ranges of data """ # Calculate last stored iteration last_good = self.read_last_iteration(last_checkpoint=False) # Create the artificial index map artificial_map = np.arange(last_good + 1, dtype=int) # Generate true index map from the input cast_iteration = artificial_map[iteration] return cast_iteration def _ensure_dimension_exists(self, dim_name, dim_size, storage=None): """ Ensure a dimension exists and is of the appropriate size, creating it if it does not already exist. A ``ValueError`` is raised if ``dim_size`` does not match the existing dimension size. Parameters ---------- dim_name : str The dimension name dim_size : int The dimension size storage : netCDF4.Dataset or netCDF4.Group, optional, default: None Storage layer to check the dimension against. If none, the _storage_analysis is used """ if storage is None: storage = self._storage_analysis if dim_name not in storage.dimensions: storage.createDimension(dim_name, dim_size) else: # Check dimension matches expected size dimension = storage.dimensions[dim_name] if dim_size == 0: if not dimension.isunlimited(): raise ValueError("NetCDF dimension {} already exists: was previously unlimited, but tried to " "redeclare it with size {}".format(dimension.name, dim_size)) else: if not int(dimension.size) == int(dim_size): raise ValueError("NetCDF dimension {} already exists: was previously size {}, but tried to " "redeclare it with dimension {}".format(dimension.name, dimension.size, dim_size)) def _ensure_group_exists_and_get(self, group_name, storage=None): """ Ensure a group exists and fetch it if it does, creating first if it does not. A ``ValueError`` is raised if ``dimsize`` does not match the existing dimension size. Parameters ---------- group_name : str The group name storage : known storage object or None, default None Storage object to check against, if None, assumes the analysis storage """ if storage is None: storage = self._storage_analysis if group_name not in storage.groups: storage.createGroup(group_name) return storage.groups[group_name] @staticmethod def _initialize_sampler_variables_on_file(dataset, n_atoms, n_replicas, is_periodic): """ Initialize the NetCDF variables on the storage file needed to store sampler states. Does nothing if file already initialized Parameters ---------- dataset : NetCDF4 Dataset Dataset to validate n_atoms : int Number of atoms which will be stored n_replicas : int Number of Sampler states which will be written is_periodic : bool True if system is periodic; False otherwise. """ if 'positions' not in dataset.variables: # Create dimensions. Replica dimension could have been created before. dataset.createDimension('atom', n_atoms) if 'replica' not in dataset.dimensions: dataset.createDimension('replica', n_replicas) # Define position variables. ncvar_positions = dataset.createVariable('positions', 'f4', ('iteration', 'replica', 'atom', 'spatial'), zlib=True, chunksizes=(1, n_replicas, n_atoms, 3)) ncvar_positions.units = 'nm' ncvar_positions.long_name = ("positions[iteration][replica][atom][spatial] is position of " "coordinate 'spatial' of atom 'atom' from replica 'replica' for " "iteration 'iteration'.") # Define velocities variables. ncvar_velocities = dataset.createVariable('velocities', 'f4', ('iteration', 'replica', 'atom', 'spatial'), zlib=True, chunksizes=(1, n_replicas, n_atoms, 3)) ncvar_velocities.units = 'nm / ps' ncvar_velocities.long_name = ("velocities[iteration][replica][atom][spatial] is velocity of " "coordinate 'spatial' of atom 'atom' from replica 'replica' for " "iteration 'iteration'.") # Define variables for periodic systems if is_periodic: ncvar_box_vectors = dataset.createVariable('box_vectors', 'f4', ('iteration', 'replica', 'spatial', 'spatial'), zlib=False, chunksizes=(1, n_replicas, 3, 3)) ncvar_volumes = dataset.createVariable('volumes', 'f8', ('iteration', 'replica'), zlib=False, chunksizes=(1, n_replicas)) ncvar_box_vectors.units = 'nm' ncvar_volumes.units = 'nm**3' ncvar_box_vectors.long_name = ("box_vectors[iteration][replica][i][j] is dimension j of box " "vector i for replica 'replica' from iteration " "'iteration-1'.") ncvar_volumes.long_name = ("volume[iteration][replica] is the box volume for replica " "'replica' from iteration 'iteration-1'.") def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: int, storage_file='checkpoint', obey_checkpoint_interval=True): """ Internal function to handle writing sampler states more generically to target file Parameters ---------- sampler_states : list of states.SamplerStates The sampler states to store for each replica. iteration : int The iteration at which to store the data. storage_file : string, Optional, Default: 'checkpoint' Name of storage file we're writing to. Must match a valid key of self._storage_dict obey_checkpoint_interval : bool, Optional, Default: False Tells this (attempted) write to obey the checkpoint interval or not. If True, no write out will be done if iteration is not on the checkpoint interval If False, the write WILL occur """ storage = self._storage_dict[storage_file] # Check if the schema must be initialized, do this regardless of the checkpoint_interval for consistency is_periodic = True if (sampler_states[0].box_vectors is not None) else False n_particles = sampler_states[0].n_particles n_replicas = len(sampler_states) self._initialize_sampler_variables_on_file(storage, n_particles, n_replicas, is_periodic) if obey_checkpoint_interval: write_iteration = self._calculate_checkpoint_iteration(iteration) else: write_iteration = iteration # Write the sampler state if we are on the checkpoint interval OR if told to ignore the interval if write_iteration is not None: # Store sampler states. # Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments positions = np.zeros([n_replicas, n_particles, 3]) for replica_index, sampler_state in enumerate(sampler_states): # Store positions in memory first x = sampler_state.positions / unit.nanometers positions[replica_index, :, :] = x[:, :] # Store positions storage.variables['positions'][write_iteration, :, :, :] = positions # Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments velocities = np.zeros([n_replicas, n_particles, 3]) for replica_index, sampler_state in enumerate(sampler_states): if sampler_state._unitless_velocities is not None: # Store velocities in memory first x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities velocities[replica_index, :, :] = x[:, :] # Store velocites # TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored # sampler_state different from origin. if 'velocities' not in storage.variables: # create variable with expected dimensions and shape storage.createVariable('velocities', storage.variables['positions'].dtype, dimensions=storage.variables['positions'].dimensions) storage.variables['velocities'][write_iteration, :, :, :] = velocities if is_periodic: # Store box vectors and volume. # Allocate whole write to memory first box_vectors = np.zeros([n_replicas, 3, 3]) volumes = np.zeros([n_replicas]) for replica_index, sampler_state in enumerate(sampler_states): box_vectors[replica_index, :, :] = sampler_state.box_vectors / unit.nanometers volumes[replica_index] = sampler_state.volume / unit.nanometers ** 3 storage.variables['box_vectors'][write_iteration, :, :, :] = box_vectors storage.variables['volumes'][write_iteration, :] = volumes else: logger.debug("Iteration {} not on the Checkpoint Interval of {}. " "Sampler State not written.".format(iteration, self._checkpoint_interval)) def _read_sampler_states_from_given_file(self, iteration, storage_file='checkpoint', obey_checkpoint_interval=True): """ Internal function to handle reading sampler states more from a general storage file Parameters ---------- iteration : int Iteration on which to read data from storage_file : string, Optional, Default: 'checkpoint' Name of storage file we're writing to. Must match a valid key of self._storage_dict obey_checkpoint_interval : bool, Optional, Default: False Tells this (attempted) write to obey the checkpoint interval or not. If True, no write out will be done if iteration is not on the checkpoint interval If False, the read will be attempted regardless WARNING: If the storage file you specify was written on the checkpoint interval and you set obey_checkpoint_interval=False, you will get undefined behavior! Returns ------- sampler_states : list of SamplerStates or None The previously stored sampler states for each replica. If the iteration is not on the checkpoint_interval and the file only writes on the checkpoint_interval, None is returned """ storage = self._storage_dict[storage_file] read_iteration = self._map_iteration_to_good(iteration) if obey_checkpoint_interval: read_iteration = self._calculate_checkpoint_iteration(iteration) if read_iteration is not None: # TODO: Restore n_replicas instead n_replicas = storage.dimensions['replica'].size sampler_states = list() for replica_index in range(n_replicas): # Restore positions. x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64) positions = unit.Quantity(x, unit.nanometers) # Restore velocities # try-catch exception, enabling reading legacy/older serialized objects from openmmtools<0.21.3 try: x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64) velocities = unit.Quantity(x, unit.nanometer / unit.picoseconds) except KeyError: # Velocities key/variable not found in serialization (openmmtools<=0.21.2) # pass zeros as velocities when key is not found (<0.21.3 behavior) velocities = np.zeros_like(positions) if 'box_vectors' in storage.variables: # Restore box vectors. x = storage.variables['box_vectors'][read_iteration, replica_index, :, :].astype(np.float64) box_vectors = unit.Quantity(x, unit.nanometers) else: box_vectors = None # Create SamplerState. sampler_states.append(states.SamplerState(positions=positions, velocities=velocities, box_vectors=box_vectors)) return sampler_states else: return None def _write_dict(self, path, data, storage_name='analysis', fixed_dimension=False, nested=False): """Store the contents of a dict. Parameters ---------- path : str The path to the dictionary in the storage file. data : dict The dict to store. storage_name : 'analysis' or 'checkpoint' The name of the storage file where to save the dict. fixed_dimension : bool, default: False Use a fixed length dimension instead of variable length one. This method seems to allow NetCDF to actually compress strings. A unique dimension name called ``"fixedL{}".format(len(data))`` will be created. Do NOT use this flag if you expect to constantly be changing the length of the data fed in. Use only for static data. nested : bool, default False In nested representation, dictionaries are represented as groups, and values as strings. In this mode, it is possible to retrieve a keyword without reading the whole dictionary. In this mode, the dictionary can be overwritten, but ONLY if the structure of the dict doesn't change, as it's impossible to delete groups/variables from a netcdf database. """ storage_nc = self._storage_dict[storage_name] # Save nested dictionary into a group if requested, unless the dictionary is empty. if nested and isinstance(data, dict) and len(data) > 0: for key, value in data.items(): if not isinstance(key, str): raise ValueError('Cannot store dict in nested form with non-string keys.') self._write_dict(path + '/' + key, value, storage_name, fixed_dimension, nested) return # Activate flow style to save space. data_str = yaml.dump(data, Dumper=_DictYamlDumper) # Check if we are updating the dictionary or creating it. try: nc_variable = self._resolve_nc_path(path, storage_name) except KeyError: if fixed_dimension: variable_type = 'S1' dimension_name = "fixedL{}".format(len(data_str)) # Create a new fixed dimension if necessary. if dimension_name not in storage_nc.dimensions: storage_nc.createDimension(dimension_name, len(data_str)) else: variable_type = str dimension_name = 'scalar' # Create variable. nc_variable = storage_nc.createVariable(path, variable_type, dimension_name, zlib=False) # Assign the value to the variable. if fixed_dimension: packed_data = np.array(list(data_str), dtype='S1') else: packed_data = np.empty(1, 'O') packed_data[0] = data_str nc_variable[:] = packed_data
# ============================================================================== # MODULE INTERNAL USAGE UTILITIES # ============================================================================== class _DictYamlLoader(yaml.CLoader): """PyYAML Loader that reads !Quantity tags.""" def __init__(self, *args, **kwargs): super(_DictYamlLoader, self).__init__(*args, **kwargs) self.add_constructor(u'!Quantity', self.quantity_constructor) self.add_constructor(u'!ndarray', self.ndarray_constructor) @staticmethod def quantity_constructor(loader, node): loaded_mapping = loader.construct_mapping(node) data_unit = quantity_from_string(loaded_mapping['unit']) data_value = loaded_mapping['value'] return data_value * data_unit @staticmethod def ndarray_constructor(loader, node): loaded_mapping = loader.construct_mapping(node, deep=True) data_type = np.dtype(loaded_mapping['type']) data_shape = loaded_mapping['shape'] data_values = loaded_mapping['values'] data = np.ndarray(shape=data_shape, dtype=data_type) if 0 not in data_shape: data[:] = data_values return data class _DictYamlDumper(yaml.CDumper): """PyYAML Dumper that handle openmm Quantities through !Quantity tags.""" def __init__(self, *args, **kwargs): super(_DictYamlDumper, self).__init__(*args, **kwargs) self.add_representer(unit.Quantity, self.quantity_representer) self.add_representer(np.ndarray, self.ndarray_representer) @staticmethod def quantity_representer(dumper, data): data_unit = data.unit data_value = data / data_unit data_dump = dict(unit=str(data_unit), value=data_value) return dumper.represent_mapping(u'!Quantity', data_dump) @staticmethod def ndarray_representer(dumper, data): """Convert a numpy array to native Python types.""" data_type = str(data.dtype) data_shape = data.shape data_values = data.tolist() data_dump = dict(type=data_type, shape=data_shape, values=data_values) return dumper.represent_mapping(u'!ndarray', data_dump)