post Image
ロジスティック回帰の数式を分かりやすく解説、、、できたらいいな

Courseraで機械学習を勉強したついでに、pythonでロジスティック回帰を実装しているのですが、数式の理解に苦戦したため一つずつ噛み砕いていきたいと思います。

注)数学をしっかり勉強した訳ではないので間違った箇所があったらスミマセン


やりたいこと

下表のようなデータがあった場合に合格・不合格を予測できるモデルを作成し、新しく5教科の点数データを当てはめ、合格・不合格を予測したい。

結果

国語

数学

英語

理科

社会

合格

40

80

70

60

90

不合格

40

20

40

60

30

合格

70

40

90

60

70

合格

50

90

60

60

80

不合格

20

60

80

30

50

合格

50

70

60

80

90


数式の確認

まず、ロジスティック回帰の数式について確認していきます。

はっきり整理しておきたいのは以下の3つです。

・仮説関数

・目的関数(対数尤度関数)

・最急降下法


仮説関数

仮説関数は先ほどの表に合わせて書くとこのような感じになるかと思います。

w^{T}x=w_0+w_1x_1+w_2x_2+w_3x_3+w_4x_4+w_5x_5

上式は非常に直感的で分かりやすいですが、これだけだと通常の重回帰分析になってしまい$w^{T}x$が1以上の数値だったりマイナスになったりで、合格・不合格のようなゼロイチに分けられるような問題を適切に予測することができなくなってしまいます。

そこで、$w^{T}x$をシグモイド関数に噛ませることで計算結果を0から1の間に調整し確率とみなしていきます。

σ(z)=\frac{1}{1+e^{(-z)}}

figure_1.png

シグモイド関数は上図のようにどのような$z$に対しても0から1に収まる性質を持っています。$z$が大きくなると1に近づき、小さくなると0に近づいていきます。

これを利用して、$z$に先ほどの$w^{T}x$を渡してあげます。ついでに、こいつを仮説関数として定義しておきます。

h(w^{T}x)=σ(w^{T}x)=\frac{1}{1+e^{(-w^{T}x)}}

こうすることで、$w^{T}x$を0から1の間に収めることができ、$h(w^{T}x)$>0.5の時、合格、$h(w^{T}x)<=0.5$の時、不合格というような予測ができるようになります。


目的関数(対数尤度関数)

目的関数とは、仮説関数から予測した値と実際に与えられているデータの差を元に最小化もしくは最大化させる関数になります。

まず、仮説関数を最小化させることに基づいて変形させていきます。

-L(w)=-h(w^{T}x)^t(1-h(w^{T}x))^{1-t} \\

ただし、t=0, 1 \, (0が不合格, 1が合格)

この時、$t$は実際に与えられているデータとなります。この式のイメージとしては下表のようになります。合格の時には右辺の第1項、不合格の時には右辺の第2項が残る形です。

$L(w)$

$t$

$h(w^{T}x)$

-0.9

1

0.9

実際の値と予測値が近い

-0.2

1

0.2

実際の値と予測値が遠い

-0.1

0

0.9

実際の値と予測値が遠い

-0.8

0

0.2

実際の値と予測値が近い

実際の値と予測値の差が少なくなることで$L(w)$が小さくなっていることが分かります。しかし、これだけでは一件分のデータにしか対応していないので、全データに対応するように式を書き換えます。

-L(w)=-\frac{1}{M}\prod_{m=1}^{M} h(w^{T}x^m)^{t^m}(1-h(w^{T}x^m))^{1-{t^m}} 

ここで、$M$はデータセットの総数となり、$m$で各行へのアクセスを示します。例えば、最初の表をデータセットとすると一行目のデータがそれぞれ$t^1$, $x^1$のデータになります。

要するに、データセットの数だけ実際の値と予測値の計算を行い、それら全てを掛けたものを最小化するような$w$を探すというイメージです。

最後に、計算しやすいように対数をとります。その際に下式の対数の性質を利用します。

\log M^p=p \log M \\

\log MN = \log M + \log N \\
E(w)=-\log L(w)=-\frac{1}{M}(\sum_{m=1}^{M} t^m \log h(w^{T}x^m)+ (1-t^m) \log (1-h(w^{T}x^m)) )

対数をとったので指数は$\log$の前へ、かけ算は足し算へ書き換えられました。$\prod$が$\sum$に変形しているのもかけ算を足し算へ変形できる対数の性質によるものです。そして、上式の$E(w)$が目的関数になります。次に最急降下法を使用して目的関数$E(w)$を最小化するような$w$を求めていきます。


最急降下法

まず、目的関数を最小化させたいので$E(w)$を微分していきます。最小化・最大化といった場合にはとりあえず微分です!

\frac{\partial E(w)}{\partial w_j}=-\frac{\partial }{\partial w_j}\frac{1}{M}(\sum_{m=1}^{M} t^m \log h(w^{T}x^m)+ (1-t^m) \log (1-h(w^{T}x^m)) )

ここで、$w_j$は$w$に含まれている$j$番目の要素へのアクセスを示します。右辺を$w_j$で偏微分していきますが、ここで$w_j$を使用している式は$h(w^{T}x^m)$であり、$t^m$は微分には関係ありません。ですので、微分の対象となるのは$\log h(w^{T}x^m)$と$\log (1-h(w^{T}x^m))$になります。ここで、下式の対数を微分する際の性質を利用します。

(\log f(x))'=\frac{f'(x)}{f(x)} \\

分母に$\log$の中身、分子に$\log$の中身を微分したものを持ってきます。これを利用して目的関数の偏微分を進めていきます。

\frac{\partial E(w)}{\partial w_j}= \frac{1}{M}\sum_{m=1}^{M} -\frac {t^m}{h(w^{T}x^m)} 

\frac{\partial h(w^{T}x^m)}{\partial w_j} + \frac{(1-t^m)}{1-h(w^{T}x^m)} \frac{\partial h(w^{T}x^m)}{\partial w_j}

ここで、この後いい感じに式を変形させるために最初のシグモイド関数をいじります。

σ(w^{T}x^m)=\frac{1}{1+e^{(-w^{T}x^m)}} \\

\frac{1}{σ(w^{T}x^m)}=1+e^{(-w^{T}x^m)} \\
e^{(-w^{T}x^m)}=\frac{1}{σ(w^{T}x^m)}-1

逆数をとって$e^{(-w^{T}x^m)}$を$σ(w^{T}x^m)$の式で表しました。次に上式を使用して$\frac{\partial h(w^{T}x^m)}{\partial w}$を求めます。



\frac{\partial h(w^{T}x^m)}{\partial w_j} = \frac{\partial σ(w^{T}x^m)}{\partial w_j} = - \frac{1}{(1+e^{(-w^{T}x^m)})^2} e^{(-w^{T}x^m)} (-x_j^m) \\
= -( h(w^{T}x^m) )^2 ( \frac{1}{h(w^{T}x^m)}-1 ) (-x_j^m) \\
= h(w^{T}x^m)( 1-h(w^{T}x^m) )x_j^m

一行目の微分がちょっとめんどくさいですね。ここでは商関数の微分と$e$の合成関数の微分を利用します。商関数の微分はこんな感じ。



y = \frac{ f(x) }{ g(x) } の時、yをxで微分する \\
y' = \frac{ f'(x)g(x)-f(x)g'(x) }{ (g(x))^2 }

こいつの$g(x)$が今回の$1+e^{(-w^{T}x^m)}$の部分です。$f(x)$は分子の$1$の部分なので$f'(x)g(x)$は$0$になります。残りの$f(x)g'(x)$には$g'(x)$が含まれますが、ここは$e$の合成関数の微分を利用します。



y = e^{x^2} の時、 \\
t=x^2 と置くと y=e^t となり、これをxで微分する\\
y' = \frac{ dy }{ dx } = \frac{ dy }{ dt } \frac{ dt }{ dx } \\
= (e^t)'(x^2)'\\
= 2xe^t \\
t=x^2 なので\\
y' = 2xe^{x^2}

ちょっとややこしいかもしれませんが、簡単に言うと$e^{x^2}$はそのままで$e$の指数部分($x^2$)を微分したものを$e$へくっつける感じです。

これを$1+e^{(-w^{T}x^m)}$の微分に利用します。$1$の部分は定数の微分なので$0$になり$e^{(-w^{T}x^m)}$が残ります。そして、$e^{(-w^{T}x^m)}$の微分ですが、ここで$w^{T}x^m$は$w_0+w_1x_{1}^{m}+w_2x_{2}^{m}+w_3x_{3}^{m}+w_4x_{4}^{m}+w_5x_{5}^{m}$だったことを思い出します。今回は一般化のため$w_j$で微分しているため、$w^{T}x^m$が微分されて残る項も$w_j$の係数部分である$x_j^m$となります。

ごちゃごちゃしましたが以上で$\frac{\partial h(w^{T}x^m)}{\partial w_j}$が求まりました。こいつを$\frac{\partial E(w)}{\partial w_j}$へ戻してあげます。

\frac{\partial E(w)}{\partial w_j}= \frac{1}{M}\sum_{m=1}^{M} -\frac {t^m}{h(w^{T}x^m)} 

h(w^{T}x^m)( 1-h(w^{T}x^m) )x_j^m + \frac{(1-t^m)}{1-h(w^{T}x^m)} h(w^{T}x^m)( 1-h(w^{T}x^m) )x_j^m \\
= \frac{1}{M}\sum_{m=1}^{M} - t^m ( 1-h(w^{T}x^m) )x_j^m + (1-t^m) h(w^{T}x^m) x_j^m \\
= \frac{1}{M}\sum_{m=1}^{M} ( h(w^{T}x^m) - t^m ) x_j^m

だいぶスッキリしました。$h(w^{T}x^m) – t^m$が予測値と実際の値の差で、それに求めたい$w_j$に対応している$x_j^m$を掛けています。$\sum_{m=1}^{M}$で各データセットへ対応しています。最後に$w_j$の更新式を確認します。

w_j : = w_j - \alpha  \frac{\partial E(w)}{\partial w_j} \\

w_j : = w_j - \alpha \frac{1}{M}\sum_{m=1}^{M} ( h(w^{T}x^m) - t^m ) x_j^m

$\alpha$はこちらで与えてあげる学習率です。$\frac{\partial E(w)}{\partial w_j}$は$w_j$が微小増加した場合に$E(w)$がどれくらい増加(減少)するのかを表すので、それに学習率を掛けて$w_j$を更新するということになりますね。


まとめ

実装するにあたって、途中のごちゃごちゃした過程はいらないので、主に必要な数式は仮説関数、目的関数、最急降下法の3つになります。最急降下法で求めた$w$を仮説関数に適用し、新しいデータセット$x$を渡すことで合否を予測できますね。目的関数は最急降下法をループしている最中に毎回計算し、最後に推移をプロットしてあげることで最小化されたかどうか確認できます。次回はpythonでこれらの数式を実装していきたいと思います!


『 Python 』Article List