Source code for physbo.predictor

# SPDX-License-Identifier: MPL-2.0
# Copyright (C) 2020- The University of Tokyo
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.

import pickle as pickle
import numpy as np
from physbo import gp


[docs] class base_predictor(object): """ Base predictor is defined in this class. """ def __init__(self, config, model=None): """ Parameters ---------- config: set_config object (physbo.misc.set_config) model: model object A default model is set as gp.core.model """ self.config = config self.model = model if self.model is None: self.model = gp.core.model( cov=gp.cov.gauss(num_dim=None, ard=False), mean=gp.mean.const(), lik=gp.lik.gauss(), )
[docs] def fit(self, *args, **kwds): """ Default fit function. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def prepare(self, *args, **kwds): """ Default prepare function. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def delete_stats(self, *args, **kwds): """ Default function to delete status. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def get_basis(self, *args, **kwds): """ Default function to get basis This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def get_post_fmean(self, *args, **kwds): """ Default function to get a mean value of the score. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def get_post_fcov(self, *args, **kwds): """ Default function to get a covariance of the score. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def get_post_params(self, *args, **kwds): """ Default function to get parameters. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def get_post_samples(self, *args, **kwds): """ Default function to get samples. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def get_predict_samples(self, *args, **kwds): """ Default function to get prediction variables of samples. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def get_post_params_samples(self, *args, **kwds): """ Default function to get parameters of samples. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def update(self, *args, **kwds): """ Default function to update variables. This function must be overwritten in each model. Parameters ---------- args kwds Returns ------- """ raise NotImplementedError
[docs] def save(self, file_name): """ Default function to save information by using pickle.dump function. The protocol version is set as 3. Parameters ---------- file_name: str A file name to save self.__dict__ object. Returns ------- """ with open(file_name, "wb") as f: pickle.dump(self.__dict__, f, 4)
[docs] def load(self, file_name): """ Default function to load variables. The information is updated using self.update function. Parameters ---------- file_name: str A file name to load variables from the file. Returns ------- """ with open(file_name, "rb") as f: tmp_dict = pickle.load(f) self.config = tmp_dict["config"] self.model = tmp_dict["model"]