physbo.predictor module

class physbo.predictor.BasePredictor(config, model=None)[ソース]

ベースクラス: object

Base predictor is defined in this class.

delete_stats(*args, **kwds)[ソース]

Default function to delete status. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

fit(*args, **kwds)[ソース]

Default fit function. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

get_basis(*args, **kwds)[ソース]

Default function to get basis This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

get_permutation_importance(training, num_permutations=10, comm=None, split_features_parallel=False)[ソース]

Calculate permutation importance of the predictor.

パラメータ:
  • training (physbo.variable) -- training dataset. If already trained, the model does not use this.

  • test (physbo.variable) -- input X

  • num_permutations (int) -- number of permutations

  • comm (MPI.Comm) -- MPI communicator

  • split_features_parallel (bool) -- If true, split features in parallel.

戻り値:

  • numpy.ndarray -- importance_mean

  • numpy.ndarray -- importance_std

get_post_fcov(*args, **kwds)[ソース]

Default function to get a covariance of the score. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

get_post_fmean(*args, **kwds)[ソース]

Default function to get a mean value of the score. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

get_post_params(*args, **kwds)[ソース]

Default function to get parameters. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

get_post_params_samples(*args, **kwds)[ソース]

Default function to get parameters of samples. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

get_post_samples(*args, **kwds)[ソース]

Default function to get samples. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

get_predict_samples(*args, **kwds)[ソース]

Default function to get prediction variables of samples. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

load(file_name)[ソース]

Default function to load variables. The information is updated using self.update function.

パラメータ:

file_name (str) -- A file name to load variables from the file.

prepare(*args, **kwds)[ソース]

Default prepare function. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

save(file_name)[ソース]

Default function to save information by using pickle.dump function. The protocol version is set as 3.

パラメータ:

file_name (str) -- A file name to save self.__dict__ object.

update(*args, **kwds)[ソース]

Default function to update variables. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds