post Image
Chainer で手書き数字認識(MNIST)

概要

定番の手書き数字データセット MNIST を Chainer を使用して3層のニューラルネットワークで学習してみました。

3層ニューラルネットワーク

Chainer では、MNIST のデータを自動的にダウンロードして使用できる便利なメソッドが用意されています。それを使って、3層ニューラルネットワーク(2つある隠れ層のノードはそれぞれ 50)で学習を行い、正答率(accuracy)を測定しました。下のコードは、Chainer の MNIST サンプル を下敷きにしていますが、大幅に変更を加えてあります。

コード

neural_net.py
import numpy as np
import chainer
from chainer import Chain, Variable
import chainer.functions as F
import chainer.links as L

class NeuralNet(chainer.Chain):
    def __init__(self, n_units, n_out):
        super().__init__(
            l1=L.Linear(None, n_units),
            l2=L.Linear(n_units, n_units),
            l3=L.Linear(n_units, n_out),
        )

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def check_accuracy(model, xs, ts):
    ys = model(xs)
    loss = F.softmax_cross_entropy(ys, ts)
    ys = np.argmax(ys.data, axis=1)
    cors = (ys == ts)
    num_cors = sum(cors)
    accuracy = num_cors / ts.shape[0]
    return accuracy, loss

def main():
    model = NeuralNet(50, 10)

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train, test = chainer.datasets.get_mnist()
    xs, ts = train._datasets
    txs, tts = test._datasets

    bm = 100

    for i in range(100):

        for j in range(600):
            model.zerograds()
            x = xs[(j * bm):((j + 1) * bm)]
            t = ts[(j * bm):((j + 1) * bm)]
            t = Variable(np.array(t, "i"))
            y = model(x)
            loss = F.softmax_cross_entropy(y, t)
            loss.backward()
            optimizer.update()

        accuracy_train, loss_train = check_accuracy(model, xs, ts)
        accuracy_test, _           = check_accuracy(model, txs, tts)

        print("Epoch %d loss(train) = %f, accuracy(train) = %f, accuracy(test) = %f" % (i + 1, loss_train.data, accuracy_train, accuracy_test))

if __name__ == '__main__':
    main()


実行結果

Epoch 1 loss(train) = 0.242552, accuracy(train) = 0.926683, accuracy(test) = 0.924800
Epoch 2 loss(train) = 0.175040, accuracy(train) = 0.946167, accuracy(test) = 0.942000
Epoch 3 loss(train) = 0.133406, accuracy(train) = 0.959050, accuracy(test) = 0.954400
...
Epoch 98 loss(train) = 0.002298, accuracy(train) = 0.999267, accuracy(test) = 0.971300
Epoch 99 loss(train) = 0.002876, accuracy(train) = 0.998917, accuracy(test) = 0.972900
Epoch 100 loss(train) = 0.003336, accuracy(train) = 0.998917, accuracy(test) = 0.973500

解説

1エポック(epoch)は、600回のイタレーション(iteration)で構成されています。100エポック後に、訓練(train)データに対する正答率は、99.9%、テストデータに対する正答率は97.3%となっています。Optimizer は SGD も試しましたが、Adam のほうがはるかに学習速度が速かったです。上記のコードの他に、各層のノードの数を少し増やしたり、層を1つ増やして4層にしたりもしてみましたが、正答率に大きな違いはありませんでした。

感想

「ゼロから作る Deep Learning」を読んで5ヶ月が経ちましたが、ようやく Chainer で実装するところまでたどり着きました。当たり前ですが、やはりフレームワークを作ると非常にすっきりと記述できますね。今回のコードを記述するにあたり、Chainer の Link, Chain, Optimizer クラスあたりのソースコードを読み込みました。とても簡潔にかかれており、読めばちゃんと理解できるのには感動しました。徐々に Chainer がわかってきた感があります。

参考文献

Chainer 本家ドキュメント


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。