post Image
ケモ・バイオインフォで今最もhotな手法 “Graph Convolutional Neural Networks” をChainerで試す。

※ あいていたので Deep Learningやっていき Advent Calendar 2017 の16日目に入れさせてもらいました。

先日Chainerを作っているPreferred Networks(PFN)か、 化学・生物学分野のための深層学習ライブラリ “Chainer Chemistry”が公開されました。
化学、生物学分野のための深層学習ライブラリChainer Chemistry公開

本ライブラリの特徴

・様々なGraph Convolutional Neural Networkのサポート
・データの前処理部分をライブラリ化・研究用データセットのサポート
・学習・推論コードのExample codeを提供
と明記されています。

Graph Convolutional Neural Networks とは

Graph Convolutional Neural Networks (以下GCNN) とは現在ケモインフォマティクス・バイオインフォマティクスなど構造データを扱う分野において広く注目を集めている手法です。
論文としては以下のような内容があげられています。
Convolutional Networks on Graphs
for Learning Molecular Fingerprints

Molecular Graph Convolutions: Moving Beyond Fingerprints
SchNet: A continuous-filter convolutional neural network for modeling quantum interactions
MoleculeNet: A Benchmark for Molecular
Machine Learning

Modeling Relational Data with Graph Convolutional Networks

また先日行われたNIPS2017にもGCNNをケモインフォに用いた論文が一報上がっています。
Protein Interface Prediction using Graph
Convolutional Networks

詳しい手法の説明等は以下の記事を参照していただきたいです。
機は熟した!グラフ構造に対するDeep Learning、Graph Convolutionのご紹介
GRAPH CONVOLUTIONAL NETWORKS THOMAS KIPF, 30
SEPTEMBER 2016

なぜケモインフォ・バイオインフォで注目されているか。

なぜ本手法がケモインフォ・バイオインフォで注目されているかというと、両領域のデータに多い”グラフ構造”を入力として深層学習に適用できるためです。
ケモインフォで扱う化合物の分子構造やバイオインフォで扱うPPINetwork、 Gene-GeneNetwork、 PathwayNetwork…などなど、様々なグラフ構造のデータに適応できるため、今後急速に本メソッドの利用は増えてくると思います。

※ ケモインフォでなぜ注目されているかはこちらも参考にしていただきたいです。
化合物でもDeep Learningがしたい!

個人的に気になっているのは以下の論文で、GCNNに対してDeepmindが提唱するRelation Networkを適用することで乳がんの分類の精度がめっちゃ上がるっていうものです。
これを他のデータに対して行うことで応用研究の裾野はもっと広がるのではないかと思ったり思わなかったり…。
今後大量のデータを集めることができれば、 複雑なメカニズムが多いバイオインフォ領域で、 ある物質の生成量から生成メカニズムの解析などを行えるようになるのではないかと考えています。

スクリーンショット 2017-12-28 23.40.12.png
以下論文より引用
Hybrid Approach of Relation Network and Localized Graph Convolutional
Filtering for Breast Cancer Subtype Classification

というわけで本記事では、上記のChainer Chemistryを用いてGCNNを試してみたいと思います。

※ Chainer自体は今回初めて触ってみたので、 もっとこう出来るよ!とかここ間違ってるからこうすると良いよ!みたいなのを教えていただけるとありがたいです!

環境

Ubuntu 16.04.3
Python 3.5.4
chainer (3.1.0)
chainer-chemistry (0.1.0)
cupy (2.1.0.1)
rdkit (2017.09.2.0)

事前の準備

Chainer Chemistryのリポジトリはこちら
Chainer Chemistry: A Library for Deep Learning in Biology and Chemistry

まず事前に以下のコマンドのでライブラリを入れます。
* rdkitはPython用のケモインフォマティクスでは必須のライブラリです。(これを読んでいる方はご存知かと思いますが・・・

pip install chainer-chemistry
conda install -c rdkit rdkit

モデル選択とデータセット

chainer chemistryに実装されているモデルとしては、

  • NFP: Neural Fingerprint
  • GGNN: Gated-Graph Neural Network
  • WeaveNet: Molecular Graph Convolutions
  • SchNet: A continuous-filter convolutional Neural Network

があります。

正直これを書くまでGGNNとSchNet以外は知らなかったのですが・・・
とりあえずチュートリアルに上がっているNFPを使ってみたいと思います。

扱うデータはExampleに上がっているtox21を使いました。
(結果としてサンプルのコードをただJupyter用に書き換えただけになってしまったので後々別のデータを使って実装したいと思います。)

実装

ここからコードを書いていきますが、個人的に好きという理由でJupyterを使っています。

まずライブラリのインポートを行います。
chainer、chainer_chemistry、rdkitなどをインポートしています。
chainerのインポートの部分は通常の使い方と一緒なので他の記事等もご覧ください。
Chainer v3 ビギナー向けチュートリアル

今回の実装だと予測を行う部分をpredictorで書いているのでそんなにchainer_chemistryのインポートが多くないです。
後ほどインポートしているdata、predictorの方のコードも記載します。

train_tox21.py
from __future__ import print_function

import os

import logging

try:
    import matplotlib
    matplotlib.use('Agg')
except ImportError:
    pass

import argparse
import chainer
from chainer import functions as F
from chainer import iterators as I
from chainer import links as L
from chainer import optimizers as O
from chainer import training
from chainer.training import extensions as E
import json
from rdkit import RDLogger

from chainer_chemistry.dataset.converters import concat_mols
from chainer_chemistry import datasets as D

import data
import predictor

# Disable errors by RDKit occurred in preprocessing Tox21 dataset.
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
# show INFO level log from chainer chemistry
logging.basicConfig(level=logging.INFO)

で次に変数を定義していきます。
事前にmethod、conv_layers、batchsizeなどをここで定義しています。
gpu = 0 の部分はcpuを使いたい場合は gpu = -1 にしてください。

train_tox21.py
label_names = D.get_tox21_label_names()

method = 'nfp'
conv_layers = 4
batchsize = 128
gpu = 0
out = 'result'
epoch = 10
unit_num = 16
resume = ''
frequency = -1

labels = None
class_num = len(label_names)

次にデータを準備します。
ここで先程インポートしたdataを用いています。

train_tox21.py
# Dataset preparation
train, val, _ = data.load_dataset(method, labels)

次に予測を行います。
予測自体はpredictorで行っています。
流れとしては通常のChainerと一緒でpredictor、optimizerなどの情報をupdaterに渡し、それをまたtrainerに渡しています。

trainer.extendの部分で各種の便利関数を追加していっています。

train_tox21.py
predictor_ = predictor.build_predictor(
        method, unit_num, conv_layers, class_num)

train_iter = I.SerialIterator(train, batchsize)
val_iter = I.SerialIterator(val, batchsize,
                                repeat=False, shuffle=False)
classifier = L.Classifier(predictor_,
                              lossfun=F.sigmoid_cross_entropy,
                              accfun=F.binary_accuracy)
if gpu >= 0:
    chainer.cuda.get_device_from_id(gpu).use()
    classifier.to_gpu()

optimizer = O.Adam()
optimizer.setup(classifier)

updater = training.StandardUpdater(
        train_iter, optimizer, device=gpu, converter=concat_mols)
trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)

trainer.extend(E.Evaluator(val_iter, classifier,
                               device=gpu, converter=concat_mols))
trainer.extend(E.snapshot(), trigger=(epoch, 'epoch'))
trainer.extend(E.LogReport())
trainer.extend(E.PrintReport(['epoch', 'main/loss', 'main/accuracy',
                                  'validation/main/loss',
                                  'validation/main/accuracy',
                                  'elapsed_time']))
trainer.extend(E.ProgressBar(update_interval=100))
frequency = epoch if frequency == -1 else max(1, frequency)
trainer.extend(E.snapshot(), trigger=(frequency, 'epoch'))

if resume:
    chainer.serializers.load_npz(resume, trainer)

trainer.run()

config = {'method': method,
              'conv_layers': conv_layers,
              'unit_num': unit_num,
              'labels': labels}

with open(os.path.join(out, 'config.json'), 'w') as o:
    o.write(json.dumps(config))

次にdata.pyです。
特に書くことがない気がしたので特に書きません。
あしからず。

data.py

import os

from chainer_chemistry.dataset.preprocessors import preprocess_method_dict
from chainer_chemistry import datasets as D
from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset

class _CacheNamePolicy(object):

    train_file_name = 'train.npz'
    val_file_name = 'val.npz'
    test_file_name = 'test.npz'

    def _get_cache_directory_path(self, method, labels, prefix):
        if labels:
            return os.path.join(prefix, '{}_{}'.format(method, labels))
        else:
            return os.path.join(prefix, '{}_all'.format(method))

    def __init__(self, method, labels, prefix='input'):
        self.method = method
        self.labels = labels
        self.prefix = prefix
        self.cache_dir = self._get_cache_directory_path(method, labels, prefix)

    def get_train_file_path(self):
        return os.path.join(self.cache_dir, self.train_file_name)

    def get_val_file_path(self):
        return os.path.join(self.cache_dir, self.val_file_name)

    def get_test_file_path(self):
        return os.path.join(self.cache_dir, self.test_file_name)

    def create_cache_directory(self):
        try:
            os.makedirs(self.cache_dir)
        except OSError:
            if not os.path.isdir(self.cache_dir):
                raise

def load_dataset(method, labels, prefix='input'):
    policy = _CacheNamePolicy(method, labels, prefix)
    train_path = policy.get_train_file_path()
    val_path = policy.get_val_file_path()
    test_path = policy.get_test_file_path()

    train, val, test = None, None, None
    print()
    if os.path.exists(policy.cache_dir):
        print('load from cache {}'.format(policy.cache_dir))
        train = NumpyTupleDataset.load(train_path)
        val = NumpyTupleDataset.load(val_path)
        test = NumpyTupleDataset.load(test_path)
    if train is None or val is None or test is None:
        print('preprocessing dataset...')
        preprocessor = preprocess_method_dict[method]()
        train, val, test = D.get_tox21(preprocessor, labels=labels)
        # Cache dataset
        policy.create_cache_directory()
        NumpyTupleDataset.save(train_path, train)
        NumpyTupleDataset.save(val_path, val)
        NumpyTupleDataset.save(test_path, test)
    return train, val, test

次にpredictor.pyについて.
インポートの部分でchainer_chemistryに実装されている各種GCNNのモデルをインポートしています.

predictor.py

import chainer
from chainer import cuda
from chainer import functions as F
from chainer import iterators as I
import numpy as np

from chainer_chemistry.dataset.converters import concat_mols
from chainer_chemistry.models import GGNN
from chainer_chemistry.models import MLP
from chainer_chemistry.models import NFP
from chainer_chemistry.models import SchNet
from chainer_chemistry.models import WeaveNet

ここでモデル毎の分岐がされています。
今回で言うと”nfp”を指定しているため、
“predictor = GraphConvPredictor(NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers), MLP(out_dim=class_num, hidden_dim=n_unit))”
が呼ばれており、ここにtrain_tox21.pyで渡した引数が入ります。

predictor.py
def build_predictor(method, n_unit, conv_layers, class_num):
    if method == 'nfp':
        print('Use NFP predictor...')
        predictor = GraphConvPredictor(
            NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'ggnn':
        print('Use GGNN predictor...')
        predictor = GraphConvPredictor(
            GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    elif method == 'schnet':
        print('Use SchNet predictor...')
        predictor = SchNet(out_dim=class_num, hidden_dim=n_unit,
                           n_layers=conv_layers, readout_hidden_dim=n_unit)
    elif method == 'weavenet':
        print('Use WeaveNet predictor...')
        n_atom = 20
        n_sub_layer = 1
        weave_channels = [50] * conv_layers
        predictor = GraphConvPredictor(
            WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit,
                     n_sub_layer=n_sub_layer, n_atom=n_atom),
            MLP(out_dim=class_num, hidden_dim=n_unit))
    else:
        raise ValueError('[ERROR] Invalid predictor: method={}'.format(method))
    return predictor

最後にChainとループの部分ですね。
ここもChainerのデフォルトの書き方という感じなので特に言うことはありません。

predictor.py
class GraphConvPredictor(chainer.Chain):

    def __init__(self, graph_conv, mlp):

        super(GraphConvPredictor, self).__init__()
        with self.init_scope():
            self.graph_conv = graph_conv
            self.mlp = mlp

    def __call__(self, atoms, adjs):
        x = self.graph_conv(atoms, adjs)
        x = self.mlp(x)
        return x

    def predict(self, atoms, adjs):
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            x = self.__call__(atoms, adjs)
            return F.sigmoid(x)

class InferenceLoop(object):

    def __init__(self, predictor):
        self.predictor = predictor

    def customized_inference(self, iterator, converter, device):
        iterator.reset()
        ret = []
        for batch in iterator:
            x = converter(batch, device=device)
            y_prob = self.predictor.predict(*x)
            y_prob = cuda.to_cpu(y_prob.data)
            y_pred = np.where(y_prob > .5, 1, 0)
            ret.append(y_pred)
        return np.concatenate(ret, axis=0)

    def inference(self, X):

        batchsize = 128
        data_iter = I.SerialIterator(X, batchsize, repeat=False, shuffle=False)

        if self.predictor.xp is np:
            device_id = -1
        else:
            device_id = cuda.cupy.cuda.get_device_id()

        return self.customized_inference(data_iter,
                                         converter=concat_mols,
                                         device=device_id)

で結果はこんな感じで出てきます。

result.text

    Use NFP predictor...
    epoch       main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
    [J1           4.50169     0.612392       0.434216              0.887374                  7.54297       
    [J     total [#####.............................................] 10.89%
    this epoch [####..............................................]  8.86%
           100 iter, 1 epoch / 10 epochs
           inf iters/sec. Estimated time to finish: 0:00:00.
    [4A[J2           0.319232    0.927095       0.28836               0.915983                  15.0646       
    [J     total [##########........................................] 21.77%
    this epoch [########..........................................] 17.72%
           200 iter, 2 epoch / 10 epochs
        12.359 iters/sec. Estimated time to finish: 0:00:58.144770.
    [4A[J3           0.263636    0.9273         0.287894              0.915983                  22.5079       
    [J     total [################..................................] 32.66%
    this epoch [#############.....................................] 26.59%
           300 iter, 3 epoch / 10 epochs
        12.284 iters/sec. Estimated time to finish: 0:00:50.358029.
    [4A[J4           0.263121    0.927045       0.286248              0.915983                  30.0162       
    [J     total [#####################.............................] 43.54%
    this epoch [#################.................................] 35.45%
           400 iter, 4 epoch / 10 epochs
        12.324 iters/sec. Estimated time to finish: 0:00:42.079189.
    [4A[J5           0.259896    0.927481       0.29347               0.915983                  37.4582       
    [J     total [###########################.......................] 54.43%
    this epoch [######################............................] 44.31%
           500 iter, 5 epoch / 10 epochs
         12.27 iters/sec. Estimated time to finish: 0:00:34.114915.
    [4A[J6           0.256632    0.927258       0.280697              0.915983                  45.006        
    [J     total [################################..................] 65.32%
    this epoch [##########################........................] 53.17%
           600 iter, 6 epoch / 10 epochs
        12.307 iters/sec. Estimated time to finish: 0:00:25.886899.
    [4A[J7           0.256001    0.927342       0.300428              0.915489                  52.4633       
    [J     total [######################################............] 76.20%
    this epoch [###############################...................] 62.03%
           700 iter, 7 epoch / 10 epochs
        12.289 iters/sec. Estimated time to finish: 0:00:17.787654.
    [4A[J8           0.25824     0.92711        0.283099              0.915983                  59.8946       
    [J     total [###########################################.......] 87.09%
    this epoch [###################################...............] 70.90%
           800 iter, 8 epoch / 10 epochs
        12.309 iters/sec. Estimated time to finish: 0:00:09.634948.
    [4A[J9           0.258653    0.927209       0.283805              0.915736                  67.3553       
    [J     total [################################################..] 97.98%
    this epoch [#######################################...........] 79.76%
           900 iter, 9 epoch / 10 epochs
         12.31 iters/sec. Estimated time to finish: 0:00:01.510471.
    [4A[J10          0.257073    0.927359       0.298906              0.914913                  74.8144       
    [J

おわりに

コードはここにあげておきます。
inoue0426/chainer-chemistory-notebooks
GCNNといえばDeepChemというライブラリが以前から用いられていますが、選択肢が広がったのは良いことだなーっと思ったりしました。
今後、実装の比較を行えたらと思います。


『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。