post Image
PCA Color Augmentationを使ってみよう

PCA Color AugmentationとはAlexNetの論文で提唱されたData Augmentationの方法の1つです。論文自体は2012年と比較的古いものの、主成分分析(PCA)を使っているため、データの色分布を加味した色の加減ができ、Data Augmentationとしてよく用いられるカラーチャンネルシフトよりも自然な画像が出来上がります。

補足追記:この文章中に出てくる「分散共分散行列」ですが、標準偏差の調整を行っているため、実際は中身が相関行列の分散共分散行列となっています。出てくる画像はあまり大した違いは出ませんが、この点が気になる方は注意してください。追記に改良コードとそのテストを記しました。

追記:GitHubに公開しました https://github.com/koshian2/PCAColorAugmentation

コード

PCA Color Augmentationのコードは以下のとおりです。こちらの実装を参考にしました。

import numpy as np

def pca_color_augmentation(image_array_input):
    assert image_array_input.ndim == 3 and image_array_input.shape[2] == 3
    assert image_array_input.dtype == np.uint8

    img = image_array_input.reshape(-1, 3).astype(np.float32)
    img = (img - np.mean(img, axis=0)) / np.std(img, axis=0)

    cov = np.cov(img, rowvar=False)
    lambd_eigen_value, p_eigen_vector = np.linalg.eig(cov)

    rand = np.random.randn(3) * 0.1
    delta = np.dot(p_eigen_vector, rand*lambd_eigen_value)
    delta = (delta * 255.0).astype(np.int32)[np.newaxis, np.newaxis, :]

    img_out = np.clip(image_array_input + delta, 0, 255).astype(np.uint8)
    return img_out

中身については後ほど解説します。np.uint8の形式の画像(配列)を1枚入力して、PCA Color Augmentation済みのnp.uint8形式の配列を返すというシンプルなものです。

動かしてみる

こちらにあるフリー素材のオウムの画像をサンプルに使います。可愛いですね。

pca_1.jpg

この画像を9回PCA Color Augmentationさせた結果が以下の通りです。

pca_2.png

なかなか良い出来上がりです。写真で露出を調整するとだいたいこんな感じで変わるので、かなり自然なData Augmentationであると言えるでしょう。

理論

AlexNetの論文の「4.1 Data Augmentation」に記されています。ざっくり言うと、入力画像のRGBの強さに応じてData Augmentationをするという手法です。具体的な計算方法はカラーチャンネルあたりの主成分分析(PCA)を使うというのがポイントです。

PCAを使うというのは論文中の表記ですが、よく次元削減の手法として使われる主成分分析とは少し意味合いが違って、ここでのPCAはただ単にカラーチャンネルあたりの固有ベクトルと固有値を求めるという位置づけです。次元削減をどっかでやっているのかな?と思い込むと理解しづらくなります。

いま、画像の座標$(x, y)$にあるピクセル値$I_{xy}$を$[I_{xy}^R, I_{xy}^G, I_{xy}^B]^T$をとしましょう。R, G, Bは各カラーチャンネルのピクセル値とします。ピクセル値が0~255で表される場合、例えば赤のピクセルだったら$[255, 0, 0]^T$となります。

PCA Color Augmentationではこの$I_{xy}$に以下の式で計算されるベクトルを足します。

\begin{bmatrix}\mathbf{p}_1 & \mathbf{p}_2 & \mathbf{p}_3\end{bmatrix} \begin{bmatrix}\alpha_1\lambda_1 & \alpha_2\lambda_2 & \alpha_3\lambda_3 \end{bmatrix}^T

それぞれのカラーチャンネルについて、$\mathbf{p}$は固有ベクトル、$\lambda$は固有値を表します。左側の$\mathbf{p}$の括弧全体は3×3の行列になります。

$\alpha$は乱数で、平均0、標準偏差が0.1の正規分布に従う乱数で、カラーチャンネルごとに独立にサンプリングします。この$\alpha$が各色をどれだけ強くするかという指標になります(後で確認します)。$\alpha\lambda$がともにスカラーなので、転置後の右側の括弧は3×1の行列となります。したがって、左と右の内積を取ると、3×3行列と3×1行列の内積なので、3×1行列となります。元のピクセル$I_{xy}$も3×1行列なのでこれに足せばよいですね。

では具体的にどう主成分分析を計算するのかというのがポイントですが、以下のプロセスになります。

  • $(y, x, c)$というshapeの画像の配列を、$(xy, c)$という2次元の行列に変換する
  • カラー単位で分散共分散行列を計算する。この行列はc×c=3×3の行列になる
  • この分散共分散行列を固有値分解し、固有値と固有ベクトルを求める

実はこの分散共分散行列→固有値分解というプロセスそのものが主成分分析なのです。以下の記事に詳しく書かれているのでご覧ください。

PCAとSVDの関連について
https://qiita.com/horiem/items/71380db4b659fb9307b4

ステップバイステップで確認する

最初に示したコードを、乱数の値を指定できる形にして動かしてみます。少しコメントをつけてみました。

def pca_color_step_by_step(image_array_input, random):
    assert image_array_input.ndim == 3 and len(random) == 3
    assert image_array_input.shape[2] == 3
    assert image_array_input.dtype == np.uint8
    # ピクセル, カラーチャンネルの形式に変換
    img = image_array_input.reshape(-1, 3).astype(np.float32)
    # カラーチャンネル単位で標準化
    img = (img - np.mean(img, axis=0)) / np.std(img, axis=0)
    # 分散共分散行列, 列単位で計算したいのでrowvar=Falseとする
    cov = np.cov(img, rowvar=False)
    # 固有値と固有ベクトルの計算
    lambd_eigen_value, p_eigen_vector = np.linalg.eig(cov)
    # PCA Color Augmentationによる増分
    delta = np.dot(p_eigen_vector, random*lambd_eigen_value)
    delta = (delta * 255.0).astype(np.int32)[np.newaxis, np.newaxis, :]
    # 出力画像
    img_out = np.clip(image_array_input + delta, 0, 255).astype(np.uint8)
    return img_out

まずはimage_array_input.reshape()で3階のテンソルから2階のテンソル(行列)に変換しています。そのあとに標準化して小数のスケールにするので、ここでuint8からfloat32にキャストしています。

次に、主成分分析で重要な標準化をしています。カラーチャンネルで集計するので、チャンネル間の平均、標準偏差で標準化します。

そして分散共分散行列を計算します。これはnp.covで簡単に計算できます。rowvar=Falseを指定しないと行(ピクセル間)で集計されてしまうので、カラーチャンネル間で集計されるように指定します。この結果は3×3行列になります。

固有値分解はnp.linalg.eigでできます。分散共分散行列を固有値分解し、3次元ベクトルの固有値と、3×3行列の固有ベクトルに分解されます。

あとはカラーチャンネルの増分を計算します。ここでint8とキャストしてしまうと、増分が大きいときにアンダーorオーバーフローすることがあるので増分は32ビットintとして定義しました。3階テンソルである入力画像と計算するために次元を拡張しています。

最後に0~255の範囲になるように値をクリッピングして、uint8にキャストして終わりです。これでPCA Color Augmentationを実装できました。あとは増分$\alpha$(random)を変えてどのように出力が変わるかを確認しましょう。

R, G, Bの乱数を-0.2, -0.1, 0, 0.1, 0.2の5パターン、計125パターンを出力して動画にしてみました。

pca_3.gif

実際は標準偏差0.1の正規乱数を使うので、-0.1~0.1が68%, -0.2~0.2が95%の割合で出現します。これが実際の乱数を使った例です(フリー素材の猫画像を使いました)。

pca_4.gif

明るいところはより明るくなり、暗い所は暗いままで確かに色調を維持しているのが確認できます。これはなかなか強そうなData Augmentationですね。

計算量の確認

便利そうなPCA Color Augmentationですが1つだけ気になるポイントがあります。それは主成分分析の内部で逆行列を計算をしているため、計算量の問題があるということです。逆行列の計算量は多いと$O(N^3)$というアルゴリズムとしては高価なオーダーなので、ボトルネックとなる可能性はあります1

次のコードでニューラルネットワークでの画像の読み込みを再現してみましょう。

import time
import numpy as np

def make_image(width):
    return (np.random.rand(width, width, 3) * 255.0).astype(np.uint8)

if __name__ == "__main__":
    start_time = time.time()
    for i in range(1000):
        img = make_image(256)
        img_aug = pca_color_augmentation(img)
    elapsed = time.time() - start_time
    print(elapsed)

乱数で画像を1000枚作成し、PCA Color Augmentationをコメントアウトするかどうかで処理時間を比較します。解像度256×256と512×512で各ケース3回測定しました。Google ColabのCPUインスタンス2で計測しました。単位は秒です。

# 256x256 PCA-Augなし
2.630277156829834 2.5239434242248535 2.574631452560425
# 256x256 PCA-Augあり
12.860137462615967 12.779702186584473 12.894489049911499

# 512x512 PCA-Augなし
11.647570610046387 12.336450576782227 11.551325798034668
# 512x512 PCA-Augあり
54.012206077575684 53.79096794128418 54.03123617172241

ケースごとの平均値を計算してまとめました。

ケース 平均秒 オーバーヘッド 枚 / 秒目安
256×256 なし 2.58
256×256 あり 12.84 10.27 97.39
512×512 なし 11.85
512×512 あり 53.94 42.10 23.75

解像度が大きくなっても2乗3乗のペースで計算量が増えないからまだよいものの、明らかに処理時間は増えています。縦、横の幅をそれぞれ$N$とするなら、おそらくこれは$O(N^2)$のオーダーでしょう。PCA Color Augmentationを入れると処理時間が5~6倍になるのは流石に多いので、もしかしたら訓練速度に影響が出るかもしれません。そこはケースバイケースなので、実際の訓練で計測してみないとわかりません。

追記:TPUで計測したところ、バッチサイズを上げるとPCA Color Augmentationがボトルネックとなってしまいました。CIFAR-10で訓練させてなしだと1エポック7秒ぐらいだったのが、ありだと16秒かかります。軽めのモデルだったのがいけなかったかもしれません。高速化させたいのならGPU/TPUでブーストできるKerasのカスタムレイヤーで定義すべきでしょうね。

追記:ちょっと難しいですがテンソル演算を使えばもう少し速くできるはずです。

PCA Color Augmentationのアルゴリズムを振り返ってよく考えると、逆行列は常に3×3行列で計算してるから、この計算量って逆行列の計算よりかは多分分散共分散行列の計算部分なんですよね。GPUでこれをできればベストですが、ちょっとこの遅さは気持ち悪いなというのが正直な感想です。もちろんCPUとGPUがうまく分業してボトルネックにならななければ何の問題もありません。

ただし、このPCA Color Augmentationはかなり強力で手法で、AlexNetの論文にも「ImageNetのTop1エラー率を1%以上減らした」とあります。また、実際の出力画像を見るとわかるように、Augmentation後の画像は元画像の光線や照明の加減を変化させたものに相当するので、よく用いられる単純なカラーチャンネルの増減よりもかなり自然な印象を与えます。また、これは自分の意見ですが、Data Augmentationはアルゴリズム次第では分類問題では使えても物体検出では使えない(使いづらい)ケースもあるなか、PCA Color Augmentationは画像だけの作用で完結するので物体検出でも簡単に使えます。なので、計算量をためらうことなく使ってみる価値は十分にあると思います。

まとめ

PCA Color Augmentationはカラーチャンネル単位で主成分分析をかけ、元の画像の色の分布を加味して、Data Augmentationをするアルゴリズムです。出てきた画像はかなり自然なので、ぜひ使ってみてください。単に行列の固有値分解(主成分分析)の応用例として見ても面白いです。

追記:もうちょっと改良してみる(標準偏差を同一化しない)

実は分散共分散行列と相関行列の違いを知らなくて3、うっかり標準偏差で割って元画像を標準化していました。標準偏差で割ってしまうとR,G,Bの標準偏差が同一化してしまうので、元の色の分布を損なう可能性があります。特に各チャンネルに色が偏った画像で発生しやすくなります。ちなみに平均で引くのは何も問題ありません。

ただし、スケーリングの都合上、分散を大きさを調整するのは必須で、これをどう両立するか悩みました。この方法がいいのではないでしょうか。スケーリング定数$k$を定義し、

$$ k(X-\mu)$$

という変換をします。$\mu$はチャンネル間の平均です。このときの分散共分散行列には次のような関係があります。

$$Cov(kX) = k^2Cov(X)$$

分散なのでkは2乗になります。ここで、各チャンネルの分散の合計を3(=チャンネル数)にするという制約をおきます。ちなみに標準偏差で割った場合は対角要素が全て1になるので分散の合計は常に3になります。各チャンネルの分散を$\sigma^2_R, \sigma^2_G, \sigma^2_B$とすると

\begin{align}
k^2(\sigma^2_R+\sigma^2_G+\sigma^2_B) &= 3 \\
k &= \sqrt{\frac{3}{\sigma^2_R+\sigma^2_G+\sigma^2_B}}
\end{align}

これでスケーリング定数kが求められます。コードは以下のように変わります。

改良版
def pca_color_augmentation_modify(image_array_input):
    assert image_array_input.ndim == 3 and image_array_input.shape[2] == 3
    assert image_array_input.dtype == np.uint8

    img = image_array_input.reshape(-1, 3).astype(np.float32)
    # 分散を計算
    ch_var = np.var(img, axis=0)
    # 分散の合計が3になるようにスケーリング
    scaling_factor = np.sqrt(3.0 / sum(ch_var))
    # 平均で引いてスケーリング
    img = (img - np.mean(img, axis=0)) * scaling_factor

    cov = np.cov(img, rowvar=False)
    lambd_eigen_value, p_eigen_vector = np.linalg.eig(cov)

    rand = np.random.randn(3) * 0.1
    delta = np.dot(p_eigen_vector, rand*lambd_eigen_value)
    delta = (delta * 255.0).astype(np.int32)[np.newaxis, np.newaxis, :]

    img_out = np.clip(image_array_input + delta, 0, 255).astype(np.uint8)
    return img_out

葉っぱの画像で比較

緑(G)の成分が多めの葉っぱの画像です。
leaf.jpg

上が変更前(標準偏差で割ったケース)、下が変更後(標準偏差で割らないケース)です。同一の乱数を使っています。
pca_leaf.gif

どっちがどうとは一概には言えないですが、変更後のほうが若干青(B)のチャンネルが上がりにくくなっているのがわかります。確かにこの画像青要素ほとんどないので、RGBが均一に上がる(変更前)はちょっとおかしいかなと思います。

夕日の画像で比較

赤(R)の成分が多めの夕日の画像です。
sunset.jpg

上が変更前、下が変更後です。
pca_sunset.gif

青成分がほとんど動かなくなった反面、赤のチャンネルがよく動くようになりました。

猫の画像で比較

元の画像
cat.jpg

比較
pca_cat.gif

普通の画像です。色彩に詳しい人でないと違いがわからないかも。

オウムの画像で比較

元の画像
pca_1.jpg

比較
pca_bird.gif

ぱっと見違いがよくわからない。色彩が偏った画像でないと目に見えて差が出ないのではと。


  1. ニューラルネットワークの計算量と比べると鼻くそみたいなものですが、Data Augmentationは通常インプット前のCPUで計算するため、ニューラルネットワークのようなデバイスによる高速化はできません。 

  2. Intel(R) Xeon(R) CPU @ 2.30GHz ×2(/proc/cpuinfoより) 

  3. 元データを平均で引き、標準偏差で割ると分散共分散行列は、対角要素が全て1である相関行列と等しくなります。詳しくはこちらに書きました。https://blog.shikoan.com/cov-corr-gram-matrix/ 


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。