Source code for odatse.domain.meshgrid

# SPDX-License-Identifier: MPL-2.0
# ODAT-SE -- an open framework for data analysis
# Copyright (C) 2020- The University of Tokyo
# This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
# If a copy of the MPL was not distributed with this file, You can obtain one at

from typing import List, Dict, Union, Any

from pathlib import Path
import numpy as np

import odatse
from ._domain import DomainBase

[docs] class MeshGrid(DomainBase): """ MeshGrid class for handling grid data for quantum beam diffraction experiments. """ grid: List[Union[int, float]] = [] grid_local: List[Union[int, float]] = [] candicates: int
[docs] def __init__(self, info: odatse.Info = None, *, param: Dict[str, Any] = None): """ Initialize the MeshGrid object. Parameters ---------- info : Info, optional Information object containing algorithm parameters. param : dict, optional Dictionary containing parameters for setting up the grid. """ super().__init__(info) if info: if "param" in info.algorithm: self._setup(info.algorithm["param"]) else: raise ValueError("ERROR: algorithm.param not defined") elif param: self._setup(param) else: pass
[docs] def do_split(self): """ Split the grid data among MPI processes. """ if self.mpisize > 1: index = [idx for idx, *v in self.grid] index_local = np.array_split(index, self.mpisize)[self.mpirank] self.grid_local = [[idx, *v] for idx, *v in self.grid if idx in index_local] else: self.grid_local = self.grid
[docs] def _setup(self, info_param): """ Setup the grid based on provided parameters. Parameters ---------- info_param Dictionary containing parameters for setting up the grid. """ if "mesh_path" in info_param: self._setup_from_file(info_param) else: self._setup_grid(info_param) self.ncandicates = len(self.grid)
[docs] def _setup_from_file(self, info_param): """ Setup the grid from a file. Parameters ---------- info_param Dictionary containing parameters for setting up the grid. """ if "mesh_path" not in info_param: raise ValueError("ERROR: mesh_path not defined") mesh_path = self.root_dir / Path(info_param["mesh_path"]).expanduser() if not mesh_path.exists(): raise FileNotFoundError("mesh_path not found: {}".format(mesh_path)) comments = info_param.get("comments", "#") delimiter = info_param.get("delimiter", None) skiprows = info_param.get("skiprows", 0) if self.mpirank == 0: data = np.loadtxt(mesh_path, comments=comments, delimiter=delimiter, skiprows=skiprows) if data.ndim == 1: data = data.reshape(1, -1) # old format: index x1 x2 ... -> omit index data = data[:, 1:] else: data = None if self.mpisize > 1: data = odatse.mpi.comm().bcast(data, root=0) self.grid = [[idx, *v] for idx, v in enumerate(data)]
[docs] def _setup_grid(self, info_param): """ Setup the grid based on min, max, and num lists. Parameters ---------- info_param Dictionary containing parameters for setting up the grid. """ if "min_list" not in info_param: raise ValueError("ERROR: algorithm.param.min_list is not defined in the input") min_list = np.array(info_param["min_list"], dtype=float) if "max_list" not in info_param: raise ValueError("ERROR: algorithm.param.max_list is not defined in the input") max_list = np.array(info_param["max_list"], dtype=float) if "num_list" not in info_param: raise ValueError("ERROR: algorithm.param.num_list is not defined in the input") num_list = np.array(info_param["num_list"], dtype=int) if len(min_list) != len(max_list) or len(min_list) != len(num_list): raise ValueError("ERROR: lengths of min_list, max_list, num_list do not match") xs = [ np.linspace(mn, mx, num=nm) for mn, mx, nm in zip(min_list, max_list, num_list) ] self.grid = [ [idx, *v] for idx, v in enumerate( np.array( np.meshgrid(*xs, indexing='xy') ).reshape(len(xs), -1).transpose() ) ]
[docs] def store_file(self, store_path, *, header=""): """ Store the grid data to a file. Parameters ---------- store_path Path to the file where the grid data will be stored. header Header to be included in the file. """ if self.mpirank == 0: np.savetxt(store_path, [[*v] for idx, *v in self.grid], header=header)
[docs] @classmethod def from_file(cls, mesh_path): """ Create a MeshGrid object from a file. Parameters ---------- mesh_path Path to the file containing the grid data. Returns ------- MeshGrid a MeshGrid object. """ return cls(param={"mesh_path": mesh_path})
[docs] @classmethod def from_dict(cls, param): """ Create a MeshGrid object from a dictionary of parameters. Parameters ---------- param Dictionary containing parameters for setting up the grid. Returns ------- MeshGrid a MeshGrid object. """ return cls(param=param)
if __name__ == "__main__": ms = MeshGrid.from_dict({ 'min_list': [0,0,0], 'max_list': [1,1,1], 'num_list': [5,5,5], }) ms.store_file("meshfile.dat", header="sample mesh data") ms2 = MeshGrid.from_file("meshfile.dat") ms2.do_split() if odatse.mpi.rank() == 0: print(ms2.grid) print(odatse.mpi.rank(), ms2.grid_local) ms2.store_file("meshfile2.dat", header="store again")