Module to perform the calculation of the angular distribution function (ADF). The ADF
describes the average distribution of angles between three particles of species a, b,
and c. Note that a, b, and c may all be the same species, e.g. Na-Na-Na.
import itertools
import logging
from abc import ABC
from dataclasses import dataclass
from typing import Union

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from mdsuite.calculators.calculator import call
from mdsuite.calculators.trajectory_calculator import TrajectoryCalculator
from mdsuite.database.mdsuite_properties import mdsuite_properties
from mdsuite.utils.linalg import get_angles
from mdsuite.utils.meta_functions import join_path
from mdsuite.utils.neighbour_list import (

log = logging.getLogger(__name__)

[docs]@dataclass class Args: """Data class for the saved properties.""" number_of_bins: int number_of_configurations: int correlation_time: int atom_selection: np.s_ data_range: int cutoff: float start: int norm_power: Union[int, float] stop: int species: list molecules: bool
[docs]class AngularDistributionFunction(TrajectoryCalculator, ABC): """ Compute the Angular Distribution Function for all species combinations. Attributes ---------- batch_size : int Number of batches, to split the configurations into. n_minibatches: int Number of minibatches for computing the angles to split each batch into n_confs: int Number of configurations to analyse r_cut: float cutoff radius for the ADF start: int Index of the first configuration stop: int Index of the last configuration bins: int bins for the ADF use_tf_function: bool, default False activate the tf.function decorator for the minibatches. Can speed up the calculation significantly, but may lead to excessive use of memory! During the first batch, this function will be traced. Tracing is slow, so this might only be useful for a larger number of batches. See Also -------- mdsuite.calculators.calculator.Calculator class Examples -------- experiment.run_computation.AngularDistributionFunction(n_confs = 100, r_cut = 3.2, batch_size = 10, n_minibatches = 50, start = 0, stop = 200, bins = 100, use_tf_function = False) """ def __init__(self, **kwargs): """ Compute the Angular Distribution Function for all species combinations. Parameters ---------- experiment : object Experiment object from which to take attributes. """ super().__init__(**kwargs) self.scale_function = {"quadratic": {"outer_scale_factor": 10}} self.loaded_property = mdsuite_properties.positions self.use_tf_function = None self.molecules = None self.bin_range = None self.number_of_atoms = None self.norm_power = None self.sample_configurations = None self.result_keys = ["max_peak"] self.result_series_keys = ["angle", "adf"] self._dtype = tf.float32 self.adf_minibatch = None # memory management for triples generation per batch. self.analysis_name = "Angular_Distribution_Function" self.x_label = r"$$\text{Angle} / \theta$$" self.y_label = r"$$\text{ADF} / a.u.$$" @call def __call__( self, batch_size: int = 1, minibatch: int = -1, number_of_configurations: int = 5, cutoff: float = 6.0, start: int = 1, stop: int = None, number_of_bins: int = 500, species: list = None, use_tf_function: bool = False, molecules: bool = False, atom_selection=np.s_[:], plot: bool = True, norm_power: int = 4, **kwargs, ): """ Parameters ---------- batch_size : int Number of batches, to split the configurations into. minibatch: int Number of minibatches for computing the angles to split each batch into number_of_configurations: int Number of configurations to analyse cutoff: float cutoff radius for the ADF start: int Index of the first configuration stop: int Index of the last configuration number_of_bins: int bins for the ADF use_tf_function: bool, default False activate the tf.function decorator for the minibatches. Can speed up the calculation significantly, but may lead to excessive use of memory! During the first batch, this function will be traced. Tracing is slow, so this might only be useful for a larger number of batches. species : list A list of species to use. atom_selection : Union[np.s_, dict] Atoms to be used in the analysis. norm_power: int The power of the normalization factor applied to the ADF histogram. If set to zero no distance normalization will be applied. molecules : bool if true, perform the analysis on molecules. plot : bool If true, plot the result of the analysis. """ # set args that will affect the computation result self.args = Args( number_of_bins=number_of_bins, cutoff=cutoff, start=start, stop=stop, atom_selection=atom_selection, data_range=1, correlation_time=1, molecules=molecules, species=species, number_of_configurations=number_of_configurations, norm_power=norm_power, ) # Parse the user arguments. self.use_tf_function = use_tf_function self.cutoff = cutoff self.plot = plot self._batch_size = batch_size # memory management for all batches self.adf_minibatch = minibatch self.bin_range = [0.0, 3.15] # from 0 to a chemists pi self.norm_power = norm_power self.override_n_batches = kwargs.get("batches")
[docs] def check_input(self): """ Check the inputs and set defaults if necessary. Returns ------- Updates the class attributes. """ self._run_dependency_check() if self.args.stop is None: self.args.stop = self.experiment.number_of_configurations - 1 # Get the correct species out. if self.args.species is None: if self.args.molecules: self.args.species = list(self.experiment.molecules) self._compute_number_of_atoms(self.experiment.molecules) else: self.args.species = list(self.experiment.species) self._compute_number_of_atoms(self.experiment.species) else: self._compute_number_of_atoms(self.experiment.species)
def _compute_number_of_atoms(self, reference: dict): """ Compute the number of atoms total in the selected set of species. Parameters ---------- reference : dict Reference dictionary in which to look for indices lists Returns ------- Updates the number of atoms attribute. """ # TODO return to dotdict form when molecules is a dotdict number_of_atoms = 0 for item in self.args.species: if isinstance(self.args.atom_selection, dict): number_of_atoms = 0 for item in self.args.atom_selection: number_of_atoms += len(self.args.atom_selection[item]) else: number_of_atoms += reference[item].n_particles self.number_of_atoms = number_of_atoms if self.adf_minibatch == -1: self.adf_minibatch = number_of_atoms def _prepare_data_structure(self): """ Prepare variables and dicts for the analysis. Returns ------- -------. """ sample_configs = np.linspace( self.args.start, self.args.stop, self.args.number_of_configurations, dtype=int, ) species_indices = [] start_index = 0 stop_index = 0 for species in self.args.species: try: if isinstance(self.args.atom_selection, dict): stop_index += len(self.args.atom_selection[species]) else: stop_index += self.experiment.species[species].n_particles except KeyError: stop_index += self.experiment.molecules[species].n_particles species_indices.append((species, start_index, stop_index)) start_index = stop_index return sample_configs, species_indices def _prepare_triples_generator(self): """ Prepare the triples generators including tf.function Returns -------. """ if self.use_tf_function: @tf.function def _get_triplets(x): return get_triplets( x, r_cut=self.cutoff, n_atoms=self.number_of_atoms, n_batches=self.adf_minibatch, disable_tqdm=True, ) else: def _get_triplets(x): return get_triplets( x, r_cut=self.cutoff, n_atoms=self.number_of_atoms, n_batches=self.adf_minibatch, disable_tqdm=True, ) return _get_triplets def _compute_rijk_matrices(self, tmp: tf.Tensor, timesteps: int): """ Compute the rij matrix. Returns ------- """ _get_triplets = self._prepare_triples_generator() r_ij_flat = next( get_neighbour_list(tmp, cell=self.experiment.box_array, batch_size=1) ) r_ij_indices = get_triu_indicies(self.number_of_atoms) # Shape is now (n_atoms, n_atoms, 3, n_timesteps) r_ij_mat = tf.scatter_nd( indices=tf.transpose(r_ij_indices), updates=tf.transpose(r_ij_flat, (1, 2, 0)), shape=(self.number_of_atoms, self.number_of_atoms, 3, timesteps), ) r_ij_mat -= tf.transpose(r_ij_mat, (1, 0, 2, 3)) r_ij_mat = tf.transpose(r_ij_mat, (3, 0, 1, 2)) r_ijk_indices = _get_triplets(r_ij_mat) return r_ij_mat, r_ijk_indices @staticmethod def _compute_angles(species, r_ijk_indices): """ Compute the angles between indices in triangle. Parameters ---------- species r_ijk_indices Returns ------- condition name """ (i_name, i_min, i_max), (j_name, j_min, j_max), (k_name, k_min, k_max) = species name = f"{i_name}-{j_name}-{k_name}" i_condition = tf.logical_and( r_ijk_indices[:, 1] >= i_min, r_ijk_indices[:, 1] < i_max ) j_condition = tf.logical_and( r_ijk_indices[:, 2] >= j_min, r_ijk_indices[:, 2] < j_max ) k_condition = tf.logical_and( r_ijk_indices[:, 3] >= k_min, r_ijk_indices[:, 3] < k_max ) condition = tf.math.logical_and( x=tf.math.logical_and(x=i_condition, y=j_condition), y=k_condition ) return condition, name def _build_histograms(self, positions, species_indices, angles): """ Build the adf histograms. Returns ------- angles : dict A dictionary of the triples references and their histogram values. """ tmp = tf.transpose(tf.concat(positions, axis=0), (1, 0, 2)) timesteps, atoms, _ = tf.shape(tmp) r_ij_mat, r_ijk_indices = self._compute_rijk_matrices(tmp, timesteps) for species in itertools.combinations_with_replacement(species_indices, 3): # Select the part of the r_ijk indices, where the selected species # triple occurs. condition, name = self._compute_angles(species, r_ijk_indices) tmp = tf.gather_nd(r_ijk_indices, tf.where(condition)) # Get the indices required. angle_vals, pre_factor = get_angles(r_ij_mat, tmp) pre_factor = 1 / pre_factor**self.norm_power histogram, _ = np.histogram( angle_vals, bins=self.args.number_of_bins, range=self.bin_range, weights=pre_factor, density=True, ) histogram = tf.cast(histogram, dtype=tf.float32) if angles.get(name) is not None: angles.update({name: angles.get(name) + histogram}) else: angles.update({name: histogram}) return angles def _compute_adfs(self, angles, species_indices): """ Compute the ADF and store it. Parameters ---------- angles : dict Dict of angle combinations and their histograms. species_indices : list A list of indices associated with species under consideration. Returns ------- Updates the class, the SQL database, and plots values if required. """ for species in itertools.combinations_with_replacement(species_indices, 3): name = f"{species[0][0]}-{species[1][0]}-{species[2][0]}" hist = angles.get(name) bin_range_to_angles = np.linspace( self.bin_range[0] * (180 / 3.14159), self.bin_range[1] * (180 / 3.14159), self.args.number_of_bins, ) self.selected_species = [_species[0] for _species in species] self.data_range = self.args.number_of_configurations log.debug(f"species are {species}") max_angle = bin_range_to_angles[tf.math.argmax(hist.numpy())] data = { self.result_keys[0]: max_angle, self.result_series_keys[0]: bin_range_to_angles.tolist(), self.result_series_keys[1]: hist.numpy().tolist(), } self.queue_data(data=data, subjects=self.selected_species)
[docs] def plot_data(self, data): """Plot data.""" for selected_species, val in data.items(): bin_range_to_angles = np.linspace( self.bin_range[0] * (180 / 3.14159), self.bin_range[1] * (180 / 3.14159), len(val[self.result_series_keys[0]]), ) title_value = bin_range_to_angles[ tf.math.argmax(val[self.result_series_keys[1]]) ] self.run_visualization( x_data=np.array(val[self.result_series_keys[0]]), y_data=np.array(val[self.result_series_keys[1]]), title=f"{selected_species} - Max: {title_value:.3f} degrees ", )
def _format_data(self, batch: tf.Tensor, keys: list) -> tf.Tensor: """ Format the loaded data for use in the rdf calculator. The RDF requires a reshaped dataset. The generator will load a default dict oriented type. This method restructures the data to be used in the calculator. Parameters ---------- batch : tf.Tensor A batch of data to transform. keys : list Dict keys to extract from the data. Returns ------- """ formatted_data = [] for item in keys: formatted_data.append(batch[item]) if len(self.args.species) == 1: return tf.cast(formatted_data[0], dtype=self.dtype) else: return tf.cast(tf.concat(formatted_data, axis=0), dtype=self.dtype) def _correct_batch_properties(self): """ We must fix the batch size parameters set by the parent class. Returns ------- Updates the parent class. """ if self.batch_size > self.args.number_of_configurations: self.batch_size = self.args.number_of_configurations self.n_batches = 1 else: self.n_batches = int(self.args.number_of_configurations / self.batch_size) if self.override_n_batches is not None: self.n_batches = self.override_n_batches if self.minibatch: self.batch_size = 1 self.n_batches = self.args.number_of_configurations self.memory_manager.atom_batch_size = None self.memory_manager.n_atom_batches = None self.memory_manager.atom_remainder = None self.minibatch = False self.remainder = 0 self.minibatch = False
[docs] def prepare_computation(self): """ Run steps necessary to prepare the computation for running. Returns ------- """ path_list = [ join_path(species_name, for species_name in self.args.species ] self._prepare_managers(path_list) # batch loop correction self._correct_batch_properties() # Get the correct dict keys. dict_keys = [] for species_name in self.args.species: dict_keys.append( str.encode(join_path(species_name, ) # Split the configurations into batches. split_arr = np.array_split(self.sample_configurations, self.n_batches) return dict_keys, split_arr
[docs] def run_calculator(self): """ Run the analysis. Returns ------- """ self.check_input() self.sample_configurations, species_indices = self._prepare_data_structure() dict_keys, split_arr = self.prepare_computation() # Get the batch dataset batch_ds = self.get_batch_dataset( subject_list=self.args.species, loop_array=split_arr, correct=True ) angles = {} # Loop over the batches. for idx, batch in tqdm(enumerate(batch_ds), ncols=70, total=self.n_batches): positions_tensor = self._format_data(batch=batch, keys=dict_keys) angles = self._build_histograms( positions=positions_tensor, species_indices=species_indices, angles=angles, ) self._compute_adfs(angles, species_indices)