post Image
TensorFlow iOS カメラサンプルで自作した画像分類器を動かす

TensorFlow で自作した画像分類器モデルを TensorFlow iOS で動かします。

大抵の記事では分類器モデルの作成までは説明されているのですが(Mac や Ubuntu で動かすにはそれで十分)、iOS カメラサンプルで使われている pb(protocol buffers) ファイルの作成の部分まで触れているものがなかったのでその部分を説明しようかと思います。

ちなみに僕はこの方法で作成した pb ファイルを使って「B’z 分類器を作って iOS で動かす」というのをやってみました。

環境

  • Python 2.7
  • TensorFlow 1.0

やること

  • 学習済みの ckpt から iOS で動作する pb ファイルを作成する
  • pb ファイルを iOS カメラサンプルに入れて動作させる

やらないこと

  • TensorFlow のインストール、環境構築(公式通りで上手く行くと思います)
  • TensorFlow iOS のビルド(堤さんの記事で上手くいくと思います)
  • 学習データの用意、前処理、学習済み ckpt ファイルの作成(kenmazさんの記事が参考になります)
  • iOS での顔認識、物体認識など

準備

TensorFlow 1.0 の環境構築、TensorFlow iOS のビルドが済んでいる前提で話を進めますので、まだの方は上記「今回実装しないこと」記載のリンクを参考に済ませておいてください。また、学習させる部分に関しても色々な方が記事にしてますのでそちらをご覧ください。
僕は kenmazさんの記事を参考に学習済みモデルを作成しましたので、今回はそちらで使われているコードを引用しながら解説していきたいと思います。(kenmazさんありがとうございます!)
(TensorFlow のバージョンが違うので結構書き換えないといけないと思いますが、関数名が変更されている程度の書き換えなので API ドキュメントで調べながら書き換えれば正常に動作します)

学習に使うコードに追加の実装

kenmazさんの記事 のバージョン誤差を修正すれば正常に ckpt ファイルは生成されるのですが、pb ファイル生成のためのコードを追加します。また、TensorFlow iOS では読み込めない関数があるので、それを迂回するような修正を加えます。

mcz_main.py の main に引数を追加して、pb ファイル作成と学習を分けて main を呼べるようにします。
また、 mcz_input.load_data では TensorFlow iOS で読み込めない関数を使っているため以下のように書き換えて、pb ファイル生成時に mcz_input.load_data が呼ばれないようにします。

mcz_main.py

def main(for_input_graph = False):
    with tf.Graph().as_default():
        keep_prob = tf.placeholder("float")

        if for_input_graph:
            images = tf.placeholder(tf.float32, shape=[120, 56, 56, 3], name="x")
            labels = tf.placeholder(tf.float32, shape=[120, 5], name="Reshape")
        else:
            images, labels, _ = mcz_input.load_data([FLAGS.train, FLAGS.test], FLAGS.batch_size, shuffle = True, distored = True)

学習を走らせている sess.run([train_op, loss_value, acc] ... の直前に分岐を加え、input_graph.pb の生成と、学習を分離します。

mac_main.py
if for_input_graph:
    tf.train.write_graph(sess.graph.as_graph_def(), "./", "input_graph.pb")
    return
else:
    for step in range(FLAGS.max_steps):
        start_time = time.time()
        _, loss_result, acc_res = sess.run([train_op, loss_value, acc], feed_dict={keep_prob: 0.99})

最下行の main 呼び出し部分を以下のように書き換えます。

mcz_main.py
if __name__ == '__main__':
    main(for_input_graph = True)
    main()

TensorFlow iOS で tf.nn.dropout が使われたモデルを使うと正しくロードできないので mcz_model.py も以下のように書き換えます。
ちなみに dropout は過学習防止のための関数のようなので精度に影響が出る可能性はありますが、僕の場合は特に問題になりませんでした。

mcz_model.py
def inference_deep(images_placeholder, keep_prob, image_size, num_classes):
    #
    # 中略
    #
    with tf.name_scope('fc1') as scope:
        w = image_size / pow(2,3)
        W_fc1 = weight_variable([w*w*128, 1024])
        b_fc1 = bias_variable([1024])
        h_pool3_flat = tf.reshape(h_pool3, [-1, w*w*128])
        print h_pool3_flat
        h_fc1 = tf.matmul(h_pool3_flat, W_fc1) + b_fc1
        # h_fc1_drop = tf.nn.dropout(tf.nn.relu(h_fc1), keep_prob)
        # print h_fc1_drop

    with tf.name_scope('fc2') as scope:
        W_fc2 = weight_variable([1024, num_classes])
        b_fc2 = bias_variable([num_classes])
        h_fc2 = tf.matmul(h_fc1, W_fc2) + b_fc2
        # h_fc2 = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
        print h_fc2

    with tf.name_scope('softmax') as scope:
        y_conv=tf.nn.softmax(h_fc2)
        print y_conv

    return y_conv

実行します。

python mcz_main.py

これで input_graph.pb が生成されます。
すでに学習済みの ckpt ファイルがある場合はこのまま中断しても大丈夫な場合はありますが、上記コードのように dropout など TensorFlow iOS で読み込めない関数を使用している場合はこのまま学習を進めて、精度を高めます。(iOS で動かしたいだけなら、今回生成される model.ckpt-0 を利用できますので、中断しても大丈夫です)

input_graph.pb に学習済みモデルのデータを保存する

学習済みの ckpt モデルのデータをまだ空の input_graph.pb に保存します。gen_output_graph.py を以下のように作成します。

gen_output_graph.py
from tensorflow.python.tools import freeze_graph

input_checkpoint = "path/to/ckpt/model.ckpt-XXXX"
input_graph = "./input_graph.pb"
output_node_names = "softmax/Softmax"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph = "./output_graph.pb"

freeze_graph.freeze_graph(input_graph=input_graph,
 input_saver="",
 input_binary=False,
 input_checkpoint=input_checkpoint,
 output_node_names=output_node_names,
 restore_op_name=restore_op_name,
 filename_tensor_name=filename_tensor_name,
 output_graph=output_graph,
 clear_devices=True,
 initializer_nodes="",
 variable_names_blacklist="")

input_checkpoint, input_graph には作成済みのファイルのパスを指定します。

output_node_names には今回は mcz_main.pymcz_model.inference_deep の返り値が softmax スコープ内の Softmax 関数だったので softmax/Softmax となっています。

output_graph は新規で生成される pb ファイル名を指定します。

restore_op_name, filename_tensor_name はよくわかっていないのです。すいません。
path はバックスラッシュなどでエスケープする必要がないので注意です。

python gen_output_graph.py

上記コマンドを実行して問題なければ output_graph.pb が生成されます。

pb ファイルを iOS カメラサンプルに入れて動作させる

できあがった output_graph.pbtensorflow/contrib/ios_examples/camera/data/ 以下にに配置してCameraExample プロジェクトを Xcode で起動します。

Screen Shot 2017-02-20 at 5.59.31 PM.png

上記画像のようにリンクされていない(赤くなっている) txt と pb ファイルがあるので削除します。

Screen Shot 2017-02-21 at 4.08.47 PM.png

続いて、 output_graph.pb を上記画像のようにプロジェクトにインポートして、 labels.txt を作成します。

labels.txt
hoge1
hoge2
hoge3
hoge4
hoge5

今回は用意した学習済みモデルは5種類の画像の分類だったので↑のようになりました。

CameraExampleViewController.mm の以下の箇所を書き換えます。

CameraExampleViewController.mm
static NSString* model_file_name = @"output_graph";

static NSString* labels_file_name = @"labels";

const int wanted_input_width = 56;
const int wanted_input_height = 56;

const std::string input_layer_name = "x";
const std::string output_layer_name = "softmax/Softmax";

model_file_name, labels_file_name はその名の通りで、今回作成した output_graphlabels を指定します。
wanted_input_width, wanted_input_height では学習の際に指定した画像のリサイズのサイズを指定します。(kenmazさんの記事を参考にモデルを学習させた場合は mcz_input.pyDST_INPUT_SIZE
input_layer_name には上記 mcz_main.py で images に付けられている name を指定します。 name が付けられていなかった場合は output_graph.pb 作成の際の softmax/Softmax の用にスコープと関数名が必要になるかと思います。
output_layer_name は gen_output_graph.pyoutput_node_names と同じものを指定します。

あとは Xcode の実行ボタンをクリックすればめでたく iOS で自作の学習済みモデルを動かすことが出来るかと思います。

まとめ

環境構築やモデルの学習などの説明は飛ばしているので、この記事だけで iOS で動かせるわけではないですが、同じようにはまっている方の助けになれば幸いです。
今後の課題として、このままだとカメラ画像をフルに使って分類してしまっているので、顔認識や物体認識など、TensorFlow とは関係ない部分も最適化する必要があると思います。
まだまだ TensorFlow のことをよくわかっていないので、何か間違いがあるかもしれないので気になる点があればご指摘お願いします。

参考


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。