この前は基本的なRNNの仕組みについて勉強していました。
今回は、現在RNNの中でも代表的なモデルの一つであるLSTMについて勉強します。
今回も参考にしたのはこちらの本です。
ゼロから作るDeep Learning ? ―自然言語処理編
- 作者: 斎藤康毅
- 出版社/メーカー: オライリージャパン
- 発売日: 2018/07/21
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (3件) を見る
毎回のことながら、今回も非常にわかりやすかったです。
今回もpytorchを使って、楽して実装を眺めながら勉強していきます。
LSTM
勾配消失/勾配爆発
単純なRNNの問題点として、勾配消失/勾配爆発があります。 RNNレイヤの中で、時系列方向の逆伝播を考えます。
上の図より、時系列方向の距離が大きい要素ほど、多くの演算を通過することがわかります。 細かい説明は教科書に譲ります。(正直あまり良くわかってないです)
あとは、この辺も参考になるかもしれないです。
この時、時系列的に遠いところの入力は学習にほとんど影響を与えない、あるいは他の要素に対して大きすぎる影響を持つようになってしまいます。
単語の推定を行う場合などには、文脈を考慮するためにある程度の時系列範囲を考慮する必要があります。 しかし、単純なRNNでは、時系列を適切に活用することができないという問題があります。。
LSTMの仕組み
そこで考案されたのがLSTM (Long Short Term Memory) です。 前回使ったRNNレイヤの部分を下のようなLSTMレイヤに置き換えます。
このようにすることで、時系列方向に関して勾配消失/勾配爆発することを回避します。 ざっくりいえば、時系列方向に残す信号と忘れる信号を管理することで、長期記憶を可能にしています。 なんでこうなるのかは、教科書読んでください。どの部分が何を表してあるかまで説明されています。 その他、最適化とかも紹介されています。
何はともあれ、この形にすれば勾配消失/勾配爆発を回避できるLSTMの出来上がりです。
Pytorchの実装
今回はこちらのチュートリアルをなぞってみます。
環境構築
省略。どうせ前回と一緒なので。
チュートリアルを眺める
LSTMのモデルのコードはこんな感じです。
class LSTMTagger(nn.Module): def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size): super(LSTMTagger, self).__init__() self.hidden_dim = hidden_dim self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) # The LSTM takes word embeddings as inputs, and outputs hidden states # with dimensionality hidden_dim. self.lstm = nn.LSTM(embedding_dim, hidden_dim) # The linear layer that maps from hidden state space to tag space self.hidden2tag = nn.Linear(hidden_dim, tagset_size) self.hidden = self.init_hidden() def init_hidden(self): # Before we've done anything, we dont have any hidden state. # Refer to the Pytorch documentation to see exactly # why they have this dimensionality. # The axes semantics are (num_layers, minibatch_size, hidden_dim) return (torch.zeros(1, 1, self.hidden_dim), torch.zeros(1, 1, self.hidden_dim)) def forward(self, sentence): embeds = self.word_embeddings(sentence) lstm_out, self.hidden = self.lstm( embeds.view(len(sentence), 1, -1), self.hidden) tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1)) tag_scores = F.log_softmax(tag_space, dim=1) return tag_scores
なんか上の図のLSTMレイヤはすでに関数が用意されているんですね。 なのでパラメータを設定して後続の活性化関数を仕込むだけで使えるんですね。 超便利じゃないですか!
感想
以前、LSTMを使って異常検知をやっていた方がいたのを思い出しました。 この辺りまで来ると、結構実用的なレベルで使えるようになるみたいですね。