Source code for mdsuite.database.calculator_database

"""
MDSuite: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""
from __future__ import annotations

import logging
from collections import Counter
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, List

from sqlalchemy import and_

import mdsuite.database.scheme as db
from mdsuite.utils.meta_functions import is_jsonable

if TYPE_CHECKING:
    from mdsuite.experiment import Experiment

log = logging.getLogger(__name__)


[docs]@dataclass class ComputationResults: """Computation Results dataclass.""" data: dict = field(default_factory=dict) subjects: dict = field(default_factory=list)
[docs]@dataclass class Args: """Dummy Class for type hinting.""" pass
[docs]def conv_to_db(val): """Convert the given value to something that can be stored in the database.""" if is_jsonable(val): if not isinstance(val, dict): val = {"serialized_value": val} else: val = {"serialized_value": str(val)} return val
[docs]class CalculatorDatabase: """Database Interactions of the calculator class. This class handles the interaction of the calculator with the project database """ def __init__(self, experiment): """Constructor for the calculator database.""" self.experiment: Experiment = experiment self.db_computation: db.Computation = None self.database_group = None self.analysis_name = None self.load_data = None self.args = Args() self._queued_data = [] # List of computation attributes that will be added to the database self.db_computation_attributes = []
[docs] def prepare_db_entry(self): """Prepare a database entry based on the attributes defined in the init.""" with self.experiment.project.session as ses: experiment = ( ses.query(db.Experiment) .filter(db.Experiment.name == self.experiment.name) .first() ) self.db_computation = db.Computation(experiment=experiment) self.db_computation.name = self.analysis_name
[docs] def get_computation_data(self) -> db.Computation: """Query the database for computation data. This method used the self.args dataclass to look for matching calculator attributes and returns a db.Computation object if the calculation has already been performed Return: ------ db.Computation Returns the computation object from the database if available, otherwise returns None """ log.debug(f"Getting data for {self.experiment.name} with args {self.args}") with self.experiment.project.session as ses: experiment = ( ses.query(db.Experiment) .filter(db.Experiment.name == self.experiment.name) .first() ) # filter the correct experiment computations = ses.query(db.Computation).filter( db.Computation.experiment == experiment, db.Computation.name == self.analysis_name, ) # filter set args for args_field in fields(self.args): key = args_field.name val = getattr(self.args, key) computations = computations.filter( db.Computation.computation_attributes.any( and_( db.ComputationAttribute.name == key, db.ComputationAttribute.data == conv_to_db(val), ) ) ) # filter the version of the experiment, e.g. run new computation # if the experiment version has changed computations = computations.filter( db.Computation.computation_attributes.any( and_( db.ComputationAttribute.name == "version", db.ComputationAttribute.data == conv_to_db(self.experiment.version), ) ) ) computations = computations.all() if len(computations) > 0: log.debug("Calculation already performed! Loading it up") # loading data_dict to avoid DetachedInstance errors # this can take some time, depending on the size of the data # TODO remove and use lazy call for computation in computations: _ = computation.data_dict _ = computation.data_range if len(computations) > 0: if len(computations) > 1: log.warning( "Something went wrong! Found more than one computation with the" " given arguments!" ) return computations[0] # it should only be one value return None
[docs] def save_computation_args(self): """Store the user args. This method stored the user args from the self.args dataclass into SQLAlchemy objects and adds them to a list which will be written to the database after the calculation was successful. """ for args_field in fields(self.args): key = args_field.name val = getattr(self.args, key) computation_attribute = db.ComputationAttribute( name=key, data=conv_to_db(val) ) self.db_computation_attributes.append(computation_attribute) # save the current experiment version in the ComputationAttributes experiment_version = db.ComputationAttribute( name="version", data=conv_to_db(self.experiment.version) ) self.db_computation_attributes.append(experiment_version)
[docs] def save_db_data(self): """Save all the collected computationattributes and computation data to the database. This will be run after the computation was successful. """ with self.experiment.project.session as ses: ses.add(self.db_computation) for val in self.db_computation_attributes: # I need to set the relation inside the session. val.computation = self.db_computation ses.add(val) for data_obj in self._queued_data: # TODO consider renaming species to e.g., subjects, because species here # can also be molecules data_obj: ComputationResults computation_result = db.ComputationResult( computation=self.db_computation, data=data_obj.data ) species_list = [] for species in data_obj.subjects: # this will collect duplicates that can be counted later, # otherwise I would use .in_ species_list.append( ses.query(db.ExperimentSpecies) .filter(db.ExperimentSpecies.name == species) .first() ) # in case of e.g. `System` species will be [None], which is then removed species_list = [x for x in species_list if x is not None] for species, count in Counter(species_list).items(): associate = db.SpeciesAssociation(count=count) associate.species = species computation_result.species.append(associate) ses.add(computation_result) ses.commit()
[docs] def queue_data(self, data, subjects): """Queue data to be stored in the database. Parameters ---------- data: dict A dictionary containing all the data that was computed by the computation subjects: list A list of strings / subject names that are associated with the data, e.g. the pairs of the RDF """ self._queued_data.append(ComputationResults(data=data, subjects=subjects))
[docs] def update_database(self, parameters, delete_duplicate: bool = True): """ Add data to the database. Parameters ---------- parameters : dict Parameters to be used in the addition, i.e. {"Analysis": "Green_Kubo_Self_Diffusion", "Subject": "Na", "data_range": 500, "data": 1.8e-9} delete_duplicate : bool If true, duplicate entries will be deleted. Returns ------- Updates the sql database """ raise DeprecationWarning("This function has been replaced by `queue_data`")
# REMOVE # TODO rename and potentially move to a RDF based parent class def _get_rdf_data(self) -> List[db.Computation]: """Fill the data_files list with filenames of the rdf tensor_values.""" # TODO replace with exp.load.RDF() raise DeprecationWarning( "Replaced by experiment.run.RadialDistributionFunction(**kwargs)" ) # with self.experiment.project.session as ses: # computations = ( # ses.query(db.Computation) # .filter( # db.Computation.computation_attributes.any( # str_value="Radial_Distribution_Function", name="Property" # ) # ) # .all() # ) # # for computation in computations: # _ = computation.data_dict # _ = computation.data_range # # return computations # TODO rename and potentially move to a RDF based parent class def _load_rdf_from_file(self, computation: db.Computation): """Load the raw rdf tensor_values from a directory.""" raise DeprecationWarning( "Replaced by experiment.run.RadialDistributionFuncion(**kwargs)" )
# self.radii = np.array(computation.data_dict["x"]).astype(float)[1:] # self.rdf = np.array(computation.data_dict["y"]).astype(float)[1:] #####################