昨年のうちにTensorFlowのインストールとニューラルネットワークの基礎理論の勉強はしました。
基礎理論の勉強のためには、やっぱり自分で書いてみるのが一番いいので、これまでは特に便利ライブラリはなしで書いていました。 だんだん難しいことをやっていこうと思うと、こういう便利ライブラリも徐々に使いこなす必要があるので、今回はそのお勉強です。
TensorFlowの概要
まずはじめに、TensorFlowについてざっと調べて見ました。(インストールするときにやれよってツッコミは勘弁して下さい、、、)
参考にしたのはこちらです。
TensorFlowはGoogleが提供しているオープンソースのディープラーニングライブラリです。 他にも有名なディープラーニングのライブラリはあるんですが、下記の三拍子がそろうものはTensorFlowだけらしいです。
- GPU・モバイルなどのCPU以外の環境で使用可能
- 分散コンピューティングへの対応
- 宣言的プログラミングによる内部処理の隠蔽
参考にしたサイトにもありますが、TensorFlowではデータフローグラフを定義してプログラミングをします。
Computation using data flow graphs for scalable machine learning
まず処理の全体感ですが、TensorFlowとして想定している処理の流れはこんな感じらしいです。
パイプラインを用意するイメージですね。 次に、データフローグラフです。 データフローグラフは、こんな感じのものがイメージです。
ノードとエッジで構成される有向グラフをデータフローグラフらしいです。 ニューラルネットワークの設計を考える上で、図のようなモデルを設計してプログラミングしていくことになります。
実装
今回はTensorFlowのチュートリアルをなぞって勉強しようと思います。 リンクはこちらです。
(本家、英語) www.tensorflow.org
(日本語訳) qiita.com
もともとの実装
ニューラルネットの勉強は一回やっているので、細かいとことか、TensorFlowを使わないのと比較したい場合にはこちらをご参照ください。
ニューラルネットワークについて勉強してみた - Re:ゼロから始めるML生活
全体像
TensorFlowの概要で使った図です。今回はこれに沿ってコードを眺めます。
①データ読み込み
毎回のことながら、データはMNISTです。 データの読み込み自体は2行書くだけで終わります。
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
②前処理
今回は下準備はそんなに必要なくて、強いて言うならバッチのサイズにまとめて置くことですかね。
batch_xs, batch_ys = mnist.train.next_batch(100)
実行する前に、こんな感じで学習データを100個単位で取得しています。 batch_xsが画像データ、batch_ysが正解のラベルで、そのペアを100個取得しています。
③学習
TensorFlowでは、データフローグラフに対してデータを流し込んで行くんでした。 今回のデータフローグラフはこんな感じになっています。
データフローグラフの定義はこんな感じに書いてあります。
x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
placeholder
TensorFlowはあくまで汎用プログラミング言語から呼び出すライブラリです。 TensorFlowの内部だけで完結するデータはそれ専用の呼び出し(今回はtf.Variable)があります。 また、TensorFlowが想定するニューラルネットワークの処理とそれ以外の汎用言語との間でやり取りするデータを定義する必要があります。
TensorFlowではplaceholderを使用してデータの型を宣言できます。
x = tf.placeholder(tf.float32, [None, 784])
学習データxは32ビットfloat型で長さが784(=28*28)で定義されています。 None は次元が任意の長さをとることを意味しているみたいです。
y_ = tf.placeholder(tf.float32, [None, 10])
こっちも一緒ですね。 出力は0−9の10通りで、型は32ビットfloat型で定義されています。
session
学習の実行はこんな感じに書いてあります。
sess = tf.InteractiveSession() tf.global_variables_initializer().run() for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
IntaractiveSessionとありますが、sessionはネットワークの起動時に必要になるオブジェクトです。 TensorFlowは分散コンピューティングも可能とのことなので、おそらく複数のコンピューティングノードの計算結果を統合する際には注意するんでしょう。 単一のPCで実行するレベルでは、おまじないとして貰えれば。。。
sessionを呼び出すことで学習の準備はほぼ終わりました。 その次に、全ての内部変数を初期化しています。 学習の実行自体は3行で書いてありますね。
for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
データを100個単位で取ってきて、学習開始。1000回繰り返して学習を終了します。
その後、全部終わった後のネットワークの精度を確認して終了です。
④保存
元のコードに保存のプロセスはないので、今回は省略します。
TensorFlowを使った実装
最後に、チュートリアルのサイトに落ちてるコードはこんな感じです。
# coding:utf-8 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) import tensorflow as tf x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession() tf.global_variables_initializer().run() for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
コードが短くて感動です。TensorBoardを使えば内部の可視化とかもできるのですが、それはまた今度。