post Image
KerasでLSGAN書く

KerasでLSGAN書いてみました。

元論文:Least Squares Generative Adversarial Networks

この記事のリポジトリ:https://github.com/t-ae/watch-generator-keras

LSGANとは?

二乗誤差を使うことで普通のDCGANより本物に近い画像が得られるらしい。
GANのバリエーションはいろいろありますがLSGANについて特筆すべきことは実装が簡単ということです。

準備

データセット

Gressiveから集めた腕時計画像1300枚ほど。
正面向きのに限定しています。元のサイズは148×190。

画像の入出力サイズ

コンテンツ的に縦長の画像を作りたいので適切なサイズを考えます。
どのような手段を使うとしてもアップサンプリングの倍率が大きいのはよろしくないので、
散々悩んだ末6×7を元に2倍アップサンプリングを4回かけて得られる96×112を使うことにしました。
正方形の画像を使っているケースをよく見ますがこういうめんどくささを避けるためでもあるのかなぁと思いました。

実装

目的関数

LSGANの目的関数は以下です。

\min_D V_{LSGAN}(D) = \frac{1}{2}\mathbb{E}_{{\bf x} \sim p_{data}({\bf x})}[(D({\bf x}) - b)^2] + \frac{1}{2}\mathbb{E}_{{\bf z} \sim p_z(\bf{z})}[(D(G({\bf z})) - a)^2] 
\min_G V_{LSGAN}(G) = \frac{1}{2}\mathbb{E}_{{\bf z} \sim p_z({\bf z})}[(D(G({\bf z})) - c)^2] 

a, b, cは今回は0, 1, 1にしました(詳細については論文を参照してください)。
GeneratorのほうはMSEを半分にするだけでいいのですが、Discriminatorのほうはラベルごとに平均する処理が必要になります。

def create_lsgan_d_loss(a, b):
    def loss_func(y_true, y_pred):
        a_mask = K.cast(K.equal(y_true, a), K.floatx())
        b_mask = K.cast(K.equal(y_true, b), K.floatx())
        a_loss = K.sum((y_pred * a_mask - a) ** 2) / K.sum(a_mask)
        b_loss = K.sum((y_pred * b_mask - b) ** 2) / K.sum(b_mask)
        return (a_loss + b_loss) / 2
    return loss_func

(他の方の記事にロスはサンプルごとに返すと書いてあるのですが、ソースのこのへんを読んだ感じだとウェイトやマスクをかけるために用意されている関数がそうなっているだけで、逆にそれらを使わないなら今回のようにサンプルをまたいで計算しちゃっても良いんじゃないかと思ってます。)

ネットワーク

Generator

def create_generator():

    normalization = BatchNormalization

    inp = Input([Z_DIMENSION])
    x = Dense(7*6*256, use_bias=False)(inp)
    x = normalization()(x)
    x = ELU()(x)
    x = Reshape([7, 6, 256])(x)
    x = Conv2DTranspose(256, 3, padding="same", strides=2, use_bias=False)(x)  # 14x12
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(256, 3, padding="same", strides=2, use_bias=False)(x)  # 28x24
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(128, 3, padding="same", strides=2, use_bias=False)(x)  # 56x48
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(64, 3, padding="same", strides=2, use_bias=False)(x)  # 112x96
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(3, 5, padding="same", strides=1)(x)
    x = Activation("tanh")(x)
    return Model(inp, x, name="generator")

特にいうことなし。

Discriminator

def create_discriminator(out="linear"):

    normalization = InstanceNormalization

    return Sequential([
        InputLayer([112, 96, 3]),
        Conv2D(32, 7, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Conv2D(64, 5, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Conv2D(128, 3, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Conv2D(256, 3, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Flatten(),
        Dense(1),
        Activation(out)
    ])

こちらではInstance Normalizationを使っています(コード)。
以前も書きましたがBatchNormalizationだと本物と偽の画像を同時に突っ込むと両者間に隔たりがありすぎて学習不可能になってしまうので、それの回避のためです。一応本物でtrain, 偽でtrainと2ステップに分けて学習することはできるのですが、LSGANだと目的関数が両方を同時に要求するのでこうなっています。
ちなみに前の記事ではNormalizationを入れないで学習していましたが、入れたほうが長期的にG-D間の均衡が保てるようでした。

Generatorの画像を保存しておく

Apple Machine Learning JournalによるとGeneratorが生成した画像の履歴を使うとDiscriminatorをより良く学習できる的なことが書いてあったので試しに入れてみてます。とはいえこの記事を読んだだけで、深く考えず適当に実装したせいか効果は正直分からなかったです。
データセットと同じサイズのバッファを用意し、毎エポック100枚をランダムに除去して新しい100枚を入れるというふうになっています。
(追記:論文読んできたらエポックごとじゃなくてステップごとに入れ替えてるみたいでした)

その他

こちらの記事の方法をいくつか取り込んでます。
1. zは正規分布から採取
2. 本物画像にエポックごとに減衰するガウシアンノイズをのせる
3. 説明済みのInstance Normalizationとか

またデータが1300枚だと少ないので縦4ピクセル、横2ピクセルまでランダムに移動するようにしました。

学習結果

Dのロスが1を超え続けてバグってる臭いのですがそれっぽいものは一応出てました。
すべて同じz群から採取した結果で、上の数字がDiscriminatorの出力になります。
GTX1060 6GBにて2日弱ほど回しました。

0エポック
0

100エポック
100

200エポック
100

500エポック
100

1000エポック
1000

3000エポック
1000

このへんが一番綺麗に出力できてます。インデックスの方向もそれっぽいです。

6000エポック
1000

黒に偏ってたり形状が怪しかったり。

6300エポック
1000

完全にぶっ壊れました。

まとめ

ケースの形状はかなり早い段階から取れていて、丸が綺麗に出力されるところなどは素晴らしいのですが、最もこだわりたい文字盤上の表現が全然でした。モデルを変更しまくっていてどれだったか覚えていないのですが、スモセコやオフセンターが出ているようなのも一応出来ました。アスペクト比すら今の値でなく、しかもMode collapse臭いですが画像が残っていたので貼っておきます。
old.png

最初は高画質なのができたら記事にしようと思っていたのですが、まだ満足なところまで辿り着いていません。
いろいろ試しているとKerasでは自由度が低かったり、ソースをたどりに行く必要があったりして結構手間だったので、PyTorchに移ることにしました。
今回の記事はとりあえずKerasでのまとめということで、PyTorchでもう少し高画質化を狙ってみようと思います。


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。