Source code for mdsuite.database.data_manager

"""
MDSuite: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the data manager. The data manager handles loading of data as TensorFlow
generators. These generators allow for the full use of the TF data pipelines but can
required special formatting rules.
"""
import logging

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from mdsuite.database.simulation_database import Database

log = logging.getLogger(__name__)


[docs]class DataManager: """ Class for the MDS tensor_values fetcher Due to the amount of tensor_values that needs to be collected and the possibility to optimize repeated loading, a separate tensor_values fetching class is required. This class manages how tensor_values is loaded from the MDS database_path and optimizes processes such as pre-loading and parallel reading. """ def __init__( self, database: Database = None, data_path: list = None, data_range: int = None, n_batches: int = None, batch_size: int = None, ensemble_loop: int = None, correlation_time: int = 1, remainder: int = None, atom_selection=np.s_[:], minibatch: bool = False, atom_batch_size: int = None, n_atom_batches: int = None, atom_remainder: int = None, offset: int = 0, ): """ Constructor for the DataManager class Parameters ---------- database : Database Database object from which tensor_values should be loaded data_path : list Path in the HDF5 database to be loaded. data_range : int Data range used in the calculator. n_batches : int Number of batches required. batch_size : int Size of a batch. ensemble_loop : int Number of ensembles to be looped over. correlation_time : int Correlation time used in the calculator. remainder : int Remainder used in the batching. atom_remainder : int Atom-wise remainder used in the atom-wise batching. minibatch : bool If true, atom-wise batching is required. atom_batch_size : int Size of an atom-wise batch. n_atom_batches : int Number of atom-wise batches. atom_selection : int Selection of atoms in the calculation. offset : int Offset in the data loading if it should not be loaded from the start. """ self.database = database self.data_path = data_path self.minibatch = minibatch self.atom_batch_size = atom_batch_size self.n_atom_batches = n_atom_batches self.atom_remainder = atom_remainder self.offset = offset self.data_range = data_range self.n_batches = n_batches self.batch_size = batch_size self.remainder = remainder self.ensemble_loop = ensemble_loop self.correlation_time = correlation_time self.atom_selection = atom_selection
[docs] def batch_generator( self, dictionary: bool = False, system: bool = False, remainder: bool = False, loop_array: np.ndarray = None, ) -> tuple: """ Build a generator object for the batch loop Parameters ---------- dictionary : bool If true return a dict. This is default now and could be removed. system : bool If true, a system parameter is being called for. remainder : bool If true, a remainder batch must be computed. loop_array : np.ndarray If this is not None, elements of this array will be looped over in in the batches which load data at their indices. For example, loop_array = [[1, 4, 7], [10, 13, 16], [19, 21, 24]] In this case, in the fist batch, configurations 1, 4, and 7 will be loaded for the analysis. This is particularly important in the structural properties. Returns ------- Returns a generator function and its arguments """ args = ( self.n_batches, self.batch_size, self.database.path, self.data_path, dictionary, ) def generator( batch_number: int, batch_size: int, database: str, data_path: list, dictionary: bool, ): """ Generator function for the batch loop. Parameters ---------- batch_number : int Number of batches to be looped over batch_size : int size of each batch to load database : Database database_path from which to load the tensor_values data_path : str Path to the tensor_values in the database_path dictionary : bool If true, tensor_values is returned in a dictionary Returns ------- """ database = Database(database) loop_over_remainder = self.remainder > 0 for batch in range(batch_number + int(loop_over_remainder)): start = int(batch * batch_size) + self.offset stop = int(start + batch_size) data_size = tf.cast(batch_size, dtype=tf.int32) # Handle the remainder if batch == batch_number: stop = int(start + self.remainder) data_size = tf.cast(self.remainder, dtype=tf.int16) # TODO make default if loop_array is not None: if isinstance(self.atom_selection, dict): select_slice = {} for item in self.atom_selection: select_slice[item] = np.s_[ self.atom_selection[item], loop_array[batch] ] else: select_slice = np.s_[self.atom_selection, loop_array[batch]] elif system: select_slice = np.s_[start:stop] else: if type(self.atom_selection) is dict: select_slice = {} for item in self.atom_selection: select_slice[item] = np.s_[ self.atom_selection[item], start:stop ] else: select_slice = np.s_[self.atom_selection, start:stop] yield database.load_data( data_path, select_slice=select_slice, dictionary=dictionary, d_size=data_size, ) def atom_generator( batch_number: int, batch_size: int, database: str, data_path: list, dictionary: bool, ): """ Generator function for a mini-batched calculation. Parameters ---------- batch_number : int Number of batches to be looped over batch_size : int size of each batch to load database : Database database_path from which to load the tensor_values data_path : str Path to the tensor_values in the database_path dictionary : bool If true, tensor_values is returned in a dictionary Returns ------- """ # Atom selection not currently available for mini-batched calculations if type(self.atom_selection) is dict: raise ValueError( "Atom selection is not currently available " "for mini-batched calculations" ) database = Database(database) _atom_remainder = [1 if self.atom_remainder else 0][0] start = 0 for atom_batch in tqdm( range(self.n_atom_batches + _atom_remainder), total=self.n_atom_batches + _atom_remainder, ncols=70, desc="batch loop", ): atom_start = atom_batch * self.atom_batch_size atom_stop = atom_start + self.atom_batch_size if atom_batch == self.n_atom_batches: atom_stop = start + self.atom_remainder for batch in range(batch_number + int(remainder)): start = int(batch * batch_size) + self.offset stop = int(start + batch_size) data_size = tf.cast(batch_size, dtype=tf.int32) if batch == batch_number: stop = int(start + self.remainder) data_size = tf.cast(self.remainder, dtype=tf.int16) select_slice = np.s_[int(atom_start) : int(atom_stop), start:stop] yield database.load_data( data_path, select_slice=select_slice, dictionary=dictionary, d_size=data_size, ) if self.minibatch: return atom_generator, args else: return generator, args
[docs] def ensemble_generator(self, system: bool = False, glob_data: dict = None) -> tuple: """ Build a generator for the ensemble loop Parameters ---------- system : bool If true, the system generator is returned. glob_data : dict data to be loaded in ensembles from a tensorflow generator. e.g. {b'Na/Positions': tf.Tensor}. Will usually include a b'data_size' key which is checked in the loop and ignored. All keys are in byte arrays. This appears when you pass a dict to the tensorflow generator. Returns ------- Ensemble loop generator """ args = (self.ensemble_loop, self.correlation_time, self.data_range) def dictionary_generator(ensemble_loop, correlation_time, data_range): """ Generator for the ensemble loop Parameters ---------- ensemble_loop : int Number of ensembles to loop over correlation_time : int Distance between ensembles data_range : int Size of each ensemble Returns ------- None """ ensemble_loop = int( np.clip( (glob_data[b"data_size"] - data_range) / correlation_time, 1, None ) ) for ensemble in range(ensemble_loop): start = ensemble * correlation_time stop = start + data_range output_dict = {} for item in glob_data: if item == str.encode("data_size"): pass else: output_dict[item] = glob_data[item][:, start:stop] yield output_dict return dictionary_generator, args