post Image
KerasモデルをCloud ML Engineで学習してOnline Predictionしてみた

Cloud ML Engine のruntime versionが1.2になったので、Kerasが小細工なしで使えるようになりました。TensorFlowの高レベルAPIもいい感じになって来ていますが、やはりKerasのpretrained modelの多さは魅力的です。とりあえずやり方だけ把握しておこうと、せっかくなので学習だけでなくOnline PredictionもKerasモデルでserveしてみました。

Cloud ML Engineとは
TensorFlowのフルマネージドな実行環境です。分散環境で学習、オートスケールしAPIで推論リクエスト可能なOnline Prediction等、TensorFlowの運用には最高の環境です。

KerasをCloud ML Engine(training)で使う

注意するのは、

  • Kerasのimportをtf.contribからする
  • jobのsubmit時にruntimeVersionを1.2に設定する

位で、後は普通に動きます。こんな感じで書けばOK (こちらを参考にしました)

import tensorflow as tf
from tensorflow.contrib.keras.python import keras
from tensorflow.contrib.keras.python.keras.models import Sequential
from tensorflow.contrib.keras.python.keras.layers.core import Dense, Activation
from sklearn.cross_validation import train_test_split

iris = tf.contrib.learn.datasets.base.load_iris()
train_x, test_x, train_y, test_y = train_test_split(
    iris.data, iris.target, test_size=0.2)

num_classes = 3
train_y = keras.utils.to_categorical(train_y, num_classes)
test_y = keras.utils.to_categorical(test_y, num_classes)

model = Sequential()
model.add(Dense(10, activation='relu', input_shape=(4,)))
model.add(Dense(20, activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

cb = keras.callbacks.TensorBoard(
    log_dir="gs://BUCKET/keras-mlengine", histogram_freq=1)
# train
model.fit(train_x, train_y,
          batch_size=100,
          epochs=20,
          verbose=2,
          callbacks=[cb],
          validation_data=(test_x, test_y))

# eval
score = model.evaluate(test_x, test_y, verbose=0)
pred = model.predict(test_x)

TensorBoardの出力先はGCSにしておけば後からローカルでもCloud Shellからでも参照できます。
Datalab、Jupyter NotebookからJobを投げたいとき
cloudml-magicをインストールしてください。runtime version 1.2にも対応しておきました。Kerasの利用例はこちら

ログはこんな感じに出力されます。
Untitled.png

SavedModelを作る

Online Predictionはv1からSavedModel形式しか対応していないので、KerasモデルからSavedModelに変換する必要があります。Sessionを取り出してSignature追加すればいいだけですが、ちょっと嵌りました。

from tensorflow.contrib.keras import backend
sess = backend.get_session()
x = sess.graph.get_tensor_by_name('dense_1_input:0')
y = sess.graph.get_tensor_by_name('ArgMax_1:0')
inputs = {"dense_1_input": tf.saved_model.utils.build_tensor_info(x)}
outputs = {"ArgMax_1": tf.saved_model.utils.build_tensor_info(y)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs=inputs,
    outputs=outputs,
    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)

# save as SavedModel
b = tf.saved_model.builder.SavedModelBuilder('gs://BUCKET/keras-mlengine/savedmodel')
b.add_meta_graph_and_variables(sess,
                               [tf.saved_model.tag_constants.SERVING],
                               signature_def_map={'serving_default': signature})
b.save()

SavedModelの保存先をGCSに指定すれば、Online Predictionでもそのまま使えます。Consoleから作成したSavedModelのGCSパスを指定すればOnline Prediction環境の出来上がり。

ちなみに入力と出力のTensorを取得するのに、Keras側のGraph生成ルールがわからず、とりあえずTensorBoardで確認してget_tensor_by_nameでTensorを引っ張ってきています。Sequentialplaceholderを突っ込むのが正しいやり方でしょうが、そうするとそこだけローレベルになってしまうので
追記
model.inputsmodel.outputsで入出力Tensorが取得できるようです。
しかしそれでも面倒なのでEstimatorexport_savedmodelみたいなのがKerasにも欲しいですね。

Online Prediction で推論する

ここは従来通りです。Discovery APIを使う場合はこんな感じ。

from oauth2client.client import GoogleCredentials
from googleapiclient import discovery
from googleapiclient import errors

PROJECTID = 'PROJECTID'
projectID = 'projects/{}'.format(PROJECTID)
modelName = 'keras-iris'
modelID = '{}/models/{}'.format(projectID, modelName)

credentials = GoogleCredentials.get_application_default()
ml = discovery.build('ml', 'v1', credentials=credentials)

request_body = {'instances': [{'dense_1_input': [5.4,  3.9,  1.3,  0.4]},
                              {'dense_1_input': [4.4,  3.2,  1.3,  0.2]},
                              {'dense_1_input': [4.3,  3.,  1.1,  0.1]},
                              {'dense_1_input': [5.,  3.5,  1.6,  0.6]},
                              {'dense_1_input': [5.9,  3.,  4.2,  1.5]},
                              {'dense_1_input': [7.7,  3.,  6.1,  2.3]},
                              ]}

request = ml.projects().predict(name=modelID, body=request_body)
try:
    response = request.execute()
except errors.HttpError as err:
    # Something went wrong with the HTTP transaction.
    # To use logging, you need to 'import logging'.
    print('There was an HTTP error during the request:')
    print(err._get_reason())
print(response)

するとこんな感じで帰ってきます。

{u'predictions': [{u'ArgMax_1': 0},
  {u'ArgMax_1': 0},
  {u'ArgMax_1': 0},
  {u'ArgMax_1': 0},
  {u'ArgMax_1': 1},
  {u'ArgMax_1': 2}]}

終わりに

GCPすごい、Cloud ML Engineすごい!・・・のですが、あまり情報がないのが寂しいところ。慣れちゃえばとっても簡単なので、皆さんどんどん使って色々事例を出してくださいね!


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。