import numpy as np
import pickle as pickle
from . import gp
[ドキュメント]class base_predictor(object):
"""
Base predictor is defined in this class.
"""
def __init__(self, config, model=None):
"""
Parameters
----------
config: set_config object (physbo.misc.set_config)
model: model object
A default model is set as gp.core.model
"""
self.config = config
self.model = model
if self.model is None:
self.model = gp.core.model(
cov=gp.cov.gauss(num_dim=None, ard=False),
mean=gp.mean.const(),
lik=gp.lik.gauss(),
)
[ドキュメント] def fit(self, *args, **kwds):
"""
Default fit function.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def prepare(self, *args, **kwds):
"""
Default prepare function.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def delete_stats(self, *args, **kwds):
"""
Default function to delete status
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def get_basis(self, *args, **kwds):
"""
Default function to get basis
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def get_post_fmean(self, *args, **kwds):
"""
Default function to get a mean value of the score.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def get_post_fcov(self, *args, **kwds):
"""
Default function to get a covariance of the score.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def get_post_params(self, *args, **kwds):
"""
Default function to get parameters.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def get_post_samples(self, *args, **kwds):
"""
Default function to get samples.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def get_predict_samples(self, *args, **kwds):
"""
Default function to get prediction variables of samples.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def get_post_params_samples(self, *args, **kwds):
"""
Default function to get parameters of samples.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def update(self, *args, **kwds):
"""
Default function to update variables.
This function must be overwritten in each model.
Parameters
----------
args
kwds
Returns
-------
"""
raise NotImplementedError
[ドキュメント] def save(self, file_name):
"""
Default function to save information by using pickle.dump function.
The protocol version is set as 2.
Parameters
----------
file_name: str
A file name to save self.__dict__ object.
Returns
-------
"""
with open(file_name, "w") as f:
pickle.dump(self.__dict__, f, 2)
[ドキュメント] def load(self, file_name):
"""
Default function to load variables.
The information is updated using self.update function.
Parameters
----------
file_name: str
A file name to load variables from the file.
Returns
-------
"""
with open(file_name) as f:
tmp_dict = pickle.load(f)
self.update(tmp_dict)