post Image
ディープラーニングでイラスト画像分類

Akka Streams を使って Twitter の画像をダウンロードする で、ふぁぼった画像がいい感じに集まってきました。集めた画像を仕分けしたいところですが、1回で約 3000 画像集まりますので、人の手でやるのは大変です。そこでコンピュータに分類させたいと思います。

ちなみに私は機械学習についてはまったくのど素人です。参考資料にあげたリンクの内容以上のことはやってません……

データ

自動ダウンロードを始める前に収集した、分類済み画像が下記のとおり存在しています。これを使って学習させます。画像は基本的にイラストですが、イラストやフィギュアを撮影した写真などもあります。

内容 枚数(学習用) 枚数(検証用)
艦これ 6130 681
ガルパン 4157 462
その他 778 87

「その他」の枚数が少なくてバランスが悪いです。仕方がありません。これしかないので。

画像の内容はこんな感じです。雑多に集めてあり、サイズもばらばらです。顔抽出も行いません。雰囲気で分類してもらいます。面倒くさいので。

Images

実装

実装は、Keras を使います。バックエンドは Tensorflow です。

今回は3種類の実装を試します。

  1. 小さな CNN
  2. 層の深い CNN
  3. VGG16 のファインチューニング

小さな CNN

小規模な畳み込みニューラルネットワーク (Convolutional Neural Network:CNN) で学習させます。以下のような畳み込み層が3つの CNN です。

CNN 構築部分の実装は以下のとおりです。参考資料 そのままの実装です。

img_width, img_height, channels = 150, 150, 3
nb_classes = 3
def build_cnn(img_width, img_height, channels, nb_classes):
    input_shape = (img_width, img_height, channels)

    model = Sequential()

    model.add(Conv2D(32, (3, 3), input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(32, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(64, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Flatten())
    model.add(Dense(64))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))

    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))

    return model

画像の読み込みは ImageDataGenerator を使います。入力画像の加工は行いません。学習用画像と検証用画像に同じラベルが生成されるように classes を指定します。

# classes = ['01Garupan', '02Kancolle', '99Other']
def load_images(train_path, test_path, img_width, img_height, classes):
    target_size = (img_width, img_height)
    color_mode = 'rgb'
    batch_size = 32
    class_mode = 'categorical'

    train_datagen = ImageDataGenerator(rescale = 1.0 / 255)
    test_datagen = ImageDataGenerator(rescale = 1.0 / 255)
    train_generator = train_datagen.flow_from_directory(
        train_path,
        classes=classes,
        target_size=target_size,
        color_mode=color_mode,
        batch_size=batch_size,
        class_mode=class_mode)

    test_generator = test_datagen.flow_from_directory(
        test_path,
        classes=classes,
        target_size=target_size,
        color_mode=color_mode,
        batch_size=batch_size,
        class_mode=class_mode)
    return train_generator, test_generator

画像を読み込みつつ学習する部分は以下のようになります。

def fit_model(img_width, img_height, channels, classes):
    train_generator, test_generator = load_images('images/train',
                                                  'images/test',
                                                  img_width,
                                                  img_height,
                                                  classes)
    print('Train indices:', train_generator.class_indices)
    print('Test indices:', test_generator.class_indices)
    model = build_cnn(img_width, img_height, channels, len(classes))
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    history = model.fit_generator(train_generator,
                                  steps_per_epoch=steps_per_epoch,
                                  epochs=epochs,
                                  validation_data=test_generator,
                                  validation_steps=validation_steps,
                                  verbose=1)
    return model, history

images/train 配下に学習用の画像を、images/test 配下に検証用の画像を置きます。

$ find images -type d
images
images/test
images/test/01Garupan
images/test/02Kancolle
images/test/99Other
images/train
images/train/01Garupan
images/train/02Kancolle
images/train/99Other

上記で実装した fit_model を実行すれば学習を実行します。fit_model が返した model を実際の分類用に、history を検証用に保存します。

層の深い CNN

少し層の深い CNN で学習させます。規模の違う CNN で結果がどう変わるか見てみるためです。

CNN 構築部分の実装は以下のとおりです。参考資料 そのままの実装です。

img_width, img_height, channels = 150, 150, 3
nb_classes = 3
def build_cnn(img_width, img_height, channels, nb_classes):
    input_shape = (img_width, img_height, channels)

    model = Sequential()

    model.add(Conv2D(32, (3, 3), padding='same', input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(Conv2D(32, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(Activation('relu'))
    model.add(Conv2D(64, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))

    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))

    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))

    return model

VGG16 のファインチューニング

上記では CNN をゼロから学習させましたが、ここでは学習済みのモデルを使ってファインチューニングを行う方法を試します。Keras には VGG16 の学習済みモデルを簡単に利用できるので、それを使います。

VGG16 は ImageNet の画像で学習されており、そのままではイラストデータの分類に使えるものではありません。VGG16 のモデルの深い層だけをイラスト用に再調整 (fine-tuning) させて、分類精度がどうなるか試します。

これも参考資料そのままに実装しました。

結果

損失と精度は以下のようになりました。損失のグラフを見ると、もう少し epoch を大きくとっても良かったかもしれません。
CNN は精度が 0.6 ほどしか出ませんが、VGG16 をファインチューニングした方は 0.7 近くまで出ています。

Validation loss

Validation accuracy

まとめ

  • いい加減な画像でも、ファインチューニングで 7割の精度は出ることがわかりました。
  • 今回のケースでは CNN ではモデルの規模によらず、精度が変わらないことがわかりました。

参考


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。