# SPDX-License-Identifier: MPL-2.0
#
# ODAT-SE -- an open framework for data analysis
# Copyright (C) 2020- The University of Tokyo
#
# This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
# If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
import os
import numpy as np
import odatse
import time
# type hints
from pathlib import Path
from typing import Callable, Optional, Dict, Tuple
[docs]
class Solver(odatse.solver.SolverBase):
    """
    Solver class for evaluating functions with given parameters.
    """
    x: np.ndarray
    fx: float
    _func: Optional[Callable[[np.ndarray], float]]
[docs]
    def __init__(self, info: odatse.Info) -> None:
        """
        Initialize the solver.
        Parameters
        ----------
        info: Info
            Information object containing solver configuration.
        """
        super().__init__(info)
        self._name = "function"
        self._func = None
        # for debug purpose
        self.delay = info.solver.get("delay", 0.0) 
[docs]
    def evaluate(self, x: np.ndarray, args: Tuple = (), nprocs: int = 1, nthreads: int = 1) -> float:
        """
        Evaluate the function with given parameters.
        Parameters
        ----------
        x : np.ndarray
            Input array for the function.
        args : Tuple, optional
            Additional arguments for the function.
        nprocs : int, optional
            Number of processes to use.
        nthreads : int, optional
            Number of threads to use.
        Returns
        -------
        float
            Result of the function evaluation.
        """
        self.prepare(x, args)
        cwd = os.getcwd()
        os.chdir(self.work_dir)
        self.run(nprocs, nthreads)
        os.chdir(cwd)
        result = self.get_results()
        return result 
[docs]
    def prepare(self, x: np.ndarray, args = ()) -> None:
        """
        Prepare the solver with the given parameters.
        Parameters
        ----------
        x : np.ndarray
            Input array for the function.
        args : tuple, optional
            Additional arguments for the function.
        """
        self.x = x 
[docs]
    def run(self, nprocs: int = 1, nthreads: int = 1) -> None:
        """
        Run the function evaluation.
        Parameters
        ----------
        nprocs : int, optional
            Number of processes to use.
        nthreads : int, optional
            Number of threads to use.
        Raises
        ------
        RuntimeError
            If the function is not set.
        """
        if self._func is None:
            raise RuntimeError(
                "ERROR: function is not set. Make sure that `set_function` is called."
            )
        self.fx = self._func(self.x)
        # for debug purpose
        if self.delay > 0.0:
            time.sleep(self.delay) 
[docs]
    def get_results(self) -> float:
        """
        Get the results of the function evaluation.
        Returns
        -------
        float
            Result of the function evaluation.
        """
        return self.fx 
[docs]
    def set_function(self, f: Callable[[np.ndarray], float]) -> None:
        """
        Set the function to be evaluated.
        Parameters
        ----------
        f : Callable[[np.ndarray], float]
            Function to be evaluated.
        """
        self._func = f