# SPDX-License-Identifier: MPL-2.0
#
# ODAT-SE -- an open framework for data analysis
# Copyright (C) 2020- The University of Tokyo
#
# This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
# If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
from abc import ABCMeta, abstractmethod
from enum import IntEnum
import time
import os
import pathlib
import pickle
import shutil
import copy
import numpy as np
import odatse
import odatse.util.limitation
from odatse import exception, mpi
# for type hints
from pathlib import Path
from typing import List, Optional, TYPE_CHECKING, Dict, Tuple
if TYPE_CHECKING:
from mpi4py import MPI
[docs]
class AlgorithmStatus(IntEnum):
"""Enumeration for the status of the algorithm."""
INIT = 1
PREPARE = 2
RUN = 3
[docs]
class AlgorithmBase(metaclass=ABCMeta):
"""Base class for algorithms, providing common functionality and structure."""
mpicomm: Optional["MPI.Comm"]
mpisize: int
mpirank: int
rng: np.random.RandomState
dimension: int
label_list: List[str]
runner: Optional[odatse.Runner]
root_dir: Path
output_dir: Path
proc_dir: Path
timer: Dict[str, Dict]
status: AlgorithmStatus = AlgorithmStatus.INIT
mode: Optional[str] = None
[docs]
@abstractmethod
def __init__(
self,
info: odatse.Info,
runner: Optional[odatse.Runner] = None,
run_mode: str = "initial"
) -> None:
"""
Initialize the algorithm with the given information and runner.
Parameters
----------
info : Info
Information object containing algorithm and base parameters.
runner : Runner (optional)
Optional runner object to execute the algorithm.
run_mode : str
Mode in which the algorithm should run.
"""
self.mpicomm = mpi.comm()
self.mpisize = mpi.size()
self.mpirank = mpi.rank()
self.timer = {"init": {}, "prepare": {}, "run": {}, "post": {}}
self.timer["init"]["total"] = 0.0
self.status = AlgorithmStatus.INIT
self.mode = run_mode.lower()
# keep copy of parameters
self.info = copy.deepcopy(info.algorithm)
self.dimension = info.algorithm.get("dimension") or info.base.get("dimension")
if not self.dimension:
raise ValueError("ERROR: dimension is not defined")
if "label_list" in info.algorithm:
label = info.algorithm["label_list"]
if len(label) != self.dimension:
raise ValueError(f"ERROR: length of label_list and dimension do not match ({len(label)} != {self.dimension})")
self.label_list = label
else:
self.label_list = [f"x{d+1}" for d in range(self.dimension)]
# initialize random number generator
self.__init_rng(info)
# directories
self.root_dir = info.base["root_dir"]
self.output_dir = info.base["output_dir"]
self.proc_dir = self.output_dir / str(self.mpirank)
self.proc_dir.mkdir(parents=True, exist_ok=True)
# Some cache of the filesystem may delay making a dictionary
# especially when mkdir just after removing the old one
while not self.proc_dir.is_dir():
time.sleep(0.1)
if self.mpisize > 1:
self.mpicomm.Barrier()
# checkpointing
self.checkpoint = info.algorithm.get("checkpoint", False)
self.checkpoint_file = info.algorithm.get("checkpoint_file", "status.pickle")
self.checkpoint_steps = info.algorithm.get("checkpoint_steps", 65536*256) # large number
self.checkpoint_interval = info.algorithm.get("checkpoint_interval", 86400*360) # longer enough
# runner
if runner is not None:
self.set_runner(runner)
def __init_rng(self, info: odatse.Info) -> None:
"""
Initialize the random number generator.
Parameters
----------
info : Info
Information object containing algorithm parameters.
"""
seed = info.algorithm.get("seed", None)
seed_delta = info.algorithm.get("seed_delta", 314159)
if seed is None:
self.rng = np.random.RandomState()
else:
self.rng = np.random.RandomState(seed + self.mpirank * seed_delta)
[docs]
def set_runner(self, runner: odatse.Runner) -> None:
"""
Set the runner for the algorithm.
Parameters
----------
runner : Runner
Runner object to execute the algorithm.
"""
self.runner = runner
[docs]
def prepare(self) -> None:
"""
Prepare the algorithm for execution.
"""
if self.runner is None:
msg = "Runner is not assigned"
raise RuntimeError(msg)
self._prepare()
self.status = AlgorithmStatus.PREPARE
[docs]
@abstractmethod
def _prepare(self) -> None:
"""Abstract method to be implemented by subclasses for preparation steps."""
pass
[docs]
def run(self) -> None:
"""
Run the algorithm.
"""
if self.status < AlgorithmStatus.PREPARE:
msg = "algorithm has not prepared yet"
raise RuntimeError(msg)
original_dir = os.getcwd()
os.chdir(self.proc_dir)
self.runner.prepare(self.proc_dir)
self._run()
self.runner.post()
os.chdir(original_dir)
self.status = AlgorithmStatus.RUN
[docs]
@abstractmethod
def _run(self) -> None:
"""Abstract method to be implemented by subclasses for running steps."""
pass
[docs]
def post(self) -> Dict:
"""
Perform post-processing after the algorithm has run.
Returns
-------
Dict
Dictionary containing post-processing results.
"""
if self.status < AlgorithmStatus.RUN:
msg = "algorithm has not run yet"
raise RuntimeError(msg)
original_dir = os.getcwd()
os.chdir(self.output_dir)
result = self._post()
os.chdir(original_dir)
return result
[docs]
@abstractmethod
def _post(self) -> Dict:
"""Abstract method to be implemented by subclasses for post-processing steps."""
pass
[docs]
def main(self):
"""
Main method to execute the algorithm.
"""
time_sta = time.perf_counter()
self.prepare()
time_end = time.perf_counter()
self.timer["prepare"]["total"] = time_end - time_sta
if self.mpisize > 1:
self.mpicomm.Barrier()
time_sta = time.perf_counter()
self.run()
time_end = time.perf_counter()
self.timer["run"]["total"] = time_end - time_sta
print("end of run")
if self.mpisize > 1:
self.mpicomm.Barrier()
time_sta = time.perf_counter()
result = self.post()
time_end = time.perf_counter()
self.timer["post"]["total"] = time_end - time_sta
self.write_timer(self.proc_dir / "time.log")
return result
[docs]
def write_timer(self, filename: Path):
"""
Write the timing information to a file.
Parameters
----------
filename : Path
Path to the file where timing information will be written.
"""
with open(filename, "w") as fw:
fw.write("#in units of seconds\n")
def output_file(type):
d = self.timer[type]
fw.write("#{}\n total = {}\n".format(type, d["total"]))
for key, t in d.items():
if key == "total":
continue
fw.write(" - {} = {}\n".format(key, t))
output_file("init")
output_file("prepare")
output_file("run")
output_file("post")
[docs]
def _save_data(self, data, filename="state.pickle", ngen=3) -> None:
"""
Save data to a file with versioning.
Parameters
----------
data
Data to be saved.
filename
Name of the file to save the data.
ngen : int, default: 3
Number of generations for versioning.
"""
try:
fn = Path(filename + ".tmp")
with open(fn, "wb") as f:
pickle.dump(data, f)
except Exception as e:
print("ERROR: {}".format(e))
sys.exit(1)
for idx in range(ngen-1, 0, -1):
fn_from = Path(filename + "." + str(idx))
fn_to = Path(filename + "." + str(idx+1))
if fn_from.exists():
shutil.move(fn_from, fn_to)
if ngen > 0:
if Path(filename).exists():
fn_to = Path(filename + "." + str(1))
shutil.move(Path(filename), fn_to)
shutil.move(Path(filename + ".tmp"), Path(filename))
print("save_state: write to {}".format(filename))
[docs]
def _load_data(self, filename="state.pickle") -> Dict:
"""
Load data from a file.
Parameters
----------
filename
Name of the file to load the data from.
Returns
-------
Dict
Dictionary containing the loaded data.
"""
if Path(filename).exists():
try:
fn = Path(filename)
with open(fn, "rb") as f:
data = pickle.load(f)
except Exception as e:
print("ERROR: {}".format(e))
sys.exit(1)
print("load_state: load from {}".format(filename))
else:
print("ERROR: file {} not exist.".format(filename))
data = {}
return data
[docs]
def _show_parameters(self):
"""
Show the parameters of the algorithm.
"""
if self.mpirank == 0:
info = flatten_dict(self.info)
for k, v in info.items():
print("{:16s}: {}".format(k, v))
[docs]
def _check_parameters(self, param=None):
"""
Check the parameters of the algorithm against previous parameters.
Parameters
----------
param (optional)
Previous parameters to check against.
"""
info = flatten_dict(self.info)
info_prev = flatten_dict(param)
for k,v in info.items():
w = info_prev.get(k, None)
if v != w:
if self.mpirank == 0:
print("WARNING: parameter {} changed from {} to {}".format(k, w, v))
if self.mpirank == 0:
print("{:16s}: {}".format(k, v))
# utility
[docs]
def flatten_dict(d, parent_key="", separator="."):
"""
Flatten a nested dictionary.
Parameters
----------
d
Dictionary to flatten.
parent_key : str, default : ""
Key for the parent dictionary.
separator : str, default : "."
Separator to use between keys.
Returns
-------
dict
Flattened dictionary.
"""
items = []
if d:
for key_, val in d.items():
key = parent_key + separator + key_ if parent_key else key_
if isinstance(val, dict):
items.extend(flatten_dict(val, key, separator=separator).items())
else:
items.append((key, val))
return dict(items)