Source code for mdsuite.visualizer.d3_data_visualizer

"""
MDSuite: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the MDSuite 3d visualizer.
"""
import importlib.resources
import json
from typing import Union

import numpy as np
import open3d as o3d
import open3d.visualization.gui as gui
from matplotlib import cm
from PIL.ImageColor import getcolor


[docs]class DataVisualizer3D: """Class for the visualizer of three dimensional data.""" def __init__( self, data: np.ndarray, title: str, center: Union[str, dict] = None, colour_map: np.ndarray = None, ): """ Constructor for the data visualizer. Parameters ---------- data : np.ndarray data to plot. center : str centre molecule to be displayed (optional) title : str title of the plot. colour_map : np.ndarray A colour map to apply to the data. """ self.data = data self.title = title self.center = center self.colour_map = colour_map self.point_cloud = o3d.geometry.PointCloud() self.point_cloud.points = o3d.utility.Vector3dVector(self.data) self.point_cloud.colors = o3d.utility.Vector3dVector(self._build_colour_map()) self._build_app() def _build_colour_map(self): """ Build the colour map Returns -------. """ viridis = cm.get_cmap("binary") dat = viridis(self.colour_map)[:, :-1] return dat def _build_app(self): """ Build the app window and update the class. Returns ------- Updates the class. """ self.app = gui.Application.instance self.app.initialize() self.vis = o3d.visualization.O3DVisualizer("MDSuite Visualizer", 1024, 768) self.vis.show_settings = True self.vis.reset_camera_to_default() self.app.add_window(self.vis) @staticmethod def _get_atom_properties(element: str) -> dict: """ Get atom size and colour based on species. Parameters ---------- element : str Name of the element you want to render. Returns ------- data : dict A dictionary of data to use for the rendering: e.g. {'colour': (0.7, 0.33, 0.0), 'mass': 0.8) """ data = {} data_name = "mdsuite.data" json_name = "PubChemElements_all.json" with importlib.resources.open_text(data_name, json_name) as json_file: pse = json.loads(json_file.read()) try: name_split = element.split("_") element = name_split[0] except ValueError: element = element for entry in pse: if pse[entry][1] == element: data["colour"] = getcolor(f"#{pse[entry][4]}", "RGB") data["mass"] = float(pse[entry][3]) return data def _add_center(self): """ Add a rendering to the (0, 0, 0) point in the plot. Returns ------- Updates the plot. """ if type(self.center) is str: self._add_single_center() else: self._add_group_center() def _add_single_center(self): """ Add a single particle to the center. Returns ------- """ data = self._get_atom_properties(self.center) mesh = o3d.geometry.TriangleMesh.create_sphere( radius=data["mass"] / 25, resolution=15 ) mesh.compute_vertex_normals() mesh.paint_uniform_color(data["colour"]) self.vis.add_geometry("Center", mesh) def _add_group_center(self): """ Add a group of particle to the center. Returns ------- Adds a group of particles to the center. """ translation = np.array([0, 0, 0]) mass = 0.0 for item in self.center: data = self._get_atom_properties(item) mass += data["mass"] translation += (self.center[item] / 10) * data["mass"] translation = -1 * translation / mass global_mesh = None for i, item in enumerate(self.center): data = self._get_atom_properties(item) if i == 0: global_mesh = o3d.geometry.TriangleMesh.create_sphere( radius=data["mass"] / 100, resolution=15 ) global_mesh.compute_vertex_normals() global_mesh.translate(self.center[item] / 10) colour = np.array(data["colour"]) / 255 global_mesh.paint_uniform_color(colour) # self.vis.add_geometry(item, global_mesh) else: mesh = o3d.geometry.TriangleMesh.create_sphere( radius=data["mass"] / 100, resolution=15 ) mesh.compute_vertex_normals() mesh.translate(self.center[item] / 10) colour = np.array(data["colour"]) / 255 mesh.paint_uniform_color(colour) global_mesh += mesh # self.vis.add_geometry(item, mesh) global_mesh.translate(translation) self.vis.add_geometry("Center", global_mesh)
[docs] def plot(self): """ Plot the data. Returns ------- """ self.vis.add_geometry("Points", self.point_cloud) # if self.center is not None: self._add_center() self.vis.reset_camera_to_default() self.app.run()