post Image
Affineレイヤの逆伝播を地道に成分計算する

0. 背景

  • ゼロから作る Deep Learning – Pythonで学ぶディープラーニングの理論と実装」を手にニューラルネットワークの勉強をしていたのですが、5章・誤差逆伝播のAffilneレイヤの逆伝播の式変形の理解に時間がかかったので、小さな次元で計算してみました。その過程を自分の備忘録的に記述します。
  • 目標としては、低次元ながらも次の式が成分計算で求められることとします。($T$は転置行列を意味します)
\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &= \frac{\partial L}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T  \\
\frac{\partial L}{\partial \boldsymbol{W}} &=
\boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}
  • 初学者ですので、計算ミス、誤字などあると思います。優しくご教授いただけますと幸いです。

1. 低次元で活性化関数を考慮せず地道に成分計算をする

1
入力として$\boldsymbol{x} = (x_1\; x_2)$の2次元だとします。
この入力に対して、第1層目への出力を3つにしたい場合、$(2, 3)$の行列を右からかけます。

\boldsymbol{W} = \begin{pmatrix}
w_{11} & w_{21} & w_{31} \\
w_{12} & w_{22} & w_{32} 
\end{pmatrix}

出力$\boldsymbol{Y}$は

\begin{align}
\boldsymbol{Y} &= \boldsymbol{X} \cdot \boldsymbol{W} \\
&=
\begin{pmatrix}
x_1 & x_2
\end{pmatrix}
\begin{pmatrix}
w_{11} & w_{21} & w_{31} \\
w_{12} & w_{22} & w_{32} 
\end{pmatrix} \\
&= 
\begin{pmatrix}
w_{11}x_1+w_{12}x_2 & w_{21}x_1+w_{22}x_2 & w_{31}x_1+w_{32}x_2
\end{pmatrix} \\
&=
\begin{pmatrix}
y_1 & y_2 & y_3
\end{pmatrix} \tag{1.1}
\end{align}

となります。
損失関数$L$の入力$\boldsymbol{X}$による偏微分は$x_1, x_2$が$y_1, y_2, y_3$に出てくることに注意すると、

\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial x_1} & \frac{\partial L}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} & \frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_2}
\end{pmatrix} 
\end{align}

ここで

\begin{align}
\frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} = 
\begin{pmatrix}
\frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \frac{\partial L}{\partial y_3}
\end{pmatrix}
\cdot
\begin{pmatrix}
\frac{\partial y_1}{\partial x_1} \\
\frac{\partial y_2}{\partial x_1} \\
\frac{\partial y_3}{\partial x_1} 
\end{pmatrix}
\end{align}

なので

\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial y_1} \frac{\partial y_1}{\partial x_1} +
\frac{\partial L}{\partial y_2} \frac{\partial y_2}{\partial x_1} +
\frac{\partial L}{\partial y_3} \frac{\partial y_3}{\partial x_1} &
\frac{\partial L}{\partial y_1} \frac{\partial y_1}{\partial x_2} +
\frac{\partial L}{\partial y_2} \frac{\partial y_2}{\partial x_2} +
\frac{\partial L}{\partial y_3} \frac{\partial y_3}{\partial x_2} 
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1} w_{11} +
\frac{\partial L}{\partial y_2} w_{21} +
\frac{\partial L}{\partial y_3} w_{31} &
\frac{\partial L}{\partial y_1} w_{12} +
\frac{\partial L}{\partial y_2} w_{22} +
\frac{\partial L}{\partial y_3} w_{32}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \frac{\partial L}{\partial y_3}
\end{pmatrix} 
\begin{pmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{pmatrix} \\
&= \frac{\partial L}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T

\end{align}

一方、損失関数$L$の重み$\boldsymbol{W}$による偏微分は

\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{12}} & \frac{\partial L}{\partial w_{22}} & \frac{\partial L}{\partial w_{32}} 
\end{pmatrix} \\
\end{align}

です。
ここで、$(1.1)$式において、$w_{11}$は$y_1$だけに、$w_{12}$は$y_1$だけに、・・・$w_{31}$は$y_3$だけに、$w_{32}$は$y_3$だけに出てくることに注意すると、

\begin{align}
\frac{\partial L}{\partial w_{11}} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} \\
\frac{\partial L}{\partial w_{12}} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} \\
\frac{\partial L}{\partial w_{21}} &= \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} \\
\frac{\partial L}{\partial w_{22}} &= \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} \\
\frac{\partial L}{\partial w_{31}} &= \frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{32}} &= \frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{32}}
\end{align}

となります。
したがって、$\frac{\partial L}{\partial \boldsymbol{W}}$は

\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{12}} & \frac{\partial L}{\partial w_{22}} & \frac{\partial L}{\partial w_{32}} 
\end{pmatrix} \\
&= 
\begin{pmatrix}
\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} &
\frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} &
\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} & \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} &
\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1}x_1 &
\frac{\partial L}{\partial y_2}x_1 &
\frac{\partial L}{\partial y_3}x_1 \\
\frac{\partial L}{\partial y_1}x_2 &
\frac{\partial L}{\partial y_2}x_2 &
\frac{\partial L}{\partial y_3}x_2
\end{pmatrix} \\
&= \begin{pmatrix}
x_1 \\
x_2
\end{pmatrix}
\begin{pmatrix}
\frac{\partial L}{\partial y_1} &
\frac{\partial L}{\partial y_2} &
\frac{\partial L}{\partial y_3} 
\end{pmatrix} \\
&= \boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}

これで、

\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &=
\boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}

の導出が(低次元で、活性化関数も考慮していませんが)できました。

2. 活性化関数と2層目を考慮してみる

1
1では、各層の活性化関数を無視したり、2層目以降を考慮していませんでした。ここでは、2層目を導入し、活性化関数も導入してみようと思います。
1層目の活性化関数を$h$、2層目(出力層)の活性化関数を$\sigma$とおきます。
最初に結果をみると、このようになります。

L = \sigma(h(\boldsymbol{X} \cdot \boldsymbol{W}) \cdot \boldsymbol{W}^{(2)})

一つずつ見ていきます。
入力$\boldsymbol{X}$

\boldsymbol{X} = (x_1\; x_2)

1層目の入力

\boldsymbol{Y} = \boldsymbol{X}\cdot \boldsymbol{W}

1層目の出力(活性化関数を作用させる)

\begin{align}
h(\boldsymbol{Y}) &= h(\boldsymbol{X}\cdot \boldsymbol{W}) \\
&= \begin{pmatrix}
h(y_1) & h(y_2) & h(y_3)
\end{pmatrix}
\end{align}

2層目(出力層)の入力

\begin{align}
z &= h(\boldsymbol{Y})\cdot \boldsymbol{W}^{(2)} \\
&=
\begin{pmatrix}
h(y_1) & h(y_2) & h(y_3)
\end{pmatrix}
\begin{pmatrix}
w_1^{(2)} \\
w_2^{(2)} \\
w_3^{(2)}
\end{pmatrix} \\
&= w_1^{(2)}h(y_1) + w_2^{(2)}h(y_2) + w_3^{(2)}h(y_3)

\end{align}

2層目(出力層)の出力

\begin{align}
L &= \sigma (z) \\
&= \sigma (w_1^{(2)}h(y_1) + w_2^{(2)}h(y_2) + w_3^{(2)}h(y_3))
\end{align}

$\boldsymbol{X}$と$\boldsymbol{W}$の偏微分は1と同じですが、再掲します。

\begin{align}
\frac{\partial \sigma}{\partial \boldsymbol{X}} &= 
\begin{pmatrix}
\frac{\partial \sigma}{\partial x_1} & \frac{\partial \sigma}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} & \frac{\partial \sigma}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} \frac{\partial y_1}{\partial x_1} +
\frac{\partial \sigma}{\partial y_2} \frac{\partial y_2}{\partial x_1} +
\frac{\partial \sigma}{\partial y_3} \frac{\partial y_3}{\partial x_1} &
\frac{\partial \sigma}{\partial y_1} \frac{\partial y_1}{\partial x_2} +
\frac{\partial \sigma}{\partial y_2} \frac{\partial y_2}{\partial x_2} +
\frac{\partial \sigma}{\partial y_3} \frac{\partial y_3}{\partial x_2} 
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} w_{11} +
\frac{\partial \sigma}{\partial y_2} w_{21} +
\frac{\partial \sigma}{\partial y_3} w_{31} &
\frac{\partial \sigma}{\partial y_1} w_{12} +
\frac{\partial \sigma}{\partial y_2} w_{22} +
\frac{\partial \sigma}{\partial y_3} w_{32}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} & \frac{\partial \sigma}{\partial y_2} & \frac{\partial \sigma}{\partial y_3}
\end{pmatrix} 
\begin{pmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{pmatrix} \\
&= \frac{\partial \sigma}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T \\
\frac{\partial L}{\partial \boldsymbol{W}} &= 
\begin{pmatrix}
\frac{\partial \sigma}{\partial w_{11}} & \frac{\partial \sigma}{\partial w_{21}} & \frac{\partial \sigma}{\partial w_{31}} \\
\frac{\partial \sigma}{\partial w_{12}} & \frac{\partial \sigma}{\partial w_{22}} & \frac{\partial \sigma}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} &
\frac{\partial \sigma}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} &
\frac{\partial \sigma}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial \sigma}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} &
\frac{\partial \sigma}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} &
\frac{\partial \sigma}{\partial y_3}\frac{\partial y_3}{\partial w_{32}} 
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1}x_1 &
\frac{\partial \sigma}{\partial y_2}x_1 &
\frac{\partial \sigma}{\partial y_3}x_1 \\
\frac{\partial \sigma}{\partial y_1}x_2 &
\frac{\partial \sigma}{\partial y_2}x_2 &
\frac{\partial \sigma}{\partial y_3}x_2 
\end{pmatrix} \\
&=
\begin{pmatrix}
x_1 \\
x_2 
\end{pmatrix}
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} &
\frac{\partial \sigma}{\partial y_2} &
\frac{\partial \sigma}{\partial y_3}  
\end{pmatrix} \\
&= \boldsymbol{X}^T \cdot \frac{\partial \sigma}{\partial \boldsymbol{Y}}
\end{align}

蛇足ですが、成分表示をすると、

\frac{\partial L}{\partial w_{ji}} = 
\frac{\partial L}{\partial y_j} \frac{\partial y_j}{\partial w_{ji}}

となります。

3. より一般的な式にする

1, 2では、層の数も限られていましたし、各層の次元も少ないケースでしたが、より一般的に第$i$層、第$j$層、第$k$層に注目してみたいと思います。
ここでは、行列は成分表示で表します。
1

\begin{align}
a_j^{(j)} &= \sum_i w_{ji}^{(j)}z_i^{(i)} \\
z_j^{(j)} &= h(a_j^{(j)})
\end{align}

重み$w_{ji}^{(j)}$による微分は、$w_{ji}^{(j)}$が$a_j^{(j)}$のみに出現することから、

\begin{align}
\frac{\partial L}{\partial w_{ji}^{(j)}} &= \frac{\partial L}{\partial a_j^{(j)}}\frac{\partial a_j^{(j)}}{\partial w_{ji}^{(j)}} \\
&= \frac{\partial L}{\partial a_j^{(j)}}\frac{\partial}{\partial w_{ji}^{(j)}} ( \sum_i w_{ji}^{(j)}z_i^{(i)} ) \\
&= \frac{\partial L}{\partial a_j^{(j)}}z_i

\end{align}

また、$j$番目の入力$a_j^{(j)}$による微分は、$a_j^{(j)}$が$a_k^{(k)}$の変化を通じてしか誤差関数を変化させないことから、

\begin{align}
\frac{\partial L}{\partial a_j^{(j)}} &=
\sum_k \frac{\partial L}{\partial a_k^{(k)}}\frac{\partial a_k^{(k)}}{\partial a_j^{(j)}} \\
&= \sum_k \frac{\partial L}{\partial a_k^{(k)}}\frac{\partial}{\partial a_j^{(j)}} ( \sum_j w_{kj}^{(k)}z_j^{(j)} ) \\
&= \sum_k \frac{\partial L}{\partial a_k^{(k)}} w_{kj} \frac{\partial h(a_j^{(j)})}{\partial a_j^{(j)}} \\
&= \frac{\partial h(a_j^{(j)})}{\partial a_j^{(j)}}\sum_k w_{kj} \frac{\partial L}{\partial a_k^{(k)}} 
\end{align}

となる。

参考

  1. ゼロから作るDeep Learning
    ――Pythonで学ぶディープラーニングの理論と実装
  2. パターン認識と機械学習
  3. 実装ノート

『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。