post Image
ディープじゃないディープラーニングをNumPyのみで超簡単実装してみた

ディープラーニングの本質を理解するためには、ゼロから実装してみることが大切ですが、MNISTはCNNを実装するのが大変ですし、学習に時間もかかります。
そのため、今回はIrisデータセットを使い、3層(中間層で数えると1層)の”ディープじゃない”ディープラーニング、つまりただのニューラルネットワークを超簡単に実装してみました。
ミニバッチではなくバッチ学習ですが、ちゃんと(確率的ではないですが)勾配降下法も、誤差逆伝搬法も入ってます。
なお、ディープラーニングの理論については僕の愛書のゼロから作るDeep Learningを読んでみてください。本当にわかりやすい最高の良書です。

詳細

パソコンで図を描くのが得意ではないので手書きですみません。
メインはソースってことでお願いします。
また、Irisのデータは英語版のWikipediaに載っているものをランダムに並び替えました。
DSC_0457.JPG
DSC_0461.JPG

ソースコード

ソースはgithubにあります。
Python3です。

pythonのコードのみこちらにも載せておきます。
Irisのデータはgithubからダウンロードしてください。

iris.py
# coding: utf-8

import numpy as np

# ハイパーパラメータ
TRAIN_DATA_SIZE = 50  # 150個のデータのうちTRAIN_DATA_SIZE個を訓練データとして使用。残りは教師データとして使用。
HIDDEN_LAYER_SIZE = 6  # 中間層(隠れ層)のサイズ(今回は中間層は1層なのでスカラー)
LEARNING_RATE = 0.1  # 学習率
ITERS_NUM = 1000  # 繰り返し回数

# データを読み込み
# デフォルトで'#'の行をを飛ばすようになっている
x = np.loadtxt('iris.tsv', delimiter='\t', usecols=(0, 1, 2, 3))
raw_t = np.loadtxt('iris.tsv', dtype=int, delimiter='\t', usecols=(4,))
onehot_t = np.zeros([150, 3])
for i in range(150):
    onehot_t[i][raw_t[i]] = 1

train_x = x[:TRAIN_DATA_SIZE]
train_t = onehot_t[:TRAIN_DATA_SIZE]
test_x = x[TRAIN_DATA_SIZE:]
test_t = onehot_t[TRAIN_DATA_SIZE:]

# 重み・バイアス初期化
W1 = np.random.randn(4, HIDDEN_LAYER_SIZE) * np.sqrt(2 / 4)  # Heの初期値(ReLUのときはこれを使う)
W2 = np.random.randn(HIDDEN_LAYER_SIZE, 3) * np.sqrt(2 / HIDDEN_LAYER_SIZE)
b1 = np.zeros(HIDDEN_LAYER_SIZE)  # 初期値ゼロ ※ゼロから作るDeep Learningを見てこうしたので理由はわからない
b2 = np.zeros(3)

# ReLU関数
def relu(x):
    return np.maximum(x, 0)

# Softmax関数 ※この関数だけネットを見たのでどう実装しているかわからない
def softmax(x):
    e = np.exp(x - np.max(x))
    if e.ndim == 1:
        return e / np.sum(e, axis=0)
    elif e.ndim == 2:
        return e / np.array([np.sum(e, axis=1)]).T
    else:
        raise ValueError

# 交差エントロピー誤差
def cross_entropy_error(y, t):
    if y.shape != t.shape:
        raise ValueError
    if y.ndim == 1:
        return - (t * np.log(y)).sum()
    elif y.ndim == 2:
        return - (t * np.log(y)).sum() / y.shape[0]
    else:
        raise ValueError

# 順伝搬
def forward(x):
    global W1, W2, b1, b2
    return softmax(np.dot(relu(np.dot(x, W1) + b1), W2) + b2)

# テストデータの結果
test_y = forward(test_x)
print((test_y.argmax(axis=1) == test_t.argmax(axis=1)).sum(), '/', 150 - TRAIN_DATA_SIZE)

# 学習ループ
for i in range(ITERS_NUM):
    # 順伝搬withデータ保存
    y1 = np.dot(train_x, W1) + b1
    y2 = relu(y1)
    train_y = softmax(np.dot(y2, W2) + b2)

    # 損失関数計算
    L = cross_entropy_error(train_y, train_t)

    if i % 100 == 0:
        print(L)

    # 勾配計算
    # 計算グラフで求めた式を使用
    a1 = (train_y - train_t) / TRAIN_DATA_SIZE
    b2_gradient = a1.sum(axis=0)
    W2_gradient = np.dot(y2.T, a1)
    a2 = np.dot(a1, W2.T)
    a2[y1 <= 0.0] = 0
    b1_gradient = a2.sum(axis=0)
    W1_gradient = np.dot(train_x.T, a2)

    # パラメータ更新
    W1 = W1 - LEARNING_RATE * W1_gradient
    W2 = W2 - LEARNING_RATE * W2_gradient
    b1 = b1 - LEARNING_RATE * b1_gradient
    b2 = b2 - LEARNING_RATE * b2_gradient

# 結果表示

# 最終訓練データのL値
L = cross_entropy_error(forward(train_x), train_t)
print(L)

# テストデータの結果
test_y = forward(test_x)
print((test_y.argmax(axis=1) == test_t.argmax(axis=1)).sum(), '/', 150 - TRAIN_DATA_SIZE)

『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。