{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# ガウス過程\n", "\n", "PHYSBOではガウス過程回帰を実行しながらベイズ最適化を行なっています。\n", "\n", "そのため、学習データが与えられた際にガウス過程回帰を実行することもでき、学習済みモデルを利用したテストデータの予測も行うことができます。\n", "\n", "ここでは、その手順について紹介します。\n", "\n", "\n", "## 探索候補データの準備\n", "\n", "本チュートリアルでは例として、Cuの安定した界面構造の探索問題を扱います。 目的関数の評価にあたる構造緩和計算には、実際には1回あたり数時間といったオーダーの時間を要しますが、本チュートリアルでは既に評価済みの値を使用します。問題設定については、以下の文献を参照してください。\n", "S. Kiyohara, H. Oda, K. Tsuda and T. Mizoguchi, “Acceleration of stable interface structure searching using a kriging approach”, Jpn. J. Appl. Phys. 55, 045502 (2016).\n", "\n", "データセットファイル [s5-210.csv](https://raw.githubusercontent.com/issp-center-dev/PHYSBO/master/examples/grain_bound/data/s5-210.csv) を保存し、次のように読み出します。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:11:41.987250Z", "start_time": "2020-12-04T06:11:41.537168Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cythonized version of physbo is used\n" ] } ], "source": [ "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import physbo\n", "\n", "\n", "def load_data():\n", " A = np.asarray(np.loadtxt('s5-210.csv',skiprows=1, delimiter=',') )\n", " X = A[:,0:3]\n", " t = -A[:,3]\n", " return X, t\n", "\n", "X, t = load_data()\n", "X = physbo.misc.centering( X )\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 学習データの定義\n", "\n", "対象データのうち、ランダムに選んだ1割をトレーニングデータとして利用し、別のランダムに選んだ1割をテストデータとして利用します。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:11:51.077070Z", "start_time": "2020-12-04T06:11:51.072211Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ntrain = 1798\n", "Ntest = 1798\n" ] } ], "source": [ "N = len(t)\n", "Ntrain = int(N*0.1)\n", "Ntest = min(int(N*0.1), N-Ntrain)\n", "\n", "id_all = np.random.choice(N, N, replace=False)\n", "id_train = id_all[0:Ntrain]\n", "id_test = id_all[Ntrain:Ntrain+Ntest]\n", "\n", "X_train = X[id_train]\n", "X_test = X[id_test]\n", "\n", "t_train = t[id_train]\n", "t_test = t[id_test]\n", "\n", "print(\"Ntrain =\", Ntrain)\n", "print(\"Ntest =\", Ntest)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## ガウス過程の学習と予測\n", "\n", "以下のプロセスでガウス過程を学習し、テストデータの予測を行います。\n", "\n", "1. ガウス過程のモデルを生成します。\n", "\n", "2. X_train(学習データのパラメータ), t_train(学習データの目的関数値)を用いてモデルを学習します。\n", "\n", "3. 学習されたモデルを用いてテストデータ(X_test)に対する予測を実行します。\n", "\n", "共分散の定義(ガウシアン)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:11:55.403677Z", "start_time": "2020-12-04T06:11:55.399915Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "cov = physbo.gp.cov.Gauss( X_train.shape[1],ard = False )" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "平均の定義" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:11:56.279543Z", "start_time": "2020-12-04T06:11:56.277082Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "mean = physbo.gp.mean.Const()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "尤度関数の定義(ガウシアン)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:11:57.077507Z", "start_time": "2020-12-04T06:11:57.075581Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "lik = physbo.gp.lik.Gauss()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "ガウス過程モデルの生成" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:11:57.832602Z", "start_time": "2020-12-04T06:11:57.828902Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------------\n", "WARNING: physbo.set_config is deprecated and will be removed in the future.\n", " Use physbo.SetConfig instead.\n", "--------------------------------------------------------------------------------\n" ] } ], "source": [ "gp = physbo.gp.Model(lik=lik,mean=mean,cov=cov)\n", "config = physbo.misc.set_config()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "ガウス過程モデルを学習" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.218792Z", "start_time": "2020-12-04T06:11:58.261609Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Start the initial hyper parameter searching ...\n", "Done\n", "\n", "Start the hyper parameter learning ...\n", "0 -th epoch marginal likelihood 17286.344814363532\n", "50 -th epoch marginal likelihood 3941.141158356132\n", "100 -th epoch marginal likelihood 1224.4325203178405\n", "150 -th epoch marginal likelihood 78.87839479900322\n", "200 -th epoch marginal likelihood -701.4985461671986\n", "250 -th epoch marginal likelihood -1344.3964878511624\n", "300 -th epoch marginal likelihood -1676.1549909534815\n", "350 -th epoch marginal likelihood -1815.6458851286318\n", "400 -th epoch marginal likelihood -1917.6558525403716\n", "450 -th epoch marginal likelihood -1968.1828578348936\n", "500 -th epoch marginal likelihood -1998.5771581916174\n", "Done\n", "\n" ] } ], "source": [ "gp.fit(X_train, t_train, config)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "学習されたガウス過程におけるパラメタを出力" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.227479Z", "start_time": "2020-12-04T06:12:58.221821Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "likelihood parameter = [-2.76166841]\n", "mean parameter in GP prior: [-1.05925247]\n", "covariance parameter in GP prior: [-0.68395227 -2.67155333]\n", "\n", "\n" ] } ], "source": [ "gp.print_params()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "テストデータの平均値(予測値)および分散を計算" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.605713Z", "start_time": "2020-12-04T06:12:58.244883Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "gp.prepare(X_train, t_train)\n", "fmean = gp.get_post_fmean(X_train, X_test)\n", "fcov = gp.get_post_fcov(X_train, X_test)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "予測の結果" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.618218Z", "start_time": "2020-12-04T06:12:58.607794Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "array([-1.04740175, -1.09605145, -1.0802498 , ..., -0.99629321,\n", " -1.16971609, -1.13126194])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fmean" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "分散の結果" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.628483Z", "start_time": "2020-12-04T06:12:58.622345Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "array([0.00043518, 0.00045412, 0.00055981, ..., 0.0003406 , 0.00037195,\n", " 0.00053075])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fcov" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "予測値の平均二乗誤差の出力" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.636081Z", "start_time": "2020-12-04T06:12:58.631461Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "0.011857078799698457" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean((fmean-t_test)**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "得られた回帰モデルについて、どの特徴量がどのくらい重要なのかを調べるための簡単な指標として permutation importance (PI)があります。\n", "PHYSBOでは `get_permutation_importance` 関数で計算できます。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "pi_mean, pi_std = gp.get_permutation_importance(X_train, t_train, n_perm=100)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxYAAAHqCAYAAACZcdjsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8ekN5oAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAnBElEQVR4nO3dCZAV5YEH8I/hGC65VFAMChG8QkQWlVLxikS0qFJqEzUJKrpZlYhxFXUNFYHFrMFADCaG4KoR1qwHmg3B9cBCc2CyqAkRRKOsm0VUBI0aRYmc01tfp97UDIwy8M3wjvn9qp4zr1+/7m++1zb97+94rbIsywIAAECCqpQ3AwAARIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJGsTylhNTU144403wh577BFatWpV7OIAAEBFid+l/cEHH4TevXuHqqqqyg0WMVT06dOn2MUAAICK9tprr4VPfepTlRssYktF4Q/t0qVLsYsDAAAVZd26dfmN/MJ1d8UGi0L3pxgqBAsAAGgejRl2YPA2AACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZG1CBRg4+bFQVd2x2MUAyswrN44sdhEAoGJosQAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQGUEi5kzZ4a+ffuG9u3bh6FDh4Znnnmm2EUCAADKKVjMnTs3jB8/PkyePDn84Q9/CIMGDQojRowIb731VrGLBgAAlEuw+N73vhcuuuiicOGFF4bDDjss3HrrraFjx47hzjvvLHbRAACARmoTimjTpk1hyZIlYcKECbXLqqqqwvDhw8PixYuLWTRgGzWbNoRKs379+lBJOnXqVOwiANCCFTVYvP3222Hr1q2hV69e9ZbH5y+99NJ262/cuDF/FKxbt263lBMI4bUZXwyVpvOMUFGyLCt2EQBowYreFWpnTJ06NXTt2rX20adPn2IXCQAAKHaLxV577RVat24d3nzzzXrL4/N99tlnu/Vjl6k40Ltui4VwAbtHnyt/GirNi986rdhFAICKUdRg0a5duzBkyJDwxBNPhFGjRuXLampq8ueXXXbZdutXV1fnD2D3q2rXPlQaYxIAoEKCRRRbIMaMGROOPPLIcPTRR4ebb745H1AZZ4kCAADKQ9GDxTnnnBP+/Oc/h0mTJoW1a9eGI444IixYsGC7Ad0AAEDpKnqwiGK3p4a6PgEAAOWhrGaFAgAASpNgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIFmbUAGenzIidOnSpdjFAACAFkuLBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAsjahAgyc/Fioqu5Y7GIAAOzQKzeOLHYRoFlosQAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAAKJ1g8d577zXVpgAAgJYQLL7zne+EuXPn1j4/++yzw5577hn222+/sGzZsqYsHwAAUKnB4tZbbw19+vTJf1+4cGH+ePTRR8Ppp58errnmmqYuIwAAUOLa7Mqb1q5dWxssHnroobzF4tRTTw19+/YNQ4cObeoyAgAAldhi0b179/Daa6/lvy9YsCAMHz48/z3LsrB169amLSEAAFCZLRZ///d/H77yla+EAQMGhHfeeSfvAhU9++yzoX///k1dRgAAoBKDxYwZM/JuT7HVYtq0aaFz58758jVr1oRLL720qcsIAABUYrBo27ZtuPrqq7dbfuWVVzZFmQAAgJbyPRY/+clPwrBhw0Lv3r3DqlWr8mU333xzmD9/flOWDwAAqNRgMWvWrDB+/Ph8bEX8YrzCgO1u3brl4QIAAGhZdilY3HLLLeH2228P3/zmN0Pr1q1rlx955JFh+fLlTVk+AACgUoPFypUrw+DBg7dbXl1dHdavX98U5QIAACo9WPTr1y8sXbp0u+XxOy0OPfTQpigXAABQ6bNCxfEV48aNCxs2bMi/FO+ZZ54J9957b5g6dWq44447mr6UAABA5QWLf/zHfwwdOnQI1113XfjrX/+af1lenB3q+9//fvjSl77U9KUEAAAqK1hs2bIl3HPPPWHEiBFh9OjRebD48MMPQ8+ePZunhAAAQOWNsWjTpk0YO3Zs3g0q6tixo1ABAAAt3C4N3j766KPDs88+2/SlAQAAWs4Yi0svvTRcddVV4fXXXw9DhgwJnTp1qvf64Ycf3lTlAwAAKjVYFAZoX3755bXLWrVqlc8QFX8WvokbAABoGdrs6hfkAQAAJAWLAw44YFfeBgAAVKhdChZ33XXXJ75+/vnn72p5AACAlhIs/umf/qne882bN+ffZ9GuXbt8+tnGBotFixaF6dOnhyVLloQ1a9aEefPmhVGjRu1KkQAAgHKbbvYvf/lLvUf8grwVK1aEYcOGhXvvvbfR21m/fn0YNGhQmDlz5q4UAwAAKOcWi4YMGDAg3HjjjeHcc88NL730UqPec/rpp+cPAKC81Wz62xfn0rgbq+zYtl9nQAsKFvnG2rQJb7zxRmguGzduzB8F69ata7Z9AQCN99qMLxa7CGWj84xil6A8xK8xoAUEiwcffHC7Dz6OkfjhD38YjjvuuNBcpk6dGqZMmdJs2wcAAHZNq2wX4mBVVf2hGfFL8fbee+/wuc99Ltx0001h33333fmCtGq1w8HbDbVY9OnTJ/S54v5QVd1xp/cJADQNXaEa78VvnVbsIpQFXaFKQ7ze7tq1a3j//fdDly5dmr7FoqamJhRDdXV1/gAASktVu/bFLkLZcMFMpdqlWaGuv/76fHrZbX300Uf5awAAQMuyS8EijnOIU8xuK4aNnRkDEbexdOnS/BGtXLky//3VV1/dlWIBAABFsktdoeKwjDgmYlvLli0LPXr0aPR2fv/734eTTz659vn48ePzn2PGjAlz5szZlaIBAAClHiy6d++eB4r4OOigg+qFi61bt+YtEGPHjm309k466SRTiQEAQEsLFjfffHMeBP7hH/4h7/IUR4gXtGvXLvTt2zccc8wxzVFOAACgUoJF7KIU9evXLxx77LGhbdu2zVUuAACg0sdYnHjiibW/b9iwIWzatKne6zua4xYAAKgsuzQrVJz96bLLLgs9e/bM52KOYy/qPgAAgJZll4LFNddcE37xi1+EWbNm5V9Yd8cdd+RjLnr37h3uuuuupi8lAABQeV2h/uu//isPEHFWpwsvvDAcf/zxoX///uGAAw4Id999dxg9enTTlxQAAKisFot33303fPrTn64dTxGfR8OGDQuLFi1q2hICAACVGSxiqIjfkh0dcsgh4f77769tyejWrVvTlhAAAKjMYBG7P8Vv2Y6+8Y1vhJkzZ4b27duHK6+8Mh9/AQAAtCy7NMYiBoiC4cOHh5deeiksWbIkH2dx+OGHN2X5AACASg0WdcXvsYiDtuMDAABomXapK9TWrVvDt771rbDffvuFzp07h//7v//Ll0+cODH8+Mc/buoyAgAAlRgsbrjhhjBnzpwwbdq00K5du9rlAwcOzL/TAgAAaFl2KVjE77C47bbb8u+raN26de3yQYMG5eMtAACAlmWXgsXq1avzgdrbqqmpCZs3b26KcgEAAJUeLA477LDw5JNPbrf8pz/9aRg8eHBTlAsAAKj0WaEmTZoUxowZk7dcxFaKn/3sZ2HFihV5F6mHHnqo6UsJAABUTotFnP0py7Jw5pln5t+y/fjjj4dOnTrlQePFF1/Ml33+859vvtICAADl32IxYMCAsGbNmtCzZ89w/PHHhx49eoTly5eHXr16NV8JAQCAymqxiK0VdT366KNh/fr1TV0mAACgJQze/rigAQAAtEw7FSxatWqVP7ZdBgAAtGxtdraF4oILLgjV1dX58w0bNoSxY8fmA7jrirNEAQAALcdOBYs4xWxd5557blOXBwAAqPRgMXv27OYrCQAA0DIHbwMAAESCBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACRrEyrA81NGhC5duhS7GAAA0GJpsQAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQLI2oQIMnPxYqKruWOxiAADwMV65cWSxi0Az02IBAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIBkggUAAJBMsAAAAJIJFgAAQDLBAgAASCZYAAAAyQQLAAAgmWABAAAkEywAAIDyDhZTp04NRx11VNhjjz1Cz549w6hRo8KKFSuKWSQAAKDcgsWvf/3rMG7cuPDUU0+FhQsXhs2bN4dTTz01rF+/vpjFAgAAdlKbUEQLFiyo93zOnDl5y8WSJUvCCSecULRyAQAAZRQstvX+++/nP3v06FHsogAAfKyaTRuKXYSyo0fKzunUqVMoNyUTLGpqasIVV1wRjjvuuDBw4MAG19m4cWP+KFi3bt1uLCEAwN+8NuOLxS5C2ek8o9glKC9ZloVyUzKzQsWxFs8//3y47777PnGwd9euXWsfffr02a1lBAAAGtYqK4E4dNlll4X58+eHRYsWhX79+n3seg21WMRw0eeK+0NVdcfdVFoAoKXTFWrnvfit04pdhLLSqUS6QsXr7XhDPw5Z6NKlS+l2hYqZ5utf/3qYN29e+NWvfvWJoSKqrq7OHwAAxVTVrn2xi1B2SuVCmebTptjdn+655568tSJ+l8XatWvz5TEVdejQoZhFAwAAymWMxaxZs/JmlZNOOinsu+++tY+5c+cWs1gAAMBOKnpXKAAAoPyVzKxQAABA+RIsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwAAIJlgAQAAJGsTKsDzU0aELl26FLsYAADQYmmxAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAsjahjGVZlv9ct25dsYsCAAAVp3CdXbjurthg8c477+Q/+/TpU+yiAABAxfrggw9C165dKzdY9OjRI//56quv7vAPZeeSaQxrr732WujSpUuxi1Mx1GvzUK/NQ702D/XaPNRr81CvTW9dGdZpbKmIoaJ37947XLesg0VV1d+GiMRQUS4fTjmJdapem556bR7qtXmo1+ahXpuHem0e6rXpdSmzOm3sDXyDtwEAgGSCBQAA0LKDRXV1dZg8eXL+k6ajXpuHem0e6rV5qNfmoV6bh3ptHuq16VVXeJ22yhozdxQAAECltlgAAAClQbAAAACSCRYAAEB5BYuZM2eGvn37hvbt24ehQ4eGZ5555hPXf+CBB8IhhxySr//Zz342PPLII/Vej8NDJk2aFPbdd9/QoUOHMHz48PDyyy/XW+fdd98No0ePzucK7tatW/jqV78aPvzww3rrPPfcc+H444/P9xO/tGTatGmhnJRivb7yyiuhVatW2z2eeuqpUC6KUa833HBDOPbYY0PHjh3zem1I/ELIkSNH5uv07NkzXHPNNWHLli2hXJRqvTZ0vN53332hXOzueo3/j8f/7/v165e/fuCBB+YDEjdt2lRvO86vTV+v5X5+LcY54Iwzzgj7779/vo243nnnnRfeeOONeus4Vpu+Xsv9WC1WvRZs3LgxHHHEEXmdLV26NJTF8ZrtJvfdd1/Wrl277M4778xeeOGF7KKLLsq6deuWvfnmmw2u/9vf/jZr3bp1Nm3atOyPf/xjdt1112Vt27bNli9fXrvOjTfemHXt2jX7+c9/ni1btiw744wzsn79+mUfffRR7TqnnXZaNmjQoOypp57Knnzyyax///7Zl7/85drX33///axXr17Z6NGjs+effz679957sw4dOmT/9m//lpWDUq3XlStXxkkBsscffzxbs2ZN7WPTpk1ZOShWvU6aNCn73ve+l40fPz5fd1tbtmzJBg4cmA0fPjx79tlns0ceeSTba6+9sgkTJmTloFTrNYrH6+zZs+sdr3W3UcqKUa+PPvpodsEFF2SPPfZY9qc//SmbP39+1rNnz+yqq66q3Ybza/PUazmfX4t1Doj//y9evDh75ZVX8m0ec8wx+aPAsdo89VrOx2ox67Xg8ssvz04//fS8DuO/+eVwvO62YHH00Udn48aNq32+devWrHfv3tnUqVMbXP/ss8/ORo4cWW/Z0KFDs0suuST/vaamJttnn32y6dOn177+3nvvZdXV1XkFR/FDjR/G7373u9p14km7VatW2erVq/PnP/rRj7Lu3btnGzdurF3n2muvzQ4++OCsHJRqvRZOJnX/RygnxajXuuIFbkMXwDFIVFVVZWvXrq1dNmvWrKxLly71juFSVar1GsXjdd68eVk5Kna9FsR/TOM/kAXOr81Tr+V8fi2VOo2BLf6bVbjAdaw2T72W87Fa7Hp95JFHskMOOSQPNNvWYSkfr7ulK1Rswl2yZEne3FNQVVWVP1+8eHGD74nL664fjRgxonb9lStXhrVr19ZbJ37deGymKqwTf8ZuD0ceeWTtOnH9uO+nn366dp0TTjghtGvXrt5+VqxYEf7yl7+EUlbK9Vq3mTR21xk2bFh48MEHQzkoVr02Rlw3Nq326tWr3n7WrVsXXnjhhVDKSrleC8aNGxf22muvcPTRR4c777wzb7IudaVUr++//37o0aNHvf04vzZ9vZbr+bVU6jR25b377rvz7pFt27at3Y9jtenrtVyP1WLX65tvvhkuuuii8JOf/CTvwtvQfkr1eN0tweLtt98OW7durXcxFMXnsYIbEpd/0vqFnztaJx7IdbVp0yY/Qdddp6Ft1N1HqSrleu3cuXO46aab8r6GDz/8cH4yGTVqVFmcUIpVr43heG2eeo2uv/76cP/994eFCxeGL3zhC+HSSy8Nt9xySyh1pVKv//u//5vX1yWXXLLD/dTdR6kq5Xot1/Nrsev02muvDZ06dQp77rlnPlZt/vz5O9xP3X2UqlKu13I9VotZr1mWhQsuuCCMHTu23g3cxuyn7j6KpU1R907Find9x48fX/v8qKOOygd0TZ8+Pb9zAaVm4sSJtb8PHjw4rF+/Pj9eL7/88qKWqxysXr06nHbaaeGss87K77LRvPXq/Lpr4kQXcWD8qlWrwpQpU8L5558fHnrooXxgLM1Tr47VnRdvJHzwwQdhwoQJoRztlhaLeGC1bt06b9qpKz7fZ599GnxPXP5J6xd+7midt956q97rcfac2FxXd52GtlF3H6WqlOu1IbGpL959K3XFqtfGcLw2T71+3PH6+uuv57NylLJi12u8SDj55JPz7g+33XZbo/ZTdx+lqpTrtVzPr8Wu07j/gw46KHz+85/PZ3yLs/UUZidyrDZPvZbrsVrMev3FL36Rd3Wqrq7Oe4P0798/Xx5bL8aMGfOJ+6m7j4oOFrEP2JAhQ8ITTzxRu6ympiZ/fswxxzT4nri87vpR7KJQWD9OxRcrr+46sZ957ONfWCf+fO+99/I+cgXxA4v7jgd2YZ1FixaFzZs319vPwQcfHLp37x5KWSnXa0PiVGlxerVSV6x6bYy47vLly+sFu7ifOO3vYYcdFkpZKdfrxx2v8RwQT+6lrJj1Gu+on3TSSfn+Z8+enfc/3nY/zq9NX6/len4tpXNA3G9UuHHgWG2eei3XY7WY9fqDH/wgLFu2LK+n+ChMVzt37tx86vSSP15355RdcdT7nDlz8lmFLr744nzKrsLsNuedd172jW98o96UXW3atMm++93vZi+++GI2efLkBqfsituIsxA899xz2ZlnntngtKiDBw/Onn766ew3v/lNNmDAgHrTosbR+HHKrrj/OGVXLGfHjh1LYsqucq7XWJ577rkn30d83HDDDflsRnHKtnJQrHpdtWpVPvPDlClTss6dO+e/x8cHH3xQb7rZU089NVu6dGm2YMGCbO+99y6r6WZLsV4ffPDB7Pbbb8+3+/LLL+czbsTzQJymthwUo15ff/31fJrpU045Jf+97lSSBc6vzVOv5Xx+LUadxmnRb7nllvz/+Tgt6hNPPJEde+yx2YEHHpht2LAhX8ex2jz1Ws7HajH/zaqroZm1Svl43W3BIooH4P7775/PCRyn8IoHZcGJJ56YjRkzpt76999/f3bQQQfl63/mM5/JHn744Xqvx2m7Jk6cmFdu/ODjiXjFihX11nnnnXfyC954MRGn5LzwwgtrLyYK4jzCw4YNy7ex33775R96OSnFeo3/Ex566KH5gR5fj+V64IEHsnJSjHqN24wnkG0fv/zlL2vXiSfwOK91nLM6fodFnN9+8+bNWbkoxXqN0yUfccQR+fHcqVOn/Dtabr311nxqwXKxu+s1Tt3bUJ1ue7/K+bXp67Xcz6+7u07jxdvJJ5+c9ejRI3+9b9++2dixY/PgVpdjtenrtdyP1WL9m1XXx03ZW6rHa6v4n+K2mQAAAOVut4yxAAAAKptgAQAAJBMsAACAZIIFAACQTLAAAACSCRYAAEAywQIAAEgmWAAAAMkECwB2qFWrVuHnP/95sYsBQAkTLAAa4YILLsgvruOjXbt2oX///uH6668PW7ZsCaVa3lGjRu30+/7lX/4lHHHEEdstX7NmTTj99NNDc5ozZ07o1q1bKGV9+/YNN998c7GLAVCS2hS7AADl4rTTTguzZ88OGzduDI888kgYN25caNu2bZgwYcJOb2vr1q15SKmqKo/7O/vss09oyTZt2pQHSgA+Xnn8iwZQAqqrq/ML7AMOOCB87WtfC8OHDw8PPvhg/loMG1dffXXYb7/9QqdOncLQoUPDr371q+3uxsf1DzvssHxbr776an4H/F//9V/D+eefHzp37pxvO67z5z//OZx55pn5ssMPPzz8/ve//8RWhXgXPW6r8Pq///u/h/nz59e2shTKcu2114aDDjoodOzYMXz6058OEydODJs3b64t45QpU8KyZctq3xeXNdQVavny5eFzn/tc6NChQ9hzzz3DxRdfHD788MPtWky++93vhn333TdfJwaxwr4ao/B33nnnnWH//ffP6+LSSy/NQ9m0adPyz6Jnz57hhhtuqPe+WNZZs2blLSyxfPHv/OlPf1pvncaWP267d+/e4eCDDw4nnXRSWLVqVbjyyitr6yd65513wpe//OX8s4/1+tnPfjbce++99fYX33v55ZeHf/7nfw49evTIyx7/vrree++9cMkll4RevXqF9u3bh4EDB4aHHnqo9vXf/OY34fjjj8/L3KdPn3x769evb3R9AjQ3wQJgF8ULvHgnO7rsssvC4sWLw3333Reee+65cNZZZ+UtHC+//HLt+n/961/Dd77znXDHHXeEF154Ib8ojmbMmBGOO+648Oyzz4aRI0eG8847Lw8a5557bvjDH/4QDjzwwPx5lmWNKlcMOGeffXa+/9iFKT6OPfbY/LU99tgjDwt//OMfw/e///1w++235/uPzjnnnHDVVVeFz3zmM7Xvi8u2FS9mR4wYEbp37x5+97vfhQceeCA8/vjjeR3U9ctf/jL86U9/yn/GoBP3WwgqjRXf/+ijj4YFCxbkF+s//vGP8zp6/fXXw69//eu8Pq+77rrw9NNP13tfDExf+MIX8pA0evTo8KUvfSm8+OKLO1X+J554IqxYsSIsXLgwv8D/2c9+Fj71qU/lXeAK9RNt2LAhDBkyJDz88MPh+eefz0NK/AyfeeaZetuLdRBDZyxrDEZxO3HbUU1NTR6Efvvb34b/+I//yD+fG2+8MbRu3bq2HuLnGf+meHzNnTs3DxrblhmgqDIAdmjMmDHZmWeemf9eU1OTLVy4MKuurs6uvvrqbNWqVVnr1q2z1atX13vPKaeckk2YMCH/ffbs2TEVZEuXLq23zgEHHJCde+65tc/XrFmTrzdx4sTaZYsXL86XxdeiyZMnZ4MGDaq3nRkzZuTbaqi8n2T69OnZkCFDap83tO0o7n/evHn577fddlvWvXv37MMPP6x9/eGHH86qqqqytWvX1u4/lmfLli2165x11lnZOeec87FliXXUtWvXemXp2LFjtm7dutplI0aMyPr27Ztt3bq1dtnBBx+cTZ06tV5Zx44dW2/bQ4cOzb72ta/tVPl79eqVbdy4sd524t8U63pHRo4cmV111VW1z0888cRs2LBh9dY56qijsmuvvTb//bHHHsv3v2LFiga399WvfjW7+OKL6y178skn8/d89NFHOywPwO5gjAVAI8W71rE7TuzOE+8wf+UrX8m7s8RuRrF7TuxiVFfsHhW72RTEPvqxW9O26i6L3WCi2J1m22VvvfVW8liHeKf7Bz/4QX4HPHb9iYPPu3TpslPbiHf+Bw0alN99L4gtLrFO4h3+Qnljy0fhjnsUu0TFLkg7I3bviq0sBXHbcZt1x6bEZbFu6jrmmGO2e7506dKdKn/8DBozriJ+9t/+9rfD/fffH1avXp23YsXPPnaLqmvbzz7WR6HcsWyxNWTbY6ggtrzEloq77767dlnMULHMK1euDIceeugOywnQ3AQLgEY6+eST87778WIz9rtv0+Zvp9B4gR4vdpcsWVLvQjqKQaRu16lCv/y64gDwgsLrDS2LF5FRvKjetltUY8YuxK5asVtQHEcRuwJ17do177p10003heZQ928o/B2FvyFlG02x3caoGzw+yfTp0/NuZXGcSwwj8X1XXHFFbTe5gk8qdzw2Pkk8xuL4iziuYltx/AlAKRAsABopXjDGaWa3NXjw4Pyudbz7HAfXNre99947rF27Ng8XhdBRuBtfEMNPLFNd//3f/50PDv/mN79ZuywORt7R+7YV747HsRJxrELh4juODYiBJw5yLgVPPfVUPi6l7vP4OaWWv6H6ie+NA+3jmJgohoX/+Z//yQfpN1ZszYjjRuL7Gmq1+Lu/+7t83EVDxx9AqTB4GyBRvBCMLQHxQjYO8I1dU+LA3alTp+YDeptanGEozhoVBwDHLk0zZ87MBzhv24Uodp2JXXvefvvtvEVjwIAB+UxUsZUivi92iZo3b95274vlj0Elvi926dlW/FvjrEVjxozJByvHwdlf//rX8wHLhW5ExRYHZMfZpOKF+uTJk/PPozDQOaX8sX4WLVqUd3mK9RPFeo2DsGNwi92sYsvCm2++uVPlPfHEE8MJJ5yQD86O24qfQWHQemE2r7j9+DfEzyZOChBn/TJ4GyglggVAE4jfbxGDRZxVKd71jlOVxhmHmqObSrzj/qMf/SgPFHGsQLxojjNB1XXRRRfl5TjyyCPzFo54V/2MM87Ip0qNF6NxGtd4oRpnT6orXtjG2Ydit6/4vm2nTY3i2IHHHnssvPvuu+Goo44KX/ziF8Mpp5wSfvjDH4ZSEbt7xQAVWwLuuuuu/O8otCCklD/O5PTKK6/kM3XF+onirFSxRSF2L4uhL46D2ZUvJ/zP//zPvDxx6tpY1jg1baF1JP4dcRasGJRiq1hsfZk0aVLeJQ+gVLSKI7iLXQgAaCqxe1hsidmVi3sAdp0WCwAAIJlgAQAAJDMrFAAVRQ9fgOLQYgEAACQTLAAAgGSCBQAAkEywAAAAkgkWAABAMsECAABIJlgAAADJBAsAACCZYAEAAIRU/w9H9BLxmptUeAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "features = list(range(len(pi_mean)))\n", "\n", "plt.figure(figsize=(8, 5))\n", "plt.barh(\n", " features,\n", " pi_mean,\n", " xerr=pi_std,\n", ")\n", "plt.gca().invert_yaxis()\n", "plt.yticks(features)\n", "plt.xlabel(\"Permutation Importance\")\n", "plt.ylabel(\"Features\")\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "このグラフにおいて、棒グラフはPIの平均値を、線分は標準偏差を示しており、0番の特徴量と比べて1,2番の特徴量のほうが重要そうだ、ということが読み取れます。\n", "\n", "`gp.model` だけではなく, `policy` も `get_permutation_importance` 関数を持っています。\n", "使い方は `policy.get_permutation_importance(n_perm)` です。\n", "`policy` では保存してある学習済みデータに対してPIを計算するため、 `model` とは違い、改めて渡す必要はありません。" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 訓練済みモデルによる予測\n", "\n", "学習済みモデルのパラメタをgp_paramsとして読み出し、これを用いた予測を行います。\n", "\n", "gp_paramsおよび学習データ(X_train, t_train)を記憶しておくことで、訓練済みモデルによる予測が可能となります。" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "学習されたパラメタを準備(学習の直後に行う必要あり)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.645968Z", "start_time": "2020-12-04T06:12:58.639012Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "array([-2.76166841, -1.05925247, -0.68395227, -2.67155333])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#学習したパラメタを1次元配列として準備\n", "gp_params = np.append(np.append(gp.lik.params, gp.prior.mean.params), gp.prior.cov.params)\n", "\n", "gp_params" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "学習に利用したモデルと同様のモデルをgpとして準備" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:58.666019Z", "start_time": "2020-12-04T06:12:58.653259Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#共分散の定義 (ガウシアン)\n", "cov = physbo.gp.cov.Gauss( X_train.shape[1],ard = False )\n", "\n", "#平均の定義\n", "mean = physbo.gp.mean.Const()\n", "\n", "#尤度関数の定義 (ガウシアン)\n", "lik = physbo.gp.lik.Gauss()\n", "\n", "#ガウス過程モデルの生成\n", "gp = physbo.gp.Model(lik=lik,mean=mean,cov=cov)\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "学習済みのパラメタをモデルに入力し予測を実行" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:59.016429Z", "start_time": "2020-12-04T06:12:58.673034Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#学習済みのパラメタをガウス過程に入力\n", "gp.set_params(gp_params)\n", "\n", "\n", "#テストデータの平均値(予測値)および分散を計算\n", "gp.prepare(X_train, t_train)\n", "fmean = gp.get_post_fmean(X_train, X_test)\n", "fcov = gp.get_post_fcov(X_train, X_test)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "予測の結果" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:59.020795Z", "start_time": "2020-12-04T06:12:59.017606Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "array([-1.04740175, -1.09605145, -1.0802498 , ..., -0.99629321,\n", " -1.16971609, -1.13126194])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fmean" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "分散の結果" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:59.026523Z", "start_time": "2020-12-04T06:12:59.023035Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "array([0.00043518, 0.00045412, 0.00055981, ..., 0.0003406 , 0.00037195,\n", " 0.00053075])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fcov" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "予測値の平均二乗誤差の出力" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2020-12-04T06:12:59.033497Z", "start_time": "2020-12-04T06:12:59.027871Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "0.011857078799698457" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean((fmean-t_test)**2)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%% md\n" } }, "source": [ "(注) 上の例では事前に登録されているXと同じものを利用して予測を行いました。\n", "学習済みのモデルを利用してXに含まれていないパラメータ X_new に対して予測をしたい場合には、\n", "学習モデルで使用したデータXの平均(X_{mean})と標準偏差(X_{std})を求めていただいたうえで、\n", "X_{new} = (X_{new} - X_{mean}) / X_{std}\n", "の変形を行うことで予測を行うことができます。\n", "また、渡す際のデータ形式はndarray形式になっています。\n", "そのため、X_{new}が一つのデータの場合には事前に変換する必要があります。\n", "例えば、X_{new}が実数である場合には、\n", "X_new = np.array(X_new).reshape(1)\n", "などとして変換する必要があります。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 4 }