post Image
「SGDR: Stochastic Gradient Descent with Warm Restarts」をちょっと改良してKerasで実装した

Stocastic Gradient Descent with Warm Restarts(SGDR)は学習率の減衰手法です。Shake-Shakeでこの方法が使われていたので軽く調べてみました。元の論文には含まれていませんが、減衰の発動にトリガーをつけてKerasで実装してみました。

ちなみにPyTorchの場合、torch.optim.lr_scheduler.CosineAnnealingLRで実装済です。Kerasでは組み込みでないので自分で実装しました。Kerasの場合、ReduceLROnPlateauというコールバックがあるので、SGDRにこだわらない場合はこれで足りるかもしれません。

学習率の減衰がなぜ重要なのか

CIFAR-10をCNNで分類していると、精度90%(エラー率10%)あたりに壁があります。これは90%を超えるには学習率が高すぎるためです。最初は学習が上手く進んでいても、あるところからは学習率が高すぎて解の近くをぐるぐる周回してしまうためです1

これを解決させるために、教科書的には指数関数で学習率を減衰させるというのがありますが、CNNの論文を読んでいるとよく出てくるのが「全体の50%で学習率を1/10にし、75%でさらに1/10にする」という手法です。例えばResNetの論文では学習の停滞を「error plateau2」と表現し、初期学習率を0.1とし、6万4000回の繰り返しのうち、3万2000回目で学習率を1/10にし、4万8000回目でさらに1/10にするという実装をしています。確かに32kあたりでガクッとエラー率が落ちてますね(ResNetの論文より)。

sgdr02.png

これは学習が100~300epochのように短い場合は有効ですが、例えばShake-Shakeのように1800epochも訓練させる場合はもう少し難しい関数でゆっくり減衰させる必要があります。90%の壁を超えるのに900epochもいらないわけですから。

ちなみにこのSGDRの「Warm Restarts」というのは学習率を減衰させる一方で、周期的に再起動(restart)つまり上げることを可能にする手法です。減衰させたままだと時間がかかってしまいますからね。ちなみにShake-Shakeの場合はCIFAR-10に対してエラー率2.86%という驚異的な値を出していますが、なぜかこのWarm Restartsを使っていないんですよね(コサイン関数の減衰のみ使っています)。これが面白いです。

Stochastic Gradient Descent with Warm Restarts(SGDR)

以下の式で学習率を決定します。tエポック目で、i回目のRestart(再起動)とします。

$$\eta_t=\eta_{min}^i+\frac{1}{2}(\eta_{max}^i-\eta_{min}^i)(1+\cos(\frac{T_{cur}}{T_i}\pi)) $$

コサイン関数で減衰させるのが特徴。$T_{cur}$は再起動してからのエポック数で、訓練開始時は1とします。$T_i$は再起動までのエポック数です。このような複雑な学習率のコントロールができます(図は論文より)。

sgdr01.png

論文では、最大と最小の学習率を与えるか、$T_0$と、$T_{i+1}=T_{mult}T_i$なる$T_{mult}$を与えることで、再起動の周期を変えることを可能にしています。例えば、図のエメラルドグリーンやピンクのラインはそうですよね。

式の意味がわかりづらいので実際に計算してみましょう。$T_0=5, T_{mult}=2$として書くパラメーターを表に書き出してみます。学習率の最大値は1e-1、最小値は1e-5とします。

epoch T_cur T_i lr log(lr)
1 1 5 0.090451805 -1.043582764
2 2 5 0.065454305 -1.184061787
3 3 5 0.034555695 -1.461480363
4 4 5 0.009558195 -2.019624097
5 5 5 0.00001 -5
6 1 10 0.097553071 -1.010759056
7 2 10 0.090451805 -1.043582764
8 3 10 0.079391324 -1.100226957
9 4 10 0.065454305 -1.184061787
10 5 10 0.050005 -1.300986568
11 6 10 0.034555695 -1.461480363
12 7 10 0.020618676 -1.685739219
13 8 10 0.009558195 -2.019624097
14 9 10 0.002456929 -2.609607311
15 10 10 0.00001 -5
16 1 20 0.099384479 -1.002681436
17 2 20 0.097553071 -1.010759056
18 3 20 0.094550871 -1.024334465

なんとなくわかりましたか?学習率が最小値に達したら$T_i$を倍にしています。これを繰り返すことで先程の図のような複雑な曲線が描けることになります。

とりあえずコード書いてみた

from keras.callbacks import Callback, LearningRateScheduler
import numpy as np

class LearningRateCallback(Callback):
    def __init__(self, lr_max, lr_min, lr_max_compression=5, t0=10, tmult=1, trigger_val_acc=0.0, show_lr=True):
        # Global learning rate max/min
        self.lr_max = lr_max
        self.lr_min = lr_min
        # Max learning rate compression
        self.lr_max_compression = lr_max_compression
        # Warm restarts params
        self.t0 = t0
        self.tmult = tmult
        # Learning rate decay trigger (早い段階で減衰させても訓練が遅くなるだけなので)
        self.trigger_val_acc = trigger_val_acc
        # init parameters
        self.show_lr = show_lr
        self._init_params()        

    def _init_params(self):
        # Decay triggered
        self.triggered = False
        # Learning rate of next warm up
        self.lr_warmup_next = self.lr_max
        self.lr_warmup_current = self.lr_max
        # Current learning rate
        self.lr = self.lr_max
        # Current warm restart interval
        self.ti = self.t0
        # Warm restart count
        self.tcur = 1
        # Best validation accuracy
        self.best_val_acc = 0

    def on_train_begin(self, logs):
        self._init_params()

    def on_epoch_end(self, epoch, logs):
        if not self.triggered and logs["val_acc"] >= self.trigger_val_acc:
            self.triggered = True

        if self.triggered:
            # Update next warmup lr when validation acc surpassed
            if logs["val_acc"] > self.best_val_acc:
                self.best_val_acc = logs["val_acc"]
                # Avoid lr_warmup_next too small
                if self.lr_max_compression > 0:
                    self.lr_warmup_next = max(self.lr_warmup_current / self.lr_max_compression, self.lr)
                else:
                    self.lr_warmup_next = self.lr
        if self.show_lr:
            print(f"epoch = {epoch+1}, sgdr_triggered = {self.triggered}, best_val_acc = {self.best_val_acc}, " + 
                  f"current_lr = {self.lr:f}, next_warmup_lr = {self.lr_warmup_next:f}, next_warmup = {self.ti-self.tcur}")

    # SGDR
    def lr_scheduler(self, epoch):
        if not self.triggered: return self.lr
        # SGDR
        self.tcur += 1
        if self.tcur > self.ti:
            self.ti = int(self.tmult * self.ti)
            self.tcur = 1
            self.lr_warmup_current = self.lr_warmup_next
        self.lr = float(self.lr_min + (self.lr_warmup_current - self.lr_min) * (1 + np.cos(self.tcur/self.ti*np.pi)) / 2.0)
        return self.lr

epochの終わりに読み込ませるコールバックとSGDRの学習率の調整のコールバックを同一クラスにおくことで、val_accの情報を共有するというスタイル。

改良点

  1. SGDRを発動するためのトリガーを設定した
    学習率減衰が必要になるのはある程度学習が進んでからで、最初から減衰させると学習が遅くなるだけなので、発動トリガーを作った。発動条件は「val_accが指定したパーセント以上」。CIFAR-10の場合は85%~90%ぐらいにするといいかも。
  2. Val_accがよくなったときの学習率を記録して次の再起動時に持ち越す
    最大学習率で再起動してしまうと学習率が高すぎて、精度が落ちることが確認されています。そこで、val_accがよくなったときのみ最適な学習率として記録し、その学習率を最大学習率として再起動するという方法を取ります。ただし、これだとすぐ学習率が最低値になってしまうので、lr_max_compressionを設定し学習率のボトムラインを計算します。5とか10とかがいいと思います。

イメージ

全てのエポックにおいてval_accが良くなったと仮定すると次のようになります。

sgdr03.png

ここで$l_{comp}$はlr_max_compressionの値です。つまり、指数関数的減衰(Exponential decay)+コサイン関数による学習率の減衰となります。$l_{comp}=1$とすればオリジナルのSGDRと同じです。

使い方

lr_cbs = LearningRateCallback(0.1, 0.0001, lr_max_compression=5, t0=10, tmult=2, trigger_val_acc=0.85)
sgdr = LearningRateScheduler(lr_cbs.lr_schduler)

これをfitのコールバックで食わせます。

model.fit_generator(traingen.flow(X_train, y_train, batch_size=128), , callbacks=[lr_cbs, sgdr]))

必ずlr_cbs→sgdrの順番で食わせてください。オプティマイザーはデフォルトのSGDなので、モメンタムやAdamとも併用可能です。

テストしてみる

WideResNetっぽい何かを作って実験してみました。データは相変わらずCIFAR-10です。

レイヤー レイヤー数 画素数 チャンネル数
Conv1 9 32×32 80
DownSampling1 1 16×16 160
Conv2 8 16×16 160
DownSampling2 1 8×8 320
Conv3 8 8×8 320
Output 1 320

深さは28です。WideResNetの表記に直すなら、「深さ28、k=5のWideResNet」となりました。深さ28のWideResNetの場合、k=10のほうが一般的ですが、計算時間を短縮するためにk=5にしました。オリジナルのResNetはk=1で、オリジナルの6N+2(実際は6N+4)の公式に直すとN=4です。オリジナルのResNetは縦×横×チャンネル数の表記で、32x32x16→16x16x32→8x8x64でテストしていました。ここらへんの細かい設定が気になる方向けの説明です。

ちなみに最初の3チャンネル→80チャンネルへの拡張、ダウンサンプリングの際のチャンネル拡張に1×1畳み込みを使っています。ダウンサンプリングはPoolingを使わないで、stride=2の1×1畳み込みを使いました。

条件

以下の3条件で実験しました。最大学習率は0.01、最小学習率はその1/100としました。150エポック訓練させました。SGDRの発動条件はval_accが85%としました。発動以前は最大学習率で訓練させます。

  1. $T_0=10, T_{mult}=2, l_{comp}=5$:上のグラフに示したWarmRestartsの方法です。
  2. $T_0=120, T_{mult}=1, l_{comp}=100$:WarmRestartsが実質なしの方法です。120エポック終わると最小学習率に近い値での訓練になる。
  3. 従来の方法。全体の50%の75エポック目で学習率を1/10、全体の75%の112エポック目でさらに1/10とする方法です。

時間の関係上各1回のみやりました。コードはこちらにあります。その他の条件はコード見てください。
https://gist.github.com/koshian2/4a8b27a1368db17d3cd6228ff01c876f

結果

sgdr04.png

最小Validation errorは次のようになりました。

  • ケース1:7.00%
  • ケース2:5.27%
  • ケース3:5.37%

太い線がValidation error、点線がTraining errorです。これを見ると従来の方法(Case 3)がかなり強い。Case 1の場合は90%の壁を超えるのは速いですが、明らかに再起動が悪さをしていています(これでも指数減衰の要素が入っているから副作用はかなり抑えたほうで元の論文はもっと暴れている)。30エポック目あたりの精度が悪くなっているのは再起動の影響で、この悪化を取り戻すのに10エポックぐらいかかって、その間に学習率が下がりきって学習率のスイートスポットを逃しているように見えます。CIFARの90%以降では、学習率1/10の訓練時間を長くするのが精度向上のコツのようです。逆に1/100は学習率が小さすぎて最後の微調整に使うのにはいいのかもしれませんが、ろくに学習が進みません(じゃあ極端な話val_acc見ながらStep Decayさせるだけでよくない?→それなんてReduceLROnPlateau?)。プラトー大量にあるデータの場合は効果あるのかもしれませんが、少なくともCIFARの場合、一度下げた学習率をまた上げるのはあまり良くないようです。

逆に再起動なしの場合(Case 2)の場合はいい感じで、Shake-Shakeのように何千epochの場合は従来のStep decayでコントロールするのが難しくなってくるので、再起動なしのコサインカーブでコントロールするのはかなり良さそうに見えます。ほとんど誤差みたいなものですが、再起動なしのCase 2のほうが従来の方法のCase3よりもエラーは0.1%少ないです。Shake-Shakeが再起動入れなかったのは確かにそれが無難だなという印象です。対数スケールじゃなくて線形スケールのコサインカーブが良いというのはなかなか興味深いですね。

まとめ

  • SGDRはコサインカーブで学習率をコントロールするアルゴリズムである。ただの学習率スケジューラーなので、SGDだけではなくRMSPropやAdamにも応用できる。
  • SGDRの再起動(Warm restarts)はかなり上手くチューニングしないと使いこなすのが難しい。CIFAR-10の場合、再起動を使わずにコサインカーブだけで訓練させるのが無難。

参考文献


  1. ゴルフで始めはカップから遠いため強く打ってよくても、カップに近くなると力を手加減しないとカップを通り越し、なかなか入らないのとイメージ的には近いです。 

  2. プラトー(plateau)というと普通のほうは最適化のほうをいい、勾配ベクトルのいくつかの次元の0に近くなっているため学習が停滞するポイントをいいます。この意味でのプラトーから抜け出すには学習率を上げるほうが正解です。ここではエラー率をグラフに書いたときにあたかもこのプラトーのようになっているからこういう表現をするのだと思います。 


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。