post Image
KerasのモデルとTensorFlowの最適化でWasserstein GANを学習する

はじめに


  • この記事でやること:Kerasのモデル,TensorFlowの最適化によってWasserstein GANを学習する.
  • 前提知識:GANの基本的な学習則
  • この記事が必要ない方:いずれかの深層学習ライブラリ,またはフルスクラッチで自由自在にコーディングできる方

最近深層生成系,いわゆるGANがホットですね.僕もいろいろ試して遊んでます.ライブラリはやっぱりKerasが直感的でわかりやすいですね.

さて,ふつうのDCGANやWasserstein GANなどはKerasでもサクッと実装・学習できてしまうわけですが,誤差の最適化などで細かい改善がされてるやつになると「これ,Kerasでどう書けばええんや,,,?」となってしまうこともあるかと思います.僕はImproved Training of Wasserstein GANsで提案されているGradient Penaltyの実装でそうなりました.Gradient Penaltyの説明は割愛しますが,これにあたりKerasでモデルを定義,学習はTensorFlowでやればイケるのでは?と考え,通常のWassersteinGANではなんとかこれがうまく学習するようになったので記事にしたいと思います.参考になれば幸いです.コードはこちらに置いておきます.

Requirements


  • python 3.5
  • Keras 2.0.5
  • TensorFlow 1.1.0
  • その他基本的なライブラリ
  • 学習の際には学習画像データ(CelebAなど)をどこかのディレクトリにダウンロードしておきましょう

モデルの構築(Keras)


GeneratorとDiscriminatorはKerasによって記述します.簡単ですね.(model.py参照)
GeneratorはDeconvolution(フィルターの大きさ4*4,ストライド2*2)によって乱数の入力を画像にしていきます.
DiscriminatorはConvolution(フィルターの大きさ4*4,ストライド2*2)によって入力画像から関数値を出力します.
KerasのFunctional API Modelで書いていますがSequential Model(model.addで連ねていく書き方)でも問題ないはずです.

学習(TensorFlow)


Wasserstein GANではDiscriminatorは入力画像の真偽を判定するのではなく関数値の出力を扱います.

\max_{D} \min_{G} \mathbb{E}_{x\sim p(x)}D(x) - \mathbb{E}_{z \sim p(z)}D(G(z))

これによってJensen-Shannon divergenceを最小化するオリジナルのGANよりも勾配学習が安定するわけですが,詳しい解説は元論文や他の解説をご覧ください.

さて,通常Kerasの場合用意されている損失関数か自分で定義したものを指定,model.compileするわけですが本記事ではこの最適化をTensorFlowで行います.

main.py
class WassersteinGAN:

    def __init__(self, ...):

        self.image_size = 64
        self.gen = generator # model.pyで定義したGenerator
        self.disc = discriminator # model.pyで定義したDiscriminator

        # placeholderによって入力画像,入力乱数を定義します.
        self.x = tf.placeholder(tf.float32, (None, self.image_size, self.image_size, 3), name = 'x') # 本物画像
        self.z = tf.placeholder(tf.float32, (None, 100), name = 'z') # 入力乱数
        self.x_ = self.gen(self.z) # 偽画像 <- 入力乱数

        self.d = self.disc(self.x) # 本物画像を入力とした出力
        self.d_ = self.disc(self.x_) # 偽物画像を入力とした出力

        self.d_loss = -(tf.reduce_mean(self.d) - tf.reduce_mean(self.d_)) # Discriminatorの目的関数
        self.g_loss = -tf.reduce_mean(self.d_) # Generatorの目的関数 

        # Optimizerを設定. Generatorの学習率を少し小さめに設定しています.
        self.d_opt = tf.train.RMSPropOptimizer(learning_rate = 5e-5)\
                     .minimize(self.d_loss, var_list = self.disc.trainable_weights)
        self.g_opt = tf.train.RMSPropOptimizer(learning_rate = 1e-5)\
                     .minimize(self.g_loss, var_list = self.gen.trainable_weights)

        # TensorFlowのsessionを設定.
        self.sess = tf.Session()
        K.set_session(self.sess) # ← Kerasと併用時には必要っぽいです

    def train(self, ...):

これらの設定のもとで実際にデータを流し,学習していきます.

僕のコードではさらに入力画像・入力乱数を取ってくるためのクラス(misc/dataIO.py, InputSampler)を作りました,
以降のコード中に現れるsamplerはこのInputSamplerのインスタンスであり,image_sampleメソッドによって本物画像のミニバッチを,noise_sampleメソッドによって入力乱数のミニバッチを返します.
reloadメソッドは学習データの画像枚数が多い場合,これを分割して保持するためのメソッドです.(よく使われる顔画像データセットCelebAだと20万枚以上の画像が含まれるので,メモリの都合上このような仕様にしました.)

またKerasで実装したモデルにBatchNormalizationが含まれている場合,データを流す際にK.learning_phase()を設定する必要があります.それも含めて次のコードに記載します.

main.py
class WassersteinGAN:

    def __init__(self, ...):

    # --略--

    def train(self, ...):

        for e in range(epochs):
            for batch in range(num_batches):

                # Discriminatorは多く学習
                for _ in range(5):
                    # リプシッツ連続を保証するためのweight clipping
                    d_weights = [np.clip(w, -0.01, 0.01) for w in self.disc.get_weights()]
                    self.disc.set_weights(d_weights)

                    # 本物画像のミニバッチ
                    bx = sampler.image_sample(batch_size)
                    # 入力乱数のミニバッチ
                    bz = sampler.noise_sample(batch_size)
                    # プレースホルダーに流す入力とK.learning_phase()を指定して, Discriminatorを学習
                    self.sess.run(self.d_opt, feed_dict = {self.x: bx, self.z: bz,
                                                           K.learning_phase(): 1})

                bz = sampler.noise_sample(batch_size, self.z_dim)
                # プレースホルダーに流す入力とK.learning_phase()を指定して, Generatorを学習
                self.sess.run(self.g_opt, feed_dict = {self.z: bz,
                                                       K.learning_phase(): 1})

                # ロスを出力する場合
                d_loss, g_loss = self.sess.run([self.d_loss, self.g_loss],
                                               feed_dict = {self.x: bx, self.z: bz,
                                                            K.learning_phase(): 1})
                print('epoch : {}, batch : {}, d_loss : {}, g_loss : {}'\
                      .format(e, batch, d_loss, g_loss))

モデルはKerasで記述しているのでパラメータの保存もKeras式で可能です.僕はエポックごとに保存することが多いです.
※TensorFlowのセッションの中でKerasのやり方でモデルを保存しようとするとうまくできないかもしれません.確認します.
モデルパラメータはKeras式model.save_weights()で保存できます.保存したパラメータのロードについてもmodel.load_weights()で可能です.
※ただしtensorflowで学習したパラメータについてはtf.Session()で動かさないとダメっぽいです.

main.py
class WassersteinGAN:

    def __init__(self, ...):

    # --略--

    def train(self, ...):

        for e in range(epochs):
            for batch in range(num_batches):

            # --略--

            # Kerasのやり方でパラメータが保存できる
            self.gen.save_weights('path/to/g_{}epoch.h5'.format(e))
            self.disc.save_weights('path/to/d_{}epoch.h5'.format(e))

これでそこそこうまく学習してくれましたが生成画像は少しぼやけていました.
一般にWasserstein GANはDCGANよりも少しぼやけるらしいのですが,学習率などのチューニングによってはもうちょい綺麗な生成ができる気もします.

最後に


こうしてKerasでモデルを構築,TensorFlowで学習ができました.
最近は最適化を工夫した手法が多く提案されているように感じるので,そういったものの追試に少しでも役立てば幸いです.
ご指摘や改善点などもあればお気軽にお願いします.

追記


簡単なCNNで同じようにKerasとtensorflowで学習するコードを書きました.GANとかいいからって方はこちらのがわかりやすいと思います.


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。