post Image
{tensorflow}をirisデータで試してみる

概要

 少し前に{tensorflow}というRからTensorFlowを使うパッケージがRStudio社から公開されたので、みんな大好きirisデータの分類をMNISTの例を参考に試してみました。

事前準備

 実行環境によっては以前にインストールしたバージョンを削除して再度インストールする必要があります(自分はprotobufも再インストールしました)。

NOTE: If you are upgrading from a previous installation of TensorFlow < 0.7.1, you should uninstall the previous TensorFlow and protobuf using pip uninstall first to make sure you get a clean installation of the updated protobuf dependency.
 https://www.tensorflow.org/versions/r0.11/get_started/os_setup.html

必要ライブラリのインストール
# 削除したprotobufを再インストール
$ pip install protobuf

# OS XでCPU利用の場合
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-0.11.0rc1-py2-none-any.whl
$ sudo pip install --upgrade $TF_BINARY_URL

# {tensorflow}のインストール時に{tensorflow}で利用するPythonを指定するため、パスを確認しておく
$ which python
/usr/local/bin/python
export PYTHON_PATH=`which python`
R事前準備
# {tensorflow}で利用するPythonのパスを環境変数に設定してインストールする
Sys.setenv(TENSORFLOW_PYTHON = "/usr/local/bin/python")
devtools::install_github("rstudio/tensorflow")

 インストール後にパッケージを呼び出してエラーが起きなければOK(エラーが起きる場合は使用するライブラリの再インストールや、上記のPythonパスを確認するなど試す)。なお、エラー後に再度パッケージを読み込むとエラーメッセージは表示されないが、使用できないので注意。

定義・設定

 処理で利用するライブラリの読み込みや定数・関数。加えて今回はTensorFlowのモデルの定義をここで行う。

定数定義部
library(tensorflow)
library(dplyr)
library(foreach)
library(caret)

SET_CV_NUM <- 5
SET_DATA_PARAM <- list(
  CLASS_NUM = 3L, FEATURE_NUM = 4L
)
SET_SETP_NUM <- 3000
関数定義部
# confusion matrixからAccuracyを計算
calcAccuracy <- function(confusion_mat) {
  return(sum(diag(x = confusion_mat)) / sum(confusion_mat))
}
モデル定義部
# MNIST For ML Beginnersの例を参考
W <- tensorflow::tf$Variable(
  initial_value = tensorflow::tf$zeros(shape = tensorflow::shape(SET_DATA_PARAM$FEATURE_NUM, SET_DATA_PARAM$CLASS_NUM))
)
x <- tensorflow::tf$placeholder(
  dtype = tensorflow::tf$float32,
  shape = tensorflow::shape(NULL, SET_DATA_PARAM$FEATURE_NUM)
)
b <- tensorflow::tf$Variable(
  initial_value = tensorflow::tf$zeros(shape = tensorflow::shape(SET_DATA_PARAM$CLASS_NUM))
)

y <- tensorflow::tf$nn$softmax(logits = tensorflow::tf$matmul(a = x, b = W) + b)
y_ <- tensorflow::tf$placeholder(dtype = tensorflow::tf$float32, shape = tensorflow::shape(NULL, SET_DATA_PARAM$CLASS_NUM))


# 損失関数とオプティマイザーの設定
cross_entropy <- tensorflow::tf$reduce_mean(input_tensor = - tensorflow::tf$reduce_sum(input_tensor = y_ * tensorflow::tf$log(x = y), reduction_indices = 1L))
optimizer <- tensorflow::tf$train$GradientDescentOptimizer(learning_rate = 0.5)
train_step <- optimizer$minimize(loss = cross_entropy)


# 評価用
correct_prediction <- tensorflow::tf$equal(
  x = tensorflow::tf$argmax(input = y, dimension = 1L), y = tensorflow::tf$argmax(input = y_, dimension = 1L)
)
accuracy <- tensorflow::tf$reduce_mean(
  input_tensor = tensorflow::tf$cast(x = correct_prediction, dtype = tensorflow::tf$float32)
)

実行部

学習とテスト
# irisデータ
data(iris)
> iris %>% 
  head(n = 10)
   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
1           5.1         3.5          1.4         0.2  setosa
2           4.9         3.0          1.4         0.2  setosa
3           4.7         3.2          1.3         0.2  setosa
4           4.6         3.1          1.5         0.2  setosa
5           5.0         3.6          1.4         0.2  setosa
6           5.4         3.9          1.7         0.4  setosa
7           4.6         3.4          1.4         0.3  setosa
8           5.0         3.4          1.5         0.2  setosa
9           4.4         2.9          1.4         0.2  setosa
10          4.9         3.1          1.5         0.1  setosa

> iris %>% 
   dplyr::group_by(Species) %>% 
   dplyr::summarise_all(.fun = mean)
# A tibble: 3 × 5
     Species Sepal.Length Sepal.Width Petal.Length Petal.Width
      <fctr>        <dbl>       <dbl>        <dbl>       <dbl>
1     setosa        5.006       3.428        1.462       0.246
2 versicolor        5.936       2.770        4.260       1.326
3  virginica        6.588       2.974        5.552       2.026


set.seed(seed = 71)
cv_number <- sample(x = seq(from = 1, to = SET_CV_NUM), size = nrow(x = iris), replace = TRUE)


# N分割交差検証
> tf_result <- lapply(
  X = seq(from = 1, to = SET_CV_NUM),
  FUN = function (cv_counter) {

    # 学習データ作成
    trn_d <- iris %>% 
      dplyr::filter(cv_number != cv_counter)
    trn_x <- trn_d %>% 
      dplyr::select(-Species) %>% 
      as.matrix()
    trn_y <- trn_d %>% 
      caret::dummyVars(formula = ~ Species, sep = NULL) %>% 
      predict(object = ., newdata = trn_d) %>% 
      as.matrix()


    # 初期化
    tf_session <- tensorflow::tf$Session()
    tf_session$run(fetches = tensorflow::tf$initialize_all_variables())

    # パラメータが初期化されているか確認
    print(
      stringr::str_c(
        stringr::str_c("CV:", cv_counter),
        stringr::str_c("W:", sum(tf_session$run(W))),
        stringr::str_c("b:", sum(tf_session$run(b))),
        sep = " "
      )
    )

    # 学習
    foreach::times(n = SET_SETP_NUM) %do% {
      step_logic <- sample(x = c(FALSE, TRUE), size = nrow(x = trn_x), replace = TRUE, prob = c(0.5, 0.5))
      tf_session$run(
        fetches = train_step,
        feed_dict = tensorflow::dict(x = trn_x[step_logic, ], y_ = trn_y[step_logic, , drop = FALSE])
      )
    }

    # 当てはめ結果
    tf_fit_accuracy <- accuracy$eval(feed_dict = tensorflow::dict(x = trn_x, y_ = trn_y), session = tf_session)
    fit_confusion_mat <- table(
      predict = tf_session$run(fetches = y, feed_dict = tensorflow::dict(x = trn_x, y_ = trn_y)) %>%
        apply(MARGIN = 1, FUN = which.max),
      true = trn_y %>% 
        apply(MARGIN = 1, FUN = which.max)
    )


    # 評価データ作成
    tst_d <- iris %>% 
      dplyr::filter(cv_number == cv_counter)
    tst_x <- tst_d %>% 
      dplyr::select(-Species) %>% 
      as.matrix()
    tst_y <- tst_d %>% 
      dummyVars(formula = ~ Species, sep = NULL) %>% 
      predict(object = ., newdata = tst_d) %>% 
      as.matrix()

    # 予測結果
    tf_predict_accuracy <- accuracy$eval(feed_dict = tensorflow::dict(x = tst_x, y_ = tst_y), session = tf_session)
    predict_confusion_mat <- table(
      predict = tf_session$run(fetches = y, feed_dict = tensorflow::dict(x = tst_x, y_ = matrix(data = 0, nrow = nrow(x = tst_x), ncol = SET_DATA_PARAM$CLASS_NUM))) %>%
        apply(MARGIN = 1, FUN = which.max),
      true = tst_y %>% 
        apply(MARGIN = 1, FUN = which.max)
    )


    # 比較用
    glmnet_mdl <- glmnet::glmnet(x = trn_x, y = trn_d$Species, family = "multinomial")
    glmnet_confusion_mat <- table(
      predict = predict(object = glmnet_mdl, newx = tst_x, type = "class", s = 0.01)[, 1],
      true = tst_d$Species
    )

    return(
      list(
        tf_fit_accuracy = tf_fit_accuracy,
        fit_confusion_mat = fit_confusion_mat,
        tf_predict_accuracy = tf_predict_accuracy,
        predict_confusion_mat = predict_confusion_mat,
        glmnet_confusion_mat = glmnet_confusion_mat
      )
    )
  }
)
[1] "CV:1 W:0 b:0"
[1] "CV:2 W:0 b:0"
[1] "CV:3 W:0 b:0"
[1] "CV:4 W:0 b:0"
[1] "CV:5 W:0 b:0"

 学習前にパラメータの初期化もされています。

評価結果
# 当てはめ結果を確認
> sapply(X = tf_result, FUN = "[[", "tf_fit_accuracy")
[1] 0.9487180 0.9920000 0.9663866 0.9842520 0.9732143
# TensorFlow上で評価用に算出したaccuracyとほぼ同じ結果
> sapply(
  X = lapply(X = tf_result, FUN = "[[", "fit_confusion_mat"),
  FUN = calcAccuracy
)
[1] 0.9487179 0.9920000 0.9663866 0.9842520 0.9732143

# 当てはめ結果の平均
> mean(x = sapply(X = tf_result, FUN = "[[", "tf_fit_accuracy"))
[1] 0.9729141

# 当てはめ結果のconfusion matrix
> lapply(X = tf_result, FUN = "[[", "fit_confusion_mat")
[[1]]
       true
predict  1  2  3
      1 40  0  0
      2  0 31  0
      3  0  6 40

[[2]]
       true
predict  1  2  3
      1 39  0  0
      2  0 41  0
      3  0  1 44

[[3]]
       true
predict  1  2  3
      1 42  0  0
      2  0 33  0
      3  0  4 40

[[4]]
       true
predict  1  2  3
      1 38  0  0
      2  0 43  1
      3  0  1 44

[[5]]
       true
predict  1  2  3
      1 41  0  0
      2  0 37  0
      3  0  3 31


# 予測結果を確認
> sapply(X = tf_result, FUN = "[[", "tf_predict_accuracy")
[1] 0.9696970 0.9200000 0.9677419 0.9565217 0.9736842
> sapply(
  X = lapply(X = tf_result, FUN = "[[", "predict_confusion_mat"),
  FUN = calcAccuracy
)
[1] 0.9696970 0.9200000 0.9677419 0.9565217 0.9736842

# 予測結果の平均
> mean(x = sapply(X = tf_result, FUN = "[[", "tf_predict_accuracy"))
[1] 0.957529

# 予測結果のconfusion matrix
> lapply(X = tf_result, FUN = "[[", "predict_confusion_mat")
[[1]]
       true
predict  1  2  3
      1 10  0  0
      2  0 12  0
      3  0  1 10

[[2]]
       true
predict  1  2  3
      1 11  0  0
      2  0  6  0
      3  0  2  6

[[3]]
       true
predict  1  2  3
      1  8  0  0
      2  0 12  0
      3  0  1 10

[[4]]
       true
predict  1  2  3
      1 12  0  0
      2  0  6  1
      3  0  0  4

[[5]]
       true
predict  1  2  3
      1  9  0  0
      2  0  9  0
      3  0  1 19

# {glmnet}の予測結果
> glmnet_ev <- sapply(
  X = lapply(X = tf_result, FUN = "[[", "glmnet_confusion_mat"),
  FUN = calcAccuracy
) %>% 
  print
[1] 1.0000000 0.9200000 0.9677419 0.9130435 0.9736842
> mean(x = glmnet_ev)
[1] 0.9548939

glmnetと差がほとんどなかったです。

まとめ

 {tensorflow}を用いて、分析業界の”Hello World”であるirisデータの分類を試しました。以前は{PythonInR}を使ってTensorFlowを呼び出しましたが、RStudio社が公開するパッケージということでこちらの方が安心感があります。今回はとりあえず動かしてみただけですので、モデルの改良などはもう少しお勉強してからにします。
 また、現在のところPreview版ですが、RStudioを1系にすると{tensorflow}のオブジェクトがサジェストされるので、とてもとてもオススメです。

参考

実行環境

実行環境
> devtools::session_info()
Session info ----------------------------------------------------------------------------------------
 setting  value                       
 version  R version 3.3.1 (2016-06-21)
 system   x86_64, darwin13.4.0        
 ui       RStudio (1.0.44)            
 language (EN)                        
 collate  ja_JP.UTF-8                 
 tz       Asia/Tokyo                  
 date     2016-10-27                  

Packages --------------------------------------------------------------------------------------------
 package      * version date       source                             
 assertthat     0.1     2013-12-06 CRAN (R 3.3.1)                     
 broom          0.4.1   2016-06-24 CRAN (R 3.3.0)                     
 car            2.1-2   2016-03-25 CRAN (R 3.3.0)                     
 caret        * 6.0-70  2016-06-13 CRAN (R 3.3.0)                     
 codetools      0.2-14  2015-07-15 CRAN (R 3.3.1)                     
 colorspace     1.2-6   2015-03-11 CRAN (R 3.3.1)                     
 DBI            0.5     2016-08-11 cran (@0.5)                        
 devtools       1.12.0  2016-06-24 CRAN (R 3.3.0)                     
 digest         0.6.9   2016-01-08 CRAN (R 3.3.0)                     
 dplyr        * 0.5.0   2016-06-24 CRAN (R 3.3.1)                     
 foreach      * 1.4.3   2015-10-13 CRAN (R 3.3.1)                     
 ggplot2      * 2.1.0   2016-03-01 CRAN (R 3.3.1)                     
 glmnet         2.0-5   2016-03-17 CRAN (R 3.3.0)                     
 gtable         0.2.0   2016-02-26 CRAN (R 3.3.1)                     
 iterators      1.0.8   2015-10-13 CRAN (R 3.3.1)                     
 janeaustenr    0.1.1   2016-06-20 CRAN (R 3.3.0)                     
 lattice      * 0.20-33 2015-07-14 CRAN (R 3.3.1)                     
 lme4           1.1-12  2016-04-16 CRAN (R 3.3.0)                     
 magrittr       1.5     2014-11-22 CRAN (R 3.3.1)                     
 MASS           7.3-45  2016-04-21 CRAN (R 3.3.1)                     
 Matrix         1.2-6   2016-05-02 CRAN (R 3.3.1)                     
 MatrixModels   0.4-1   2015-08-22 CRAN (R 3.3.1)                     
 memoise        1.0.0   2016-01-29 CRAN (R 3.3.0)                     
 mgcv           1.8-12  2016-03-03 CRAN (R 3.3.1)                     
 minqa          1.2.4   2014-10-09 CRAN (R 3.3.0)                     
 mnormt         1.5-4   2016-03-09 CRAN (R 3.3.0)                     
 munsell        0.4.3   2016-02-13 CRAN (R 3.3.1)                     
 nlme           3.1-128 2016-05-10 CRAN (R 3.3.1)                     
 nloptr         1.0.4   2014-08-04 CRAN (R 3.3.1)                     
 nnet           7.3-12  2016-02-02 CRAN (R 3.3.1)                     
 pbkrtest       0.4-6   2016-01-27 CRAN (R 3.3.0)                     
 plyr           1.8.4   2016-06-08 CRAN (R 3.3.1)                     
 psych          1.6.6   2016-06-28 CRAN (R 3.3.0)                     
 quantreg       5.26    2016-06-07 CRAN (R 3.3.0)                     
 R6             2.1.3   2016-08-19 cran (@2.1.3)                      
 Rcpp           0.12.7  2016-09-05 cran (@0.12.7)                     
 readr          1.0.0   2016-08-03 cran (@1.0.0)                      
 reshape2       1.4.1   2014-12-06 CRAN (R 3.3.1)                     
 scales         0.4.0   2016-02-26 CRAN (R 3.3.1)                     
 SnowballC      0.5.1   2014-08-09 CRAN (R 3.3.1)                     
 SparseM        1.7     2015-08-15 CRAN (R 3.3.0)                     
 stringi        1.1.1   2016-05-27 CRAN (R 3.3.1)                     
 stringr        1.1.0   2016-08-19 cran (@1.1.0)                      
 tensorflow   * 0.3.0   2016-10-25 Github (rstudio/tensorflow@dfe2f1a)
 tibble         1.2     2016-08-26 cran (@1.2)                        
 tidyr          0.6.0   2016-08-12 cran (@0.6.0)                      
 tidytext       0.1.1   2016-06-25 CRAN (R 3.3.0)                     
 tokenizers     0.1.4   2016-08-29 CRAN (R 3.3.0)                     
 withr          1.0.2   2016-06-20 CRAN (R 3.3.0)      

『 機械学習 』Article List
Category List

Eye Catch Image
Read More

Androidに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

AWSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Bitcoinに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

CentOSに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

dockerに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

GitHubに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Goに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Javaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

JavaScriptに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Laravelに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Pythonに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Rubyに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Scalaに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Swiftに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Unityに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Vue.jsに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

Wordpressに関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。

Eye Catch Image
Read More

機械学習に関する現役のエンジニアのノウハウ・トレンドのトピックなど技術的な情報を提供しています。コード・プログラムの丁寧な解説をはじめ、初心者にもわかりやすいように写真や動画を多く使用しています。