post Image
ディープラーニングのフレームワークの自作

はじめに

ディープラーニングのプログラムを作る場合、TensorFlowやTheanoなどの既存のフレームワークを使うことが多いと思いますが、既存のフレームワークを使うと内部の動作がブラックボックスになってデバッグがむつかしいという問題があります。

この記事では現在自作をしているディープラーニングのフレームワークについて紹介したいと思います。

以下のニューラルネットワークで基本的な動作はできるようになっていて、現在いろいろな事例のテストをしています。
– 多層パーセプトロン
– 畳み込みニューラルネットワーク(CNN)
– 再帰型ニューラルネットワーク(RNN)
– LSTM (Long short-term memory)

個人でTensorFlowのような高機能のフレームワークを作るのは難しいので、いろいろな人と一緒にフレームワークを作れたらという思いで公開しました。

他の人がコードを読んでも動作を理解できるように、現在ソースコードを整理してコメントを付け、内部の動作を説明した文書を作成する作業をしています。

以下のサイトで補足の説明の文書を作っていますので、このQiitaの記事を読んで興味を持たれたら続きをお読みください。
まだ不完全ですが今後少しずつ内容を充実させていく予定です。
http://lkzf.info/mkfn/

ソースコードはGitHubにあります。
https://github.com/teatime77/mkfn

以下は機械学習論文読み会・懇親会 vol.9 2017/3/25で発表した時のスライドです。
https://www.slideshare.net/KoHamada/ss-73602443

フレームワークの名前はとりあえずMkFnとしています。
ディープラーニングは多変数ベクトル値関数を求める処理と考えられるので、 M a k e F u n ctionの略です。
( あまり評判良くないので、そのうち名前は変えるかも… )

以下はこの記事の続きです。
Excelで作る人工知能

前提となる知識

この記事を読むための前提知識としては「ゼロから作るDeep Learning」などの本で外部のフレームワークを使わずに、ディープラーニングを作る方法を学ばれている方を対象としています。

このフレームワークの基本的な機能は順伝播の数式から誤差逆伝播の数式を自動的に生成することなので、順伝播と誤差逆伝播についておおよそ理解していればOKです。

数学の知識で必要なのは以下の多変数関数の微分くらいで、それ以外は高校の微分の知識で十分です。

    \frac{ \partial f(y_0,...,y_n) }{ \partial x } = \sum_i \frac{ \partial f(y_0,...,y_n) }{ \partial y_i } \frac{ \partial y_i }{ \partial x }

行列の知識は一切不要です。

基本的な動作

フレームワークの基本的な動作は以下の通りです。

  1. ニューラルネットワークの各レイヤーに対して、順伝播の式をC#で書きます。
  2. C#のソースコードを解析して、逆伝播の数式を生成します。
  3. CUDAゃC++のプログラムのソースコードを生成します。
  4. MathJaxの形で順伝播や逆伝播の式を出力します。

多層パーセプトロンの例

以下で単純な多層パーセプトロンの例で説明します。

多層パーセプトロンの順伝播の式は以下になります。

u_{i} = \displaystyle \sum_{j }^{ X } x_{j} \cdot w_{i}^{j} + b_{i}
 \\ 
y_{i} = σ(u_{i})

ここで $x_j$ は入力、$u_i$は入力の重み付き加算、$y_i$は出力、$w_{i}^{j}$は重み、$b_{i}$はバイアスです。
σは活性化関数でシグモイド関数やtanhです。
MkFnではこれを以下のようにC#のコードで書きます。
C#のコードはニューラルネットワークの記述に使うだけで、C#のアプリとして動作させるわけではありません。

// 多層パーセプトロンのレイヤーのクラス
public class FullyConnectedLayer : Layer {
    public int X; // 配列xのサイズ
    public int Y; // 配列yのサイズ

    public double[] x; // 入力
    public double[] u; // 入力の重み付き加算
    public double[] y; // 出力

    public double[,] w; // 重み
    public double[] b; // バイアス
}

xやyなどの配列のサイズはレイヤーのコンストラクターで指定します。

// コンストラクター    
public FullyConnectedLayer(int x_size, int y_size) {
    X = x_size;
    Y = y_size;

    x = new double[X];
    u = new double[Y];
    y = new double[Y];

    w = new double[Y, X];
    b = new double[Y];
}

順伝播の式はForwardという名前のメソッドの中で指定します。

// 順伝播の式
public override void Forward() {
    foreach (int i in Range(Y)) {
        u[i] = (from j in Range(X) select x[j] * w[i, j]).Sum() + b[i];
        y[i] = σ(u[i]);
    }
}

Range(n)は0からn-1までのintの配列を返す関数です。

MkFnはC#で書かれたレイヤーの定義を読み込み、ソースコード解析をしてから微分や数式の簡約化をして以下のような誤差逆伝播の式を生成します。

δu_{i} = δy_{i} \cdot σ'(u_{i})
 \\δx_{j} = \displaystyle \sum_{i }^{ Y } δu_{i} \cdot w_{i}^{j}
 \\δw_{i}^{j} = δu_{i} \cdot x_{j}
 \\δb_{i} = δu_{i}

この数式からC++やCUDAのソースコードを生成します。

プログラムの処理の流れ

プログラムの処理の流れは以下のようになります。

  • C#のソースコードを解析します。以下は通常のコンパイラで行われている処理です。

  • ニューラルネットワークの情報を調べます。

    • レイヤーのコンストラクターから配列のサイズを計算します。
      配列wがコンストラクターの中で以下のように初期化されていれば、配列wのサイズはY×Xになります。
      w = new double[Y, X];
    • 順伝播の式を解析して、それぞれの変数の順伝播先を調べます。
  • 順伝播の式から逆伝播の式を作ります。

    • 逆伝播の方程式を作ります。
    • 微分の数式処理をします。(数値微分ではありません。)
    • 数式の簡約化をします。
  • CUDAやC++のソースコードを作ります。

数式処理

逆伝播の式を計算するには、微分や数式の簡約化などの数式処理が必要です。
数式処理というとMathematicaのような非常に複雑なプログラムを想像するかも知れませんが、ディープラーニングの逆伝播という対象を絞った場合は、以外と簡単な処理だけで可能です。

以下に微分と数式の簡約化に必要な処理を書きます。

微分

変数自身の微分は1です。

    \frac{ \partial x }{ \partial x } = 1

変数を含まない式の微分は0です。

    \frac{ \partial c }{ \partial x } = 0

関数$f$の定数倍の微分は、定数に$f$の微分を掛けた値です。

    \frac{ \partial c \cdot f(x) }{ \partial x } = c \cdot \frac{ \partial f(x) }{ \partial x }

和の微分は、微分の和です。

    \frac{ \partial \sum_i y_i }{ \partial x } = \sum_i \frac{ \partial y_i }{ \partial x }

多変数関数の微分

    \frac{ \partial f(y_0,...,y_n) }{ \partial x } = \sum_i \frac{ \partial f(y_0,...,y_n) }{ \partial y_i } \frac{ \partial y_i }{ \partial x }

数式の簡約化

加算に0が含まれていれば除去します。

        a + 0 + b= a + b

乗算に0が含まれていれば値は0です。

        a \cdot 0 \cdot b = 0

加算の中の定数値を計算します。

        a + 1 + 2 = a + 3

乗算の中の定数値を計算します。

        2 * a * 3 * b = 6 * a * b

加算の中の定数の係数を計算します。

        a + b + 2 \cdot b = a + 3 \cdot b

入れ子の加算は外に出します。

        a + (b + c) = a + b + c

上記の数式処理だけで逆伝播の式が求まるのは私にとって意外でした。
LSTMのメモリセルの逆伝播の式は手計算でするにはかなりややこしいのですが、以下の結果があっさり出たときは、「この程度の数式処理でいいんだ!」とちょっと驚きました。

 \delta y_{t}^{j} \cdot σ(uO_{t}^{j}) \cdot σ'(s_{t}^{j}) + \delta s_{t + 1}^{j} \cdot σ(uF_{t + 1}^{j}) + \delta uO_{t}^{j} \cdot wO_{j} + \delta uF_{t + 1}^{j} \cdot wF_{j} + \delta uI_{t + 1}^{j} \cdot wI_{j}

この数式を求める過程は以下にあります。
LSTMの例
LSTMの逆伝播の式の導出

おわりに

「ディープラーニングのフレームワークの作り方はこれでいいのかな?」という不安があり、いろいろな人の意見を聞きたくて公開しました。
「そもそも基本の設計が間違っている」という厳しい意見を含めて誤りがあれば指摘してもらえればと思います。

よろしくお願いします。


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。