Source code for odatse.solver.function

# SPDX-License-Identifier: MPL-2.0
#
# ODAT-SE -- an open framework for data analysis
# Copyright (C) 2020- The University of Tokyo
#
# This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
# If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

import os
import numpy as np
import odatse
import time

# type hints
from pathlib import Path
from typing import Callable, Optional, Dict, Tuple


[docs] class Solver(odatse.solver.SolverBase): """ Solver class for evaluating functions with given parameters. """ x: np.ndarray fx: float _func: Optional[Callable[[np.ndarray], float]]
[docs] def __init__(self, info: odatse.Info) -> None: """ Initialize the solver. Parameters ---------- info: Info Information object containing solver configuration. """ super().__init__(info) self._name = "function" self._func = None # for debug purpose self.delay = info.solver.get("delay", 0.0)
[docs] def evaluate(self, x: np.ndarray, args: Tuple = (), nprocs: int = 1, nthreads: int = 1) -> float: """ Evaluate the function with given parameters. Parameters ---------- x : np.ndarray Input array for the function. args : Tuple, optional Additional arguments for the function. nprocs : int, optional Number of processes to use. nthreads : int, optional Number of threads to use. Returns ------- float Result of the function evaluation. """ self.prepare(x, args) cwd = os.getcwd() os.chdir(self.work_dir) self.run(nprocs, nthreads) os.chdir(cwd) result = self.get_results() return result
[docs] def prepare(self, x: np.ndarray, args = ()) -> None: """ Prepare the solver with the given parameters. Parameters ---------- x : np.ndarray Input array for the function. args : tuple, optional Additional arguments for the function. """ self.x = x
[docs] def run(self, nprocs: int = 1, nthreads: int = 1) -> None: """ Run the function evaluation. Parameters ---------- nprocs : int, optional Number of processes to use. nthreads : int, optional Number of threads to use. Raises ------ RuntimeError If the function is not set. """ if self._func is None: raise RuntimeError( "ERROR: function is not set. Make sure that `set_function` is called." ) self.fx = self._func(self.x) # for debug purpose if self.delay > 0.0: time.sleep(self.delay)
[docs] def get_results(self) -> float: """ Get the results of the function evaluation. Returns ------- float Result of the function evaluation. """ return self.fx
[docs] def set_function(self, f: Callable[[np.ndarray], float]) -> None: """ Set the function to be evaluated. Parameters ---------- f : Callable[[np.ndarray], float] Function to be evaluated. """ self._func = f