この前はword2vecの勉強をしていました。
今回はもうちょっと突っ込んだRNN (Recurrent Neural Networks)について勉強してみます。
参考にしたのはこちら。

ゼロから作るDeep Learning ? ―自然言語処理編
- 作者: 斎藤康毅
- 出版社/メーカー: オライリージャパン
- 発売日: 2018/07/21
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (3件) を見る
細かい内容は上の本をご参照下さい。 ホントにわかりやすいので。
今回の内容辺りからややこしい話になってくるので、実装もいれつつ勉強していきます。 現場でライブラリなしで使うことなんてありえないですし、理論は上の本の写経で勉強すればよいかと。 今回はpytorchの実装を見ながらやってみます。
RNN (Recurrent Neural Networks)
世の中の大半のニューラルネットワークの大半はfeed forward型のニューラルネットワークです。 イメージはこんな感じです。
それに対してRNNとはfeed back型のニューラルネットワークです。 イメージはこんな感じです。
RNNではニューラルネットに加えて時系列の概念が考慮されています。 時刻tのときのNNの出力の一部を時刻t+1のときのNNの入力として使用しています。 こうすることで、時系列に関する考慮を可能にしています。
時間軸方向に展開するとこんな感じで時系列方向に入力信号の影響が伝播するようになっています。
時間軸方向の関連の分割
大まかなイメージは上の図で十分だと思います。 しかし、このままでは時系列データが大きくなるとすぐメモリ不足になります。
そのため、大きな時系列データを使用できるようにBPTT (Backpropagation Through Time) と呼ばれる手法でネットワークのつながりを適当な大きさで分割します。
これを実現するには、バックプロパゲーションのやり方に一工夫必要です。 順伝播するときには、直前のブロックの最後の出力を使用しましたが、逆伝播のときにはブロック間のやり取りを断ち切ります。
これによって、
- 直前の出力と組み合せて順伝播
- 最後の出力を一時的に保存
- ブロック内の逆伝播
- 次のブロックへ(1. へ戻る)
のように小さい単位で学習を行うことが可能になります。 基本的な仕組みはこんな感じです。 ミニバッチとかの話は上の本をご参照ください。
Pytorchでやってみる
上で紹介したPytorchのチュートリアルにいい感じのやつがあるので、それを写経してやってみます。
下準備
そんなに複雑な処理でもないので、ipythonでやりたいと思います。
データセット
データセットはこちら。
(リンク)https://download.pytorch.org/tutorial/data.zip
ディレクトリ構成
そんでもって、こんな感じでデータを配置していただければ準備OKです。
tree <マウント先パス> <マウント先パス>/ ├── Dockerfile ├── char_rnn_classification_tutorial.ipynb ├── char_rnn_generation_tutorial.ipynb └── data ├── eng-fra.txt └── names ├── Arabic.txt ├── Chinese.txt ├── Czech.txt ├── Dutch.txt ├── English.txt ├── French.txt ├── German.txt ├── Greek.txt ├── Irish.txt ├── Italian.txt ├── Japanese.txt ├── Korean.txt ├── Polish.txt ├── Portuguese.txt ├── Russian.txt ├── Scottish.txt ├── Spanish.txt └── Vietnamese.txt
Docker
環境構築にはこちらを使用しました。
FROM jupyter/scipy-notebook:latest RUN conda install --quiet --yes pytorch torchvision -c soumith RUN pip install --upgrade torch
※pytorchのアップデートを書き加えました。
起動はこんな感じ。
docker run -it --rm -v <マウント先パス>:/home/jovyan/work -p 8888:8888 stepankuzmin/pytorch-notebook
CLASSIFYING NAMES WITH A CHARACTER-LEVEL RNN
基本的にチュートリアルをそのままなぞります。
このチュートリアルでやりたいこととしてはこんな感じです。
そんでもって、肝心のRNNはこんな感じです。
class RNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(RNN, self).__init__() self.hidden_size = hidden_size self.i2h = nn.Linear(input_size + hidden_size, hidden_size) self.i2o = nn.Linear(input_size + hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden): combined = torch.cat((input, hidden), 1) hidden = self.i2h(combined) output = self.i2o(combined) output = self.softmax(output) return output, hidden def initHidden(self): return torch.zeros(1, self.hidden_size)
このコードをすべて動かすとこんな感じに出力されていきます。
入力に対する出力の類推結果はこんな感じ。
コンソール的にはこんな感じ。
> Dovesky (-0.90) Russian (-1.30) Czech (-2.06) Polish > Jackson (-0.70) Scottish (-1.43) English (-2.54) Russian > Satoshi (-1.35) Japanese (-1.51) Polish (-1.72) Italian
ちゃんと類推できているっぽいですね。
GENERATING NAMES WITH A CHARACTER-LEVEL RNN
次はちょっと難しくなって、文字の推定を行います。
やりたいこととしてはこんな感じです。
肝になるRNNのコードはこんな感じです。
class RNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(RNN, self).__init__() self.hidden_size = hidden_size self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size) self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size) self.o2o = nn.Linear(hidden_size + output_size, output_size) self.dropout = nn.Dropout(0.1) self.softmax = nn.LogSoftmax(dim=1) def forward(self, category, input, hidden): input_combined = torch.cat((category, input, hidden), 1) hidden = self.i2h(input_combined) output = self.i2o(input_combined) output_combined = torch.cat((hidden, output), 1) output = self.o2o(output_combined) output = self.dropout(output) output = self.softmax(output) return output, hidden def initHidden(self): return torch.zeros(1, self.hidden_size)
そんでもって動かすとこんな感じです。
samples('Russian', 'RUS') samples('German', 'GER') samples('Spanish', 'SPA') samples('Chinese', 'CHI')
Rakiskin Uovako Shakovek Gerten Eerten Roune Santan Paran Allan Chan Han Iuan
ちゃんと動いていますね。 pytorchだとこんな感じに使えるみたいです。
感想
やっぱり難しいですね。 RNN自体は今後なにかに使うかもしれないんで、少しずつ使えるようにしたいものです。
あと全然関係無いんですけど、pytorch 1.0版が出てたんですね。 全然知らなかったっす。