Source code for openmmtools.cache

#!/usr/bin/env python

# =============================================================================
# =============================================================================

Provide cache classes to handle creation of OpenMM Context objects.


# =============================================================================
# =============================================================================

import re
import copy
import collections

    import openmm
    from openmm import unit
except ImportError:  # OpenMM < 7.6
    from simtk import openmm, unit

from openmmtools import integrators

# =============================================================================
# =============================================================================

[docs]class LRUCache(object): """A simple LRU cache with a dictionary-like interface that supports maximum capacity and expiration. It can be configured to have a maximum number of elements (capacity) and an element expiration (time_to_live) measured in number of accesses to the cache. Both read and write operations count as an access, but only successful reads (i.e. those not raising KeyError) increment the counter. Parameters ---------- capacity : int, optional Maximum number of elements in the cache. When set to None, the cache has infinite capacity (default is None). time_to_live : int, optional If an element is not accessed after time_to_live read/write operations, the element is removed. When set to None, elements do not have an expiration (default is None). Attributes ---------- capacity time_to_live Examples -------- >>> cache = LRUCache(capacity=2, time_to_live=3) When the capacity is exceeded, the least recently used element is removed. >>> cache['1'] = 1 >>> cache['2'] = 2 >>> elem = cache['1'] # read '1', now '2' is the least recently used >>> cache['3'] = 3 >>> len(cache) 2 >>> '2' in cache False After time_to_live read/write operations an element is deleted if it is not used. >>> elem = cache['3'] # '3' is used, counter is reset >>> elem = cache['1'] # access 1 >>> elem = cache['1'] # access 2 >>> elem = cache['1'] # access 3 >>> len(cache) 1 >>> '3' in cache False """
[docs] def __init__(self, capacity=None, time_to_live=None): self._data = collections.OrderedDict() self._capacity = capacity self._ttl = time_to_live self._n_access = 0
@property def capacity(self): """Maximum number of elements that can be cached. If None, the capacity is unlimited. """ return self._capacity @capacity.setter def capacity(self, new_capacity): # Remove excess elements while len(self._data) > new_capacity: self._data.popitem(last=False) self._capacity = new_capacity @property def time_to_live(self): """Number of read/write operations before an cached element expires. If None, elements have no expiration. """ return self._ttl @time_to_live.setter def time_to_live(self, new_time_to_live): # Update entries only if we are changing the ttl. if new_time_to_live == self._ttl: return # Update expiration of cache entries. for entry in self._data.values(): # If there was no time to live before, just let entries # expire in new_time_to_live accesses if self._ttl is None: entry.expiration = self._n_access + new_time_to_live # If we don't want expiration anymore, delete the field. # This way we save memory in case there are a lot of entries. elif new_time_to_live is None: del entry.expiration # Otherwise just add/subtract the difference. else: entry.expiration += new_time_to_live - self._ttl # Purge cache only if there is a time to live. if new_time_to_live is not None: self._remove_expired() self._ttl = new_time_to_live def empty(self): """Purge the cache.""" self._data = collections.OrderedDict() def __getitem__(self, key): # When we access data, push element at the # end to make it the most recently used. entry = self._data.pop(key) self._data[key] = entry # We increment the number of accesses only on successful reads. self._n_access += 1 # Update expiration and cleanup expired values. if self._ttl is not None: entry.expiration = self._n_access + self._ttl self._remove_expired() return entry.value def __setitem__(self, key, value): self._n_access += 1 # When we access data, push element at the # end to make it the most recently used. try: self._data.pop(key) except KeyError: # Remove first item if we hit maximum capacity. if self._capacity is not None and len(self._data) >= self._capacity: self._data.popitem(last=False) # Determine expiration and clean up expired. if self._ttl is None: ttl = None else: ttl = self._ttl + self._n_access self._remove_expired() self._data[key] = _CacheEntry(value, ttl) def __len__(self): return len(self._data) def __contains__(self, item): return item in self._data def __iter__(self): return self._data.__iter__() def _remove_expired(self): """Remove all expired cache entries. Assumes that entries were created with an expiration attribute. """ keys_to_remove = set() for key, entry in self._data.items(): if entry.expiration <= self._n_access: keys_to_remove.add(key) else: # Later entries have been accessed later # and they surely haven't expired yet. break for key in keys_to_remove: del self._data[key]
# ============================================================================= # GENERAL CONTEXT CACHE # =============================================================================
[docs]class ContextCache(object): """LRU cache hosting the minimum amount of incompatible Contexts. Two Contexts are compatible if they are in a compatible ThermodynamicState, and have compatible integrators. In general, two integrators are compatible if they have the same serialized state, but ContextCache can decide to store a single Context to optimize memory when two integrators differ by only few parameters that can be set after the Context is initialized. These parameters include all the global variables defined by a ``CustomIntegrator``. You can force ``ContextCache`` to consider an integrator global variable incompatible by adding it to the blacklist ``ContextCache.INCOMPATIBLE_INTEGRATOR_ATTRIBUTES``. Similarly, you can add other attributes that should be considered compatible through the whitelist ``ContextCache.COMPATIBLE_INTEGRATOR_ATTRIBUTES``. If an attribute in that dictionary is not found in the integrator, the cache will search for a corresponding getter and setter. Parameters ---------- platform : openmm.Platform, optional The OpenMM platform to use to create Contexts. If None, OpenMM tries to select the fastest one available (default is None). platform_properties : dict, optional A dictionary of platform properties for the OpenMM platform. Only valid if the platform is not None (default is None). **kwargs Parameters to pass to the underlying LRUCache constructor such as capacity and time_to_live. Attributes ---------- platform capacity time_to_live Warnings -------- Python instance attributes are not copied when ``ContextCache.get_context()`` is called. You can force this by setting adding them to the whitelist ``ContextCache.COMPATIBLE_INTEGRATOR_ATTRIBUTES``, but if modifying your Python attributes won't modify the OpenMM serialization, this will likely cause problems so this is discouraged unless you know exactly what you are doing. Examples -------- >>> from openmm import unit >>> from openmmtools import testsystems >>> from openmmtools.states import ThermodynamicState >>> alanine = testsystems.AlanineDipeptideExplicit() >>> thermodynamic_state = ThermodynamicState(alanine.system, 310*unit.kelvin) >>> time_step = 1.0*unit.femtosecond Two compatible thermodynamic states generate only a single cached Context. ContextCache can also (in few explicitly supported cases) recycle the same Context even if the integrators differ by some parameters. >>> context_cache = ContextCache() >>> context1, integrator1 = context_cache.get_context(thermodynamic_state, ... openmm.VerletIntegrator(time_step)) >>> thermodynamic_state.temperature = 300*unit.kelvin >>> time_step2 = 2.0*unit.femtosecond >>> context2, integrator2 = context_cache.get_context(thermodynamic_state, ... openmm.VerletIntegrator(time_step2)) >>> id(context1) == id(context2) True >>> len(context_cache) 1 When we switch to NPT the states are not compatible and so neither the Contexts are. >>> integrator2 = openmm.VerletIntegrator(2.0*unit.femtosecond) >>> thermodynamic_state_npt = copy.deepcopy(thermodynamic_state) >>> thermodynamic_state_npt.pressure = 1.0*unit.atmosphere >>> context3, integrator3 = context_cache.get_context(thermodynamic_state_npt, ... openmm.VerletIntegrator(time_step)) >>> id(context1) == id(context3) False >>> len(context_cache) 2 You can set a capacity and a time to live for contexts like in a normal LRUCache. >>> context_cache = ContextCache(capacity=1, time_to_live=5) >>> context2, integrator2 = context_cache.get_context(thermodynamic_state, ... openmm.VerletIntegrator(time_step)) >>> context3, integrator3 = context_cache.get_context(thermodynamic_state_npt, ... openmm.VerletIntegrator(time_step)) >>> len(context_cache) 1 See Also -------- LRUCache states.ThermodynamicState.is_state_compatible """
[docs] def __init__(self, platform=None, platform_properties=None, **kwargs): self._validate_platform_properties(platform, platform_properties) self._platform = platform self._platform_properties = platform_properties self._lru = LRUCache(**kwargs)
def __len__(self): return len(self._lru) @property def platform(self): """The OpenMM platform to use to create Contexts. If None, OpenMM tries to select the fastest one available. This can be set only if the cache is empty. """ return self._platform @platform.setter def platform(self, new_platform): if len(self._lru) > 0: raise RuntimeError('Cannot change platform of a non-empty ContextCache') if new_platform is None: self._platform_properties = None self._validate_platform_properties(new_platform, self._platform_properties) self._platform = new_platform def set_platform(self, new_platform, platform_properties=None): if len(self._lru) > 0: raise RuntimeError('Cannot change platform of a non-empty ContextCache') self._validate_platform_properties(new_platform, platform_properties) self._platform = new_platform self._platform_properties = platform_properties @property def capacity(self): """The maximum number of Context cached. If None, the capacity is unlimited. """ return self._lru.capacity @capacity.setter def capacity(self, new_capacity): self._lru.capacity = new_capacity @property def time_to_live(self): """The Contexts expiration date in number of accesses to the LRUCache. If None, Contexts do not expire. """ return self._lru.time_to_live @time_to_live.setter def time_to_live(self, new_time_to_live): self._lru.time_to_live = new_time_to_live def empty(self): """Clear up cache and remove all Contexts.""" self._lru.empty() def get_context(self, thermodynamic_state, integrator=None): """Return a context in the given thermodynamic state. In general, the Context must be considered newly initialized. This means that positions and velocities must be set afterwards. If the integrator is not provided, this will search the cache for any Context in the given ThermodynamicState, regardless of its integrator. In this case, the method guarantees that two consecutive calls with the same thermodynamic state will retrieve the same context. This creates a new Context if no compatible one has been cached. If a compatible Context exists, the ThermodynamicState is applied to it, and the Context integrator state is changed to match the one passed as an argument. As a consequence, the returned integrator is guaranteed to be in the same state as the one provided, but it can be a different instance. This is to minimize the number of Contexts objects cached that use the same or very similar integrator. Parameters ---------- thermodynamic_state : states.ThermodynamicState The thermodynamic state of the system. integrator : openmm.Integrator, optional The integrator for the context (default is None). Returns ------- context : openmm.Context The context in the given thermodynamic system. context_integrator : openmm.Integrator The integrator to be used to propagate the Context. Can be a difference instance from the one passed as an argument. Warnings -------- Python instance attributes are not copied when ``get_context()`` is called. You can force this by setting adding them to the whitelist ``ContextCache.COMPATIBLE_INTEGRATOR_ATTRIBUTES``, but if modifying the attributes won't modify the OpenMM serialization, this will likely cause problems so this is discouraged unless you know exactly what you're doing. """ context = None # If the user requires a specific integrator, look for one that matches. if integrator is None: thermodynamic_state_id = self._generate_state_id(thermodynamic_state) matching_context_ids = [context_id for context_id in self._lru if context_id[0] == thermodynamic_state_id] if len(matching_context_ids) == 0: # We have to create a new Context. integrator = self._get_default_integrator(thermodynamic_state.temperature) elif len(matching_context_ids) == 1: # Only one match. context = self._lru[matching_context_ids[0]] else: # Multiple matches, prefer the non-default Integrator. # Always pick the least recently used to make two consective # calls retrieving the same integrator. for context_id in reversed(matching_context_ids): if context_id[1] != self._default_integrator_id(): context = self._lru[context_id] break if context is None: # Determine the Context id matching the pair state-integrator. context_id = self._generate_context_id(thermodynamic_state, integrator) # Search for previously cached compatible Contexts or create new one. try: context = self._lru[context_id] except KeyError: context = thermodynamic_state.create_context(integrator, self._platform, self._platform_properties) self._lru[context_id] = context context_integrator = context.getIntegrator() # Update state of system and integrator of the cached context. # We don't have to copy the state of the integrator if the user # didn't ask for a specific one. if integrator is not None: self._copy_integrator_state(integrator, context_integrator) thermodynamic_state.apply_to_context(context) return context, context_integrator def __getstate__(self): # this serialization format was introduced in openmmtools > 0.18.3 (pull request #437) if self.platform is not None: platform_serialization = self.platform.getName() else: platform_serialization = None return dict(platform=platform_serialization, capacity=self.capacity, time_to_live=self.time_to_live, platform_properties=self._platform_properties) def __setstate__(self, serialization): # this serialization format was introduced in openmmtools > 0.18.3 (pull request #437) if serialization['platform'] is None: self._platform = None else: self._platform = openmm.Platform.getPlatformByName(serialization['platform']) if not 'platform_properties' in serialization: self._platform_properties = None else: self._platform_properties = serialization["platform_properties"] self._lru = LRUCache(serialization['capacity'], serialization['time_to_live']) def __eq__(self, other): """Two ContextCache objects are equal if they have the same values in their public attributes.""" # Check types are compatible if isinstance(other, ContextCache): # Check all inner public attributes have the same values (excluding methods) my_inner_attrs = [attr_ for attr_ in dir(self) if not attr_.startswith('_') and not callable(getattr(self, attr_))] return all([getattr(self, attr) == getattr(other, attr) for attr in my_inner_attrs]) else: return False # ------------------------------------------------------------------------- # Internal usage # ------------------------------------------------------------------------- # Each element is the name of the integrator attribute used before # get/set, and its standard value used to check for compatibility. COMPATIBLE_INTEGRATOR_ATTRIBUTES = { 'StepSize': 0.001, 'ConstraintTolerance': 1e-05, 'Temperature': 273, 'Friction': 5, 'RandomNumberSeed': 0, } INCOMPATIBLE_INTEGRATOR_ATTRIBUTES = { '_restorable__class_hash', } @classmethod def _check_integrator_compatibility_configuration(cls): """Verify that the user didn't specify the same attributes as both compatible and incompatible.""" shared_attributes = set(cls.COMPATIBLE_INTEGRATOR_ATTRIBUTES) shared_attributes = shared_attributes.intersection(cls.INCOMPATIBLE_INTEGRATOR_ATTRIBUTES) if len(shared_attributes) != 0: raise RuntimeError('These integrator attributes have been specified both as ' 'compatible and incompatible: {}'.format(shared_attributes)) @classmethod def _set_integrator_compatible_variables(cls, integrator, reference_value): """Set all the global variables to the specified reference. If the argument reference_value is another integrator, the global variables will be copied. If integrator is not a CustomIntegrator, the function has no effect. The function doesn't copy the global variables that are included in the blacklist INCOMPATIBLE_INTEGRATOR_ATTRIBUTES. """ # Check if the integrator has no global variables. try: n_global_variables = integrator.getNumGlobalVariables() except AttributeError: return # Check if we'll have to copy the values from a reference integrator. is_reference_integrator = isinstance(reference_value, integrator.__class__) for global_variable_idx in range(n_global_variables): # Do not set variables that should be incompatible. global_variable_name = integrator.getGlobalVariableName(global_variable_idx) if global_variable_name in cls.INCOMPATIBLE_INTEGRATOR_ATTRIBUTES: continue # Either copy the value from the reference integrator or just set it. if is_reference_integrator: value = reference_value.getGlobalVariable(global_variable_idx) else: value = reference_value integrator.setGlobalVariable(global_variable_idx, value) @classmethod def _copy_integrator_state(cls, copied_integrator, integrator): """Copy the compatible parameters of copied_integrator to integrator. Simply using __getstate__ and __setstate__ doesn't work because __setstate__ set also the bound Context. We can assume the two integrators are of the same class since get_context() found that they match the has. """ # Check that there are no contrasting settings for the attribute compatibility. cls._check_integrator_compatibility_configuration() # Restore temperature getter/setter before copying attributes. integrators.ThermostatedIntegrator.restore_interface(integrator) integrators.ThermostatedIntegrator.restore_interface(copied_integrator) assert integrator.__class__ == copied_integrator.__class__ # Copy all compatible global variables. cls._set_integrator_compatible_variables(integrator, copied_integrator) # Copy other compatible attributes through getters/setters. for attribute in cls.COMPATIBLE_INTEGRATOR_ATTRIBUTES: try: # getter/setter value = getattr(copied_integrator, 'get' + attribute)() except AttributeError: pass else: # getter/setter getattr(integrator, 'set' + attribute)(value) @classmethod def _standardize_integrator(cls, integrator): """Return a standard copy of the integrator. This is used to determine if the same context can be used with different integrators that differ by only compatible parameters. """ # Check that there are no contrasting settings for the attribute compatibility. cls._check_integrator_compatibility_configuration() standard_integrator = copy.deepcopy(integrator) integrators.ThermostatedIntegrator.restore_interface(standard_integrator) # Set all compatible global variables to 0, except those in the blacklist. cls._set_integrator_compatible_variables(standard_integrator, 0.0) # Copy other compatible attributes through getters/setters overwriting # eventual global variables with a different standard value. for attribute, std_value in cls.COMPATIBLE_INTEGRATOR_ATTRIBUTES.items(): try: # setter getattr(standard_integrator, 'set' + attribute)(std_value) except AttributeError: # Try to set CustomIntegrator global variable try: standard_integrator.setGlobalVariableByName(attribute, std_value) except Exception: pass return standard_integrator @staticmethod def _generate_state_id(thermodynamic_state): """Return a unique key for the ThermodynamicState.""" # We take advantage of the cached _standard_system_hash property # to generate a compatible hash for the thermodynamic state. return thermodynamic_state._standard_system_hash @classmethod def _generate_integrator_id(cls, integrator): """Return a unique key for the given Integrator.""" standard_integrator = cls._standardize_integrator(integrator) xml_serialization = openmm.XmlSerializer.serialize(standard_integrator) # Ignore per-DOF variables for the purpose of hashing. if isinstance(integrator, openmm.CustomIntegrator): tag_iter = re.finditer(r'PerDofVariables>', xml_serialization) try: open_tag_index = next(tag_iter).start() - 1 except StopIteration: # No DOF variables. pass else: close_tag_index = next(tag_iter).end() + 1 xml_serialization = xml_serialization[:open_tag_index] + xml_serialization[close_tag_index:] return xml_serialization.__hash__() @classmethod def _generate_context_id(cls, thermodynamic_state, integrator): """Return a unique key for a context in the given state. We return a tuple containing the ThermodynamicState hash and the the serialization of the Integrator. Keeping the two separated makes it possible to search for Contexts in a given state regardless of the integrator. """ state_id = cls._generate_state_id(thermodynamic_state) integrator_id = cls._generate_integrator_id(integrator) return state_id, integrator_id @staticmethod def _get_default_integrator(temperature): """Return a new instance of the default integrator.""" # Use a likely-to-be-used Integrator. return integrators.GeodesicBAOABIntegrator(temperature=temperature) @classmethod def _default_integrator_id(cls): """Return the unique key of the default integrator.""" if cls._cached_default_integrator_id is None: default_integrator = cls._get_default_integrator(300*unit.kelvin) default_integrator_id = cls._generate_integrator_id(default_integrator) cls._cached_default_integrator_id = default_integrator_id return cls._cached_default_integrator_id _cached_default_integrator_id = None @staticmethod def _validate_platform_properties(platform=None, platform_properties=None): """Check if platform properties are valid for the platform; else raise ValueError.""" if platform_properties is None: return True if platform_properties is not None and platform is None: raise ValueError("To set platform_properties, you need to also specify the platform.") if not isinstance(platform_properties, dict): raise ValueError("platform_properties must be a dictionary") for key, value in platform_properties.items(): if not isinstance(value, str): raise ValueError( "All platform properties must be strings. You supplied {}: {} of type {}".format( key, value, type(value) ) ) # create a context to check if all properties are dummy_system = openmm.System() dummy_system.addParticle(1) dummy_integrator = openmm.VerletIntegrator(1.0*unit.femtoseconds) try: openmm.Context(dummy_system, dummy_integrator, platform, platform_properties) return True except Exception as e: if "Illegal property name" in str(e): raise ValueError("Invalid platform property for this platform. {}".format(e)) else: raise e
# ============================================================================= # DUMMY CONTEXT CACHE # ============================================================================= class DummyContextCache(object): """A dummy ContextCache which always create a new Context. Parameters ---------- platform : openmm.Platform, optional The OpenMM platform to use. If None, OpenMM tries to select the fastest one available (default is None). Attributes ---------- platform : openmm.Platform The OpenMM platform to use. If None, OpenMM tries to select the fastest one available. Examples -------- Create a new ``Context`` object for alanine dipeptide in vacuum in NPT. >>> import openmm >>> from openmm import unit >>> from openmmtools import states, testsystems >>> system = testsystems.AlanineDipeptideVacuum().system >>> thermo_state = states.ThermodynamicState(system, temperature=300*unit.kelvin) >>> context_cache = DummyContextCache() >>> integrator = openmm.VerletIntegrator(1.0*unit.femtosecond) >>> context, context_integrator = context_cache.get_context(thermo_state, integrator) Or create a ``Context`` with an arbitrary integrator (when you only need to compute energies, for example). >>> context, context_integrator = context_cache.get_context(thermo_state) """ def __init__(self, platform=None): self.platform = platform def get_context(self, thermodynamic_state, integrator=None): """Create a new context in the given thermodynamic state. Parameters ---------- thermodynamic_state : states.ThermodynamicState The thermodynamic state of the system. integrator : openmm.Integrator, optional The integrator to bind to the new context. If ``None``, an arbitrary integrator is used. Currently, this is a ``LangevinIntegrator`` with "V R O R V" splitting, but this might change in the future. Default is ``None``. Returns ------- context : openmm.Context The new context in the given thermodynamic system. context_integrator : openmm.Integrator The integrator bound to the context that can be used for propagation. This is identical to the ``integrator`` argument if it was passed. """ if integrator is None: integrator = integrators.LangevinIntegrator( timestep=1.0*unit.femtoseconds, splitting="V R O R V", temperature=thermodynamic_state.temperature ) context = thermodynamic_state.create_context(integrator, self.platform) return context, integrator def __getstate__(self): if self.platform is not None: platform_serialization = self.platform.getName() else: platform_serialization = None return dict(platform=platform_serialization) def __setstate__(self, serialization): if serialization['platform'] is None: self.platform = None else: self.platform = openmm.Platform.getPlatformByName(serialization['platform']) # ============================================================================= # GLOBAL CONTEXT CACHE # ============================================================================= global_context_cache = ContextCache(capacity=None, time_to_live=None) """A shared ContextCache that minimizes Context object creating when using MCMCMove.""" # ============================================================================= # CACHE ENTRY (MODULE INTERNAL USAGE) # ============================================================================= class _CacheEntry(object): """A cache entry holding an optional expiration attribute.""" def __init__(self, value, expiration=None): self.value = value # We create the self.expiration attribute only if requested # to save memory in case the cache stores a lot of entries. if expiration is not None: self.expiration = expiration if __name__ == '__main__': import doctest doctest.testmod()