post Image
pythonでベイズ線形回帰実装

はじめに

pythonでベイズ線形回帰を実装しました.
教科書として『パターン認識と機械学習 上』を使いました.

本記事の構成

  • はじめに
  • 線形回帰
    • 最小二乗法
    • 正則化
    • 最尤推定
    • ベイズ線形回帰
  • 実装
  • 結果
  • おわりに

線形回帰

入力変数 $\boldsymbol x$ から目標変数 $t$ を予測することを回帰と言います.
$\boldsymbol x$ と $t$ に以下の関係があると仮定します.

t = y(\boldsymbol x, \boldsymbol w) = \sum_{j = 0}^{M - 1} w_{j} \phi_{j}(\boldsymbol x) = \boldsymbol w^{T} \boldsymbol \phi(\boldsymbol x) \tag{1}

式$(1)$は線形結合の形をしているため,線形回帰 と呼ばれます.
$\boldsymbol \phi$ は基底関数で,多項式基底やガウス基底がよく使用されます.
以降の章では,訓練集合 $(\boldsymbol x_{n}, t_{n}) \ (n = 1, \cdots, N)$ から $\boldsymbol w$ を求めることを考えます.

最小二乗法

データ点 $\boldsymbol x_{n}$ における予測値と目標値 $t_{n}$ のずれを測る 誤差関数 を以下のように定義します.

E_{D}(\boldsymbol w) = \cfrac{1}{2} \sum_{n = 1}^{N} \bigl( t_{n} - \boldsymbol w^{T} \boldsymbol \phi(\boldsymbol x_{n}) \bigr)^{2} \tag{2}

この誤差関数 $E_{D}(\boldsymbol w)$ を最小化するため微分方程式を解きます.

\begin{align}
&\sum_{n = 1}^{N} t_{n} {\boldsymbol \phi(\boldsymbol x_{n})}^{T} - \boldsymbol w^{T} \left( \sum_{n = 1}^{N} \boldsymbol \phi(\boldsymbol x_{n}) {\boldsymbol \phi(\boldsymbol x_{n})}^{T} \right) =\boldsymbol 0 \\
& \boldsymbol w = \bigl( \boldsymbol \Phi^{T} \boldsymbol \Phi \bigr)^{-1}\boldsymbol \Phi^{T} \boldsymbol t \tag{3} \\
\end{align}

ここで $\boldsymbol \Phi, \boldsymbol t$ は以下のように定義しています.

\begin{align}
\boldsymbol \Phi &=
\left(
  \begin{array}{ccc}
    \phi_{0}(\boldsymbol x_{1}) & \cdots & \phi_{M - 1}(\boldsymbol x_{1}) \\
    \vdots & \ddots & \vdots \\
    \phi_{0}(\boldsymbol x_{N}) & \cdots & \phi_{M - 1}(\boldsymbol x_{N}) \\
  \end{array}
\right) \tag{4} \\
\boldsymbol t &= \bigl( t_{1}, \cdots, t_{N} \bigr)^{T} \tag{5} \\
\end{align}

正則化

一般に式$(1)$における $M$ を大きくするほど,訓練集合に対する誤差関数 $E_{D}(\boldsymbol w)$ は小さくなります.
このとき,自由度が高くなるが故に過度なフィッティングが起こる場合があります.(下図参考)
これを 過学習 と呼びます.

over_fitting.png

過学習を防ぐために,誤差関数を以下のように定義します.

E(\boldsymbol w) = E_{D}(\boldsymbol w) + \lambda E_{W}(\boldsymbol w) \tag{6}

$E_{W}(\boldsymbol w)$ を 正則化項 と呼び,代表として重みベクトルの二乗和 $\cfrac{1}{2} \boldsymbol w^{T} \boldsymbol w$ が挙げられます.
このとき,重みパラメータは以下のように求められます.

\boldsymbol w = \bigl( \lambda \boldsymbol {\rm I} + \boldsymbol \Phi^{T} \boldsymbol \Phi \bigr)^{-1}\boldsymbol \Phi^{T} \boldsymbol t \tag{7}

正則化の効果を示す図を掲載します.
下図は,左が正則化なし / 右が正則化ありの結果になります.
過学習が緩和されていることが確認できます.

regularization.png

最尤推定

対数尤度を最大化することで回帰問題を解きます.
目標変数 $t_{n}$ が $y(\boldsymbol x_{n}, \boldsymbol w)$ とガウスノイズ $\epsilon$ の和で表されるとします.
ガウスノイズ $\epsilon$ は,期待値が $0$ で精度(分散の逆数)が $\beta$ のガウス確率変数です.

t_{n} = y(\boldsymbol x_{n}, \boldsymbol w) + \epsilon \tag{8}

このとき,対数尤度は以下のように表されます.

\begin{align}
\ln p(\boldsymbol t \mid \boldsymbol X, \boldsymbol w, \beta)
&= \ln \prod_{n = 1}^{N} \mathcal{N} \bigl( t_{n} \mid y(\boldsymbol x_{n}, \boldsymbol w), \beta^{-1} \bigr) \\
&= \sum_{n = 1}^{N} \ln \mathcal{N} \bigl( t_{n} \mid y(\boldsymbol x_{n}, \boldsymbol w), \beta^{-1} \bigr) \\
&= \cfrac{N}{2} \ln \beta - \cfrac{N}{2} \ln (2 \pi) - \cfrac{\beta}{2} \sum_{n = 1}^{N} \bigl( t_{n} - \boldsymbol w^{T} \boldsymbol \phi(\boldsymbol x_{n}) \bigr)^{2} \tag{9} \\
\end{align}

式$(9)$より,対数尤度の最大化は二乗誤差の最小化と等価であることが分かります.
したがって,重みパラメータは式$(3)$で表されます.

ベイズ線形回帰

事前分布を導入し,ベイズの定理を用いることで回帰問題を解きます.
尤度関数は以下のように表されます.

\begin{align}
p(\boldsymbol t \mid \boldsymbol X, \boldsymbol w, \beta)
&= \prod_{n = 1}^{N} \mathcal{N} \bigl( t_{n} \mid y(\boldsymbol x_{n}, \boldsymbol w), \beta^{-1} \bigr) \\
&\propto \prod_{n = 1}^{N} \exp \left( - \cfrac{\beta}{2} \bigl( t_{n} - \boldsymbol w^{T} \boldsymbol \phi(\boldsymbol x_{n}) \bigr)^{2} \right) \\
&= \exp \left( - \cfrac{\beta}{2} \sum_{n = 1}^{N} \bigl( t_{n} - \boldsymbol w^{T} \boldsymbol \phi(\boldsymbol x_{n}) \bigr)^{2} \right) \\
&= \exp \left( - \cfrac{\beta}{2} \bigl( \boldsymbol t - \boldsymbol \Phi \boldsymbol w \bigr)^{T} \bigl( \boldsymbol t - \boldsymbol \Phi \boldsymbol w \bigr) \right) \tag{10} \\
\end{align}

事前分布 $p(\boldsymbol w)$ を以下のように定義します.

\begin{align}
p(\boldsymbol w)
&= \mathcal{N} (\boldsymbol w \mid \boldsymbol m_{0}, \boldsymbol S_{0}) \\
&\propto \exp \left( - \cfrac{1}{2} \bigl( \boldsymbol w - \boldsymbol m_{0} \bigr)^{T} \boldsymbol S_{0}^{-1} \bigl( \boldsymbol w - \boldsymbol m_{0} \bigr) \right) \tag{11} \\
\end{align}

式$(10)$,式$(11)$,ベイズの定理を用いて式を展開します.

\begin{align}
p(\boldsymbol w \mid \boldsymbol t)
&\propto p(\boldsymbol t \mid \boldsymbol w) p(\boldsymbol w) \\
&\propto \exp \left( - \cfrac{\beta}{2} \bigl( \boldsymbol t - \boldsymbol \Phi \boldsymbol w \bigr)^{T} \bigl( \boldsymbol t - \boldsymbol \Phi \boldsymbol w \bigr) \right) \exp \left( - \cfrac{1}{2} \bigl( \boldsymbol w - \boldsymbol m_{0} \bigr)^{T} \boldsymbol S_{0}^{-1} \bigl( \boldsymbol w - \boldsymbol m_{0} \bigr) \right) \\
&= \exp \left( - \cfrac{1}{2} \Bigl( \boldsymbol w^{T} \bigl( \boldsymbol S_{0}^{-1} + \beta \boldsymbol \Phi^{T} \boldsymbol \Phi \bigr) \boldsymbol w - \bigl( \boldsymbol S_{0}^{-1} \boldsymbol m_{0} + \beta \boldsymbol \Phi^{T} \boldsymbol t \bigr)^{T} \boldsymbol w - \boldsymbol w^{T} \bigl( \boldsymbol S_{0}^{-1} \boldsymbol m_{0} + \beta \boldsymbol \Phi^{T} \boldsymbol t \bigr) + \beta \boldsymbol t^{T} \boldsymbol t + \boldsymbol m_{0}^{T} \boldsymbol S_{0}^{-1} \boldsymbol m_{0} \Bigr) \right) \tag{12} \\
\end{align}

式$(12)$より,以下の結果が得られます.

\begin{align}
& p(\boldsymbol w \mid \boldsymbol t) = \mathcal{N} \bigl( \boldsymbol w \mid \boldsymbol m_{N}, \boldsymbol S_{N} \bigr) \tag{13} \\
& \boldsymbol m_{N} = \boldsymbol S_{N} \bigl( \boldsymbol S_{0}^{-1} \boldsymbol m_{0} + \beta \boldsymbol \Phi^{T} \boldsymbol t \bigr) \tag{14} \\
& \boldsymbol S_{N} = \boldsymbol S_{0}^{-1} + \beta \boldsymbol \Phi^{T} \boldsymbol \Phi \tag{15} \\
\end{align}

ベイズ線形回帰では,得られた事後分布のパラメータ $\boldsymbol m_{N}, \boldsymbol S_{N}$ を,
次のデータを観測したときの事前分布のパラメータとして扱うことができます.
この性質により,データを観測するたびにパラメータを更新することができます.
これを逐次学習またはオンライン学習と呼びます.

実装

ベイズ線形回帰で直線フィッティングを解いてみました.
まず,$f(x, \boldsymbol a) = a_{0} + a_{1} x \ (a_{0} = -0.3, a_{1} = 0.5)$ に対して,
一様分布 $U(x \mid -1, 1)$ から $x_{n}$ を選び,$f(x_{n}, \boldsymbol a)$ を評価します.
次に,標準偏差 $0.2$ のガウスノイズを $f(x_{n}, \boldsymbol a)$ に加え,目標値 $t_{n}$ を生成します.
実装の目標は,訓練集合から真のパラメータ $a_{0}, a_{1}$ を求めることになります.
python3で実装したプログラムを掲載します.

import numpy as np
from matplotlib import pyplot as plt
import matplotlib.mlab as mlab

def bivariate_normal(x_seq, y_seq, mu, S):
    return mlab.bivariate_normal(x_seq, y_seq, S[0,0] ** (0.5), S[1,1] ** (0.5), mu[0], mu[1], S[0, 1])

def plot_data(ax, X, T):
    ax.set_xlim(-1, 1), ax.set_ylim(-1, 1)
    ax.scatter(X, T, color = 'b')

def plot_heatmap(ax, Z, x_seq, y_seq):
    ax.set_xlim(-1, 1), ax.set_ylim(-1, 1)
    ax.pcolor(x_seq, y_seq, Z, cmap=plt.cm.jet)

def plot_line(ax, W, seq):
    ax.set_xlim(-1, 1), ax.set_ylim(-1, 1)
    for _w in W:
        ax.plot(seq, _w.dot(np.vstack((np.ones(seq.size), seq))), color = 'r')


if __name__ == "__main__":
    # init
    w = np.array([-0.3, 0.5])
    sigma = 0.2
    alpha, beta = 2.0, 1.0 / (sigma ** 2)
    mu = np.zeros(2)
    S = np.identity(2) / alpha
    N = 15

    # plot
    fig = plt.figure(figsize = (15, 5 * N))
    seq = np.linspace(-1.0, 1.0, 51)
    x_seq, y_seq = np.meshgrid(seq, seq)
    Z = bivariate_normal(x_seq, y_seq, mu, S)
    W = np.random.multivariate_normal(mu, S, 6)
    axes = [fig.add_subplot(N + 1, 3, j) for j in range(1, 4)]
    plot_heatmap(axes[1], Z, x_seq, y_seq)
    plot_line(axes[2], W, seq)

    # generate data
    X = 2 * (np.random.rand(N) - 0.5)
    T = w.dot(np.vstack((np.ones(N), X))) + np.random.normal(0, sigma, N)

    # fit
    for n in range(N):
        x_n, t_n = X[n], T[n]
        Phi = np.array([1.0, x_n])

        # estimate parameters
        S_inv = np.linalg.inv(S)
        S = np.linalg.inv(S_inv + beta * Phi.reshape(-1, 1) * Phi) # eq(12)
        mu = S.dot(S_inv.dot(mu) + beta *  Phi * t_n) # eq(11)

        # plot
        Z = bivariate_normal(x_seq, y_seq, mu, S)
        W = np.random.multivariate_normal(mu, S, 6)
        axes = [fig.add_subplot(N + 1, 3, (n + 1) * 3 + j) for j in range(1, 4)]
        plot_data(axes[0], X[:n+1], T[:n+1])
        plot_heatmap(axes[1], Z, x_seq, y_seq)
        plot_line(axes[2], W, seq)

    plt.savefig('bayes.png')
    plt.clf()

結果

得られた結果を下図に示します.
1列目は,観測したデータ $(x_{n}, t_{n})$ をプロットした図を表しています.
2列目は,事後分布をヒートマップで表しています.
データを観測するたびに $(-0.3, 0.5)$ に集中していくのが分かります.
3列目は,事後分布からランダムに選んだパラメータ持つ直線を描画しています.

今回の推定するのは傾きと切片の$2$パラメータであるため,データを2つ観測した時点で

bayes_regression.png

おわりに

pythonでベイズ線形回帰による直線フィッティングを実装できました.


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。