physbo.gp.predictor module

class physbo.gp.predictor.predictor(config, model=None)[ソース]

ベースクラス: base_predictor

delete_stats()[ソース]

Default function to delete status. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds

fit(training, num_basis=None)[ソース]

Fitting model to training dataset

パラメータ:
  • training (physbo.variable) -- dataset for training

  • num_basis (int) -- the number of basis (default: self.config.predict.num_basis)

get_basis(*args, **kwds)[ソース]
パラメータ:
  • args

  • kwds

get_post_fcov(training, test, diag=True)[ソース]

Calculating posterior variance-covariance matrix of model

パラメータ:
  • training (physbo.variable) -- training dataset. If already trained, the model does not use this.

  • test (physbo.variable) -- inputs

  • diag (bool) -- Diagonlization flag in physbo.exact.get_post_fcov function.

戻り値の型:

numpy.ndarray

get_post_fmean(training, test)[ソース]

Calculating posterior mean value of model

パラメータ:
  • training (physbo.variable) -- training dataset. If already trained, the model does not use this.

  • test (physbo.variable) -- inputs

戻り値の型:

numpy.ndarray

get_post_params(*args, **kwds)[ソース]
パラメータ:
  • args

  • kwds

get_post_samples(training, test, alpha=1)[ソース]

Drawing samples of mean values of model

パラメータ:
  • training (physbo.variable) -- training dataset. If already trained, the model does not use this.

  • test (physbo.variable) -- inputs (not used)

  • alpha (float) -- tuning parameter of the covariance by multiplying alpha**2 for np.random.multivariate_normal.

戻り値の型:

numpy.ndarray

get_predict_samples(training, test, N=1)[ソース]

Drawing samples of values of model

パラメータ:
  • training (physbo.variable) -- training dataset. If already trained, the model does not use this.

  • test (physbo.variable) -- inputs

  • N (int) -- number of samples (default: 1)

戻り値の型:

numpy.ndarray (N x len(test))

prepare(training)[ソース]

Initializing model by using training data set

パラメータ:

training (physbo.variable) -- dataset for training

update(training, test)[ソース]

Default function to update variables. This function must be overwritten in each model.

パラメータ:
  • args

  • kwds