post Image
【 self attention 】簡単に予測理由を可視化できる文書分類モデルを実装する

はじめに

Deep Learning モデルの予測理由を可視化する手法がたくさん研究されています。
今回はその中でも最もシンプルな(しかし何故かあまり知られていない)self attentionを用いた文書分類モデルを実装したので実験結果を紹介します。
この手法では、RNNモデルが文書中のどの単語に注目して分類を行ったか可視化することが可能になります。

attentionの復習

attentionとは(正確な定義ではないですが)予測モデルに入力データのどの部分に注目するか知らせる機構のことです。
attention技術は機械翻訳への応用が特に有名です。
例えば、日英翻訳モデルを考えます。翻訳モデルは”これはペンです”という文字列を入力として”This is a pen”という英文を出力しますが、「pen」という文字を出力する際、モデルは入力文の「ペン」という文字に注目するはずです。このように入力データのある部分に「注目する=attention」という機構を予測モデルに組み込むことで、種々のタスクにおいいて精度が向上することが報告されてきました。
また、このattentionを可視化することで「入力データのどの部分に注目して予測を行ったか」という形で予測理由の提示を行うことができます。
attentionについての説明と実装は

がとても参考になります。

self attention を利用した分類

今回は、attentionの技術を利用して、予測理由が可視化できる文書分類モデルを実装しました。
self-attentive sentence embedding という論文の手法を単純化したものになります。
この手法は次のような手順で予測を行います。

  1. bidirectional LSTMで文書を変換
  2. 各単語に対応する隠れ層(下図$h_i$)を入力とし、予測の際その単語に注目すべき確率(self attention 下図$A_i$)をNeural Networkで予測
  3. self attention の重み付で各単語に対応する隠れ層を足し合わせたものを入力とし、Neural Networkで文書のラベルを予測

この$A_i$を可視化してやれば、モデルが予測の際どの単語に注目したかを知ることができます。
(オリジナル論文では複数個のself attentionを利用する方法が提案されているのですが、今回は簡易のためattentionは1種類としています。)

image.png

実装

上記手法をpytorchで実装してみました。
bidirectional LSTMの部分は次のような感じになります。

class EncoderRNN(nn.Module):
    def __init__(self, emb_dim, h_dim, v_size, gpu=True, v_vec=None, batch_first=True):
        super(EncoderRNN, self).__init__()
        self.gpu = gpu
        self.h_dim = h_dim
        self.embed = nn.Embedding(v_size, emb_dim)
        if v_vec is not None:
            self.embed.weight.data.copy_(v_vec)
        self.lstm = nn.LSTM(emb_dim, h_dim, batch_first=batch_first,
                            bidirectional=True)

    def init_hidden(self, b_size):
        h0 = Variable(torch.zeros(1*2, b_size, self.h_dim))
        c0 = Variable(torch.zeros(1*2, b_size, self.h_dim))
        if self.gpu:
            h0 = h0.cuda()
            c0 = c0.cuda()
        return (h0, c0)

    def forward(self, sentence, lengths=None):
        self.hidden = self.init_hidden(sentence.size(0))
        emb = self.embed(sentence)
        packed_emb = emb

        if lengths is not None:
            lengths = lengths.view(-1).tolist()
            packed_emb = nn.utils.rnn.pack_padded_sequence(emb, lengths)

        out, hidden = self.lstm(packed_emb, self.hidden)

        if lengths is not None:
            out = nn.utils.rnn.pad_packed_sequence(output)[0]

        out = out[:, :, :self.h_dim] + out[:, :, self.h_dim:]

        return out

attentionクラスです。
LSTMの隠れ層を入力として、各単語へのattentionを出力します。

class Attn(nn.Module):
    def __init__(self, h_dim):
        super(Attn, self).__init__()
        self.h_dim = h_dim
        self.main = nn.Sequential(
            nn.Linear(h_dim, 24),
            nn.ReLU(True),
            nn.Linear(24,1)
        )

    def forward(self, encoder_outputs):
        b_size = encoder_outputs.size(0)
        attn_ene = self.main(encoder_outputs.view(-1, self.h_dim)) # (b, s, h) -> (b * s, 1)
        return F.softmax(attn_ene.view(b_size, -1), dim=1).unsqueeze(2) # (b*s, 1) -> (b, s, 1)

最後にattentionを利用して実際に文書分類を行う部分です。

class AttnClassifier(nn.Module):
    def __init__(self, h_dim, c_num):
        super(AttnClassifier, self).__init__()
        self.attn = Attn(h_dim)
        self.main = nn.Linear(h_dim, c_num)


    def forward(self, encoder_outputs):
        attns = self.attn(encoder_outputs) #(b, s, 1)
        feats = (encoder_outputs * attns).sum(dim=1) # (b, s, h) -> (b, h)
        return F.log_softmax(self.main(feats)), attns

これらのNeural Networkを同時に学習させます。

実験

今回はIMDB映画レビューのネガポジ判別を行ってみます。
このデータは映画のレビューに対して、positiveかnegativeかをタグ付けしたデータセットで、torchtextなどから簡単に利用することができます。

  • 単語の分散表現の次元は100
  • LSTMの隠れ層の次元は32

と比較的小さなNetworkを利用しましたが、90%程の精度を達成できました。

予測理由(attention) の可視化

検証用データを用いてattentionを可視化してみました。
赤いハイライトの濃さがattentionの強さを表しています。
(Qiitaってspanタグ使えないんだ。。)

正解:POSITIVE 予測:POSITIVE なデータ

pospos.png
good, brilliant などの、単語の意味自体がpositiveなものや、highly recommendといった映画のレビューの文脈ではpositiveといえるものがハイライトされているのが観察できます。

正:NEGATIVE 予測:NEGATIVE なデータ

negneg.png

negativeなレビューの場合も、worstやhateといった単語の意味自体がnegativeなものが強くattentionされています。

正解:POSITIVE 予測:NEGATIVE なデータ

posneg.png

これはpositiveなレビューなのにnegativeと予測してしまった例です。attentionを観察するかぎりpassという単語に注目してnegativeと予測してしまったようです。
たしかに、”pass”という単語は「こんな映画みても無駄だからpassしろ」といった感じでnegativeに使われることも多いですが、今回は”pass it on”で「他の人にもこの映画を広めてほしい」というニュアンスで使われていると思います。たぶん。(英検3級並感)
このイディオムをとらえきれなかったようですね。

正解:NEGATIVE 予測:POSITIVE なデータ

negpos.png
これは逆にnegativeなレビューなのにpositiveと判断してしまった例です。「この映画の製作陣が二度と再結成しないことを祈るよ。」というかなり皮肉に富んだレビューです。はっきりとnegative / positiveを表している単語が少なく、この言い回しの真意を読み取れなかったようです。

まとめ

今回はself attentionを使用して、予測理由が簡単に可視化できる文書分類モデルを実装しました。どの単語に注目したか可視化するだけでも結構説得力のあるモデルになっていると思います。
予測を間違えたデータの分析も予測理由の可視化ができるとわかりやすいですね。

コード

https://github.com/nn116003/self-attention-classification

参考文献

[1]Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu,Bing Xiang, Bowen Zhou & Yoshua Bengio A STRUCTURED SELF-ATTENTIVE SENTENCE EMBEDDING, ICLR 2017


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。