post Image
PRML第12章 ベイズ的主成分分析 Python実装

今回の記事ではベイズ的主成分分析を実装します。対象とするデータが存在する観測空間(高次元)から潜在空間(低次元)への射影を求めるというのが主成分分析(PCA: Principal Component Analysis)の主な使い方だと思います。可視化が目的ならば潜在空間を2(もしくは3)次元にしますが、データの前処理としてだと潜在空間の次元をいくつに設定すればいいのかわかる状況は稀だと思います。寄与率を計算してという方法もありますが、結局そのときの閾値は私たちが設定しないといけません。そこで、ベイズ的主成分分析では関連度自動決定によって自動的に潜在空間の次元を決定します。

確率的主成分分析

主成分分析を確率的に解釈することで、後々ベイズ的な取り扱いができるようになります。
確率的主成分分析では、私たちが観測したデータ$x$(D次元)は、潜在空間からサンプルされた$z$(M次元)を行列$W$((D,M)行列)で射影してから平行移動させて($+\mu$(D次元))、ノイズを加えた($+\epsilon$(D次元))ものと解釈されます。

{\bf x = Wz + \mu + \epsilon}

ここで$z$と$\epsilon$はガウス分布に従うと仮定します。
推定するパラメータは$W,\mu,\sigma^2$($\epsilon$が従うガウス分布の分散)の3つで、その尤度関数は次のようになります。

p({\bf x|W,\mu},\sigma^2) = \int p({\bf x|Wz+\mu},\sigma^2)p({\bf z}){\rm d}{\bf z}

この尤度関数を最大化する手法がPRMLでは2つ紹介されています。一つ目は単純に特異値分解を利用する手法、もう一つは、EMアルゴリズムを用いて$z$についての事後分布の更新(Eステップ)と完全データ$x,z$が与えられたときの尤度関数の最大化(Mステップ)を繰り返すというものです。

ベイズ的主成分分析

上の例では、潜在変数空間の次元をMに固定していました。ベイズ的主成分分析では関連度自動決定を用いて余分な次元を枝刈りしていきます。(ただし、Mの値が実際に減るわけではありません。)そのためにパラメータ$W$に次のような事前分布を設けます。

p({\bf W}|{\bf \alpha}) = \prod_{i=1}^M\left({\alpha_i\over2\pi}\right)^{D/2}\exp\left\{-{1\over2}\alpha_i{\bf w}_i^\top {\bf w}_i\right\}

ここで、${\bf w}_i$は${\bf W}$の$i$番目の列ベクトルです。そして、$\alpha_i$は個々のガウス分布の精度パラメータの役割をしています。$\alpha$を推定してくと幾つかの成分は非常に大きな値を持ちます。そうなると精度が非常に大きいということなので対応する${\bf }W$の列ベクトルの成分は0ばかり、すなわちその次元が枝刈りされたということになります。

コード

ライブラリ

いつも通りmatplotlibとnumpyだけ使います。

import matplotlib.pyplot as plt
import numpy as np

最尤法による主成分分析

通常通りの固有値分解を用いた手法

# 主成分分析を行うクラス
class PCA(object):

    def __init__(self, n_component):
        # 潜在空間の次元を指定
        self.n_component = n_component

    # 最尤法で主成分分析を行う
    def fit(self, X):
        # PRML式(12.1) muの最尤推定値を計算
        self.mean = np.mean(X, axis=0)

        # PRML式(12.2) データ共分散行列
        cov = np.cov(X, rowvar=False)

        # 固有値分解
        values, vectors = np.linalg.eigh(cov)
        index = np.size(X, 1) - self.n_component

        # PRML式(12.46) sigma^2の最尤推定値
        if index == 0:
            self.var = 0
        else:
            self.var = np.mean(values[:index])

        # PRML式(12.45) 射影行列Wの最尤推定値
        self.W = vectors[:, index:].dot(np.sqrt(np.diag(values[index:]) - self.var * np.eye(self.n_component)))

ベイズ的主成分分析

下のメソッドは全て先ほどのPCAクラス内のメソッドです。今回は比較のために、まずパラメータを最尤推定して、それらを初期値として関連度自動決定で枝刈りをしています。

    # ベイズ的主成分分析を行う
    def fit_bayesian(self, X, iter_max=100):
        # データ空間の次元
        self.ndim = np.size(X, 1)

        # 最尤推定でパラメータを推定して初期値とする
        self.fit(X)

        # 精度パラメータの初期化(1度目の推定)
        self.alpha = self.ndim / np.sum(self.W ** 2, axis=0)

        # データの0平均化
        D = X - self.mean

        # EMアルゴリズムを指定回数だけ繰り返す
        for i in xrange(iter_max):
            # Eステップ zの十分統計量
            Ez, Ezz = self.expectation(D)

            # Mステップ W,sigma^2の推定
            self.maximize(D, Ez, Ezz)

            # PRML式(12.62) 超パラメータの更新
            self.alpha = self.ndim / np.sum(self.W ** 2, axis=0).clip(min=1e-10)

    # Eステップ zの十分統計量 E[z]、E[zz^T]の計算
    def expectation(self, D):
        # PRML式(12.41)
        M = self.W.T.dot(self.W) + self.var * np.eye(self.n_component)
        Minv = np.linalg.inv(M)

        # PRML式(12.54) E[z]
        Ez = D.dot(self.W).dot(Minv)

        # PRML式(12.55) E[zz^T]
        Ezz = self.var * Minv + np.einsum('ni,nj->nij', Ez, Ez)
        return Ez, Ezz

    # Mステップ W,sigma^2の推定
    def maximize(self, D, Ez, Ezz):
        # PRML式(12.63) Wの推定
        self.W = D.T.dot(Ez).dot(np.linalg.inv(np.sum(Ezz, axis=0) + self.var * np.diag(self.alpha)))

        # PRML式(12.57) sigma^2の推定
        self.var = np.mean(
            np.mean(D ** 2, axis=-1)
            - 2 * np.mean(Ez.dot(self.W.T) * D, axis=-1)
            + np.trace(Ezz.dot(self.W.T).dot(self.W).T) / self.ndim)

ヒントン図

今回はPRML図12.14の再現をします。そのためには、行列の各要素を正方形で表すヒントン図を描くための関数が必要になります。matplotlib hintonでググって出てきたページのものを少し改変しています。

def hinton(matrix, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))

    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(matrix):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([y - size / 2, x - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis()
    plt.xlim(-0.5, np.size(matrix, 1) - 0.5)
    plt.ylim(-0.5, len(matrix) - 0.5)
    plt.show()

メイン関数

def create_toy_data(sample_size=100, ndim_hidden=1, ndim_observe=2, std=1.):
    Z = np.random.normal(size=(sample_size, ndim_hidden))
    mu = np.random.uniform(-5, 5, size=(ndim_observe))
    W = np.random.uniform(-5, 5, (ndim_hidden, ndim_observe))

    # PRML式(12.33)
    X = Z.dot(W) + mu + np.random.normal(scale=std, size=(sample_size, ndim_observe))
    return X


def main():
    # 3次元の潜在空間から10次元空間に射影してできたデータを100点作成
    X = create_toy_data(sample_size=100, ndim_hidden=3, ndim_observe=10, std=1.)

    # 潜在空間を9次元として最尤法によるPCAを行う
    pca = PCA(9)
    pca.fit(X)
    hinton(pca.W)

    # 最大9次元から枝刈りをしてベイズ的PCAを行う
    pca.fit_bayesian(X)
    hinton(pca.W)

全体のコード

pca.py
import matplotlib.pyplot as plt
import numpy as np


class PCA(object):

    def __init__(self, n_component):
        self.n_component = n_component

    def fit(self, X):
        self.mean = np.mean(X, axis=0)
        cov = np.cov(X, rowvar=False)
        values, vectors = np.linalg.eigh(cov)
        index = np.size(X, 1) - self.n_component
        if index == 0:
            self.var = 0
        else:
            self.var = np.mean(values[:index])
        self.W = vectors[:, index:].dot(np.sqrt(np.diag(values[index:]) - self.var * np.eye(self.n_component)))

    def fit_bayesian(self, X, iter_max=100):
        self.ndim = np.size(X, 1)
        self.fit(X)
        self.alpha = self.ndim / np.sum(self.W ** 2, axis=0)
        D = X - self.mean
        for i in xrange(iter_max):
            Ez, Ezz = self.expectation(D)
            self.maximize(D, Ez, Ezz)
            self.alpha = self.ndim / np.sum(self.W ** 2, axis=0).clip(min=1e-10)

    def expectation(self, D):
        M = self.W.T.dot(self.W) + self.var * np.eye(self.n_component)
        Minv = np.linalg.inv(M)
        Ez = D.dot(self.W).dot(Minv)
        Ezz = self.var * Minv + np.einsum('ni,nj->nij', Ez, Ez)
        return Ez, Ezz

    def maximize(self, D, Ez, Ezz):
        self.W = D.T.dot(Ez).dot(np.linalg.inv(np.sum(Ezz, axis=0) + self.var * np.diag(self.alpha)))
        self.var = np.mean(
            np.mean(D ** 2, axis=-1)
            - 2 * np.mean(Ez.dot(self.W.T) * D, axis=-1)
            + np.trace(Ezz.dot(self.W.T).dot(self.W).T) / self.ndim)


def hinton(matrix, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))

    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(matrix):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([y - size / 2, x - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis()
    plt.xlim(-0.5, np.size(matrix, 1) - 0.5)
    plt.ylim(-0.5, len(matrix) - 0.5)
    plt.show()


def create_toy_data(sample_size=100, ndim_hidden=1, ndim_observe=2, std=1.):
    Z = np.random.normal(size=(sample_size, ndim_hidden))
    mu = np.random.uniform(-5, 5, size=(ndim_observe))
    W = np.random.uniform(-5, 5, (ndim_hidden, ndim_observe))
    X = Z.dot(W) + mu + np.random.normal(scale=std, size=(sample_size, ndim_observe))
    return X


def main():
    X = create_toy_data(sample_size=100, ndim_hidden=3, ndim_observe=10, std=1.)

    pca = PCA(9)
    pca.fit(X)
    hinton(pca.W)

    pca.fit_bayesian(X)
    hinton(pca.W)


if __name__ == '__main__':
    main()

結果

最尤法によるPCAで射影行列Wを推定した結果が下のようになります。
mle.png
そして、ベイズ的主成分分析を用いて潜在空間の次元を枝刈りすると射影行列Wは下の図のようになります。左の6列分が消えていて、それらに対応する次元が枝刈りされました。3次元から射影してきたということを捉えらることができています。
bayesian.png
PRML図12.14のような結果が得られました。

終わりに

PRMLもこのあたりまでくると、今までに習ったことを組み合わせて新しいモデルに適用することが多くなってきました。今回は第6章の関連度自動決定と第9章のEMアルゴリズムを組み合わせて、観測データは実際にはより低次元空間から射影されて生成されているというモデルに適用しています。欠損値のあるデータについても適用できるらしいので、そちらも試してみたいものです。


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。