リカレントニューラルネットワーク(RNN)とそれらの計算グラフ

このセクションでは以下の書き方を学びます。

  • Full Backpropagationでのリカレントニューラルネットワーク

  • Truncated Backpropagationでのリカレントニューラルネットワーク

  • 少ないメモリでのネットワークの評価

このセクションを読んだら、以下のことができるようになります:

  • 可変長の入力シーケンスの処理

  • Forward計算中にネットワークの上位ストリームを切り捨てる

  • volatile変数を使用してネットワーク構築を防止する

リカレントニューラルネットワーク(RNN)

リカレントニューラルネットワークはループ構造をもつニューラルネットワークです。主にシーケンスデータの入出力から学習させるために使われます。入力ストリーム \(x_1, x_2, \dots, x_t, \dots\) と初期状態 \(h_0\) を与えると、リカレントニューラルネットワークは  \(h_t = f(x_t, h_{t-1})\) によってその状態を繰り返し更新します。そして、ある間隔または毎時刻に \(y_t = g(h_t)\) を出力します。この手順を時間軸に沿って展開すると、同じパラメータがネットワーク内で繰り返し使用されることを除いて、通常のフィードフォワードニューラルネットワークのように見えます。

ここでは単純な一層のリカレントニューラルネットワークの書き方を学びます。課題は言語モデリングです。有限の単語の列が与えられたとき、連続する単語を覗き込むことなく、各位置で次の単語を予測したいと考えてみましょう。1000種類の異なる単語タイプがあり、各単語(Word embedding)を表すために100次元の実ベクトルを使用すると仮定します。

リカレントニューラルネットワークの言語モデル(RNNLM)をChainとして定義することから始めましょう。Fully ConnectedのステートフルなLSTM層を実装する chainer.links.LSTM Linkを使用できます。このLinkは通常のFully Connected Layerのように見えます。構築時には、入力と出力のサイズをコンストラクタに渡します。

>>> l = L.LSTM(100, 50)

次に、このインスタンス l(x) を呼び出すと、 LSTM層の1ステップ が実行されます。

>>> l.reset_state()
>>> x = Variable(np.random.randn(10, 100).astype(np.float32))
>>> y = l(x)

Forward計算の前にLSTM層の内部状態をリセットすることを忘れないでください!すべてのリカレント層は、その内部状態(すなわち、前の呼び出しの出力)を保持する。リカレント層の最初のアプリケーションでは、内部状態をリセットする必要があります。そうすると、次の入力をLSTMインスタンスに直接入力することができます。

>>> x2 = Variable(np.random.randn(10, 100).astype(np.float32))
>>> y2 = l(x2)

このLSTM Linkに基づいて、新しいChainとしてリカレントネットワークを記述しましょう。:

class RNN(Chain):
    def __init__(self):
        super(RNN, self).__init__(
            embed=L.EmbedID(1000, 100),  # word embedding
            mid=L.LSTM(100, 50),  # the first LSTM layer
            out=L.Linear(50, 1000),  # the feed-forward output layer
        )

    def reset_state(self):
        self.mid.reset_state()

    def __call__(self, cur_word):
        # Given the current word ID, predict the next word.
        x = self.embed(cur_word)
        h = self.mid(x)
        y = self.out(h)
        return y

rnn = RNN()
model = L.Classifier(rnn)
optimizer = optimizers.SGD()
optimizer.setup(model)

ここで EmbedID はWord embeddingのLinkです。入力整数を対応する固定次元埋め込みベクトルに変換します。最後の線形Link out は、feed-forwardの出力層を表す。

RNN Chainは one-step-forwardド計算 を実装します。シーケンスはそれ自体では処理されませんが、シーケンス内のアイテムをChainに直接送るだけで、シーケンスを処理することができます。

単語変数 x_list のリストがあるとします。簡単な for ループによって単語シーケンスの損失値を計算することができます。

def compute_loss(x_list):
    loss = 0
    for cur_word, next_word in zip(x_list, x_list[1:]):
        loss += model(cur_word, next_word)
    return loss

もちろん、累積された損失は計算の全ての履歴を持つVariableオブジェクトです。したがって、 backward() メソッドを呼び出すと、モデルパラメータに対応する全ての損失の勾配を計算できます。

# Suppose we have a list of word variables x_list.
rnn.reset_state()
model.cleargrads()
loss = compute_loss(x_list)
loss.backward()
optimizer.update()

同等に、 compute_loss を損失関数として使用することもできます。

rnn.reset_state()
optimizer.update(compute_loss, x_list)

Unchainingでグラフを切り捨てる

非常に長いシーケンスから学習することは、リカレントニューラルネットワークの典型的なユースケースです。入力シーケンスと状態シーケンスが長すぎてメモリに収まらないと仮定します。そのような場合、しばしばバックプロパゲーションを短い時間範囲に切り捨てることをします。このテクニックは truncated backpropagation と呼ばれます。これはヒューリスティックであり、勾配にバイアスがかかります。しかしながら、このテクニックは時間範囲が長過ぎる場合、実際にはうまく動作します。

Chainerでtruncated backpropagationをどのように実装するのでしょうか?Chainerには、backward unchaining と呼ばれるtruncated backpropagationを実現するスマートなメカニズムがあります。これは Variable.unchain_backward() メソッドに実装されています。Backward Unchainingは、Variableオブジェクトから始まり、変数から計算履歴を断ち切ります。断ち切られた変数は自動的に削除されます(他のオブジェクトからの明示的な参照がない場合)。結果として、それらはもはや計算履歴の一部ではなくなり、もはやbackpropagationに関与しなくなります。

truncated backpropagationの例を書いていきましょう。ここでは前のサブセクションでしたものと同じネットワークを使用します。非常に長いシーケンスが与えられていると仮定し、30回のステップごとに切り捨てるバックプロパゲーションを実行したいとします。上記で定義したモデルを使ってtruncated backpropagationを書くことができます。:

loss = 0
count = 0
seqlen = len(x_list[1:])

rnn.reset_state()
for cur_word, next_word in zip(x_list, x_list[1:]):
    loss += model(cur_word, next_word)
    count += 1
    if count % 30 == 0 or count == seqlen:
        model.cleargrads()
        loss.backward()
        loss.unchain_backward()
        optimizer.update()

状態は model() によって更新され、損失は loss 変数に蓄積されます。30ステップごとに、蓄積された損失によってバックプロパゲーションが行われます。それから、 unchain_backward() メソッドが呼び出され、蓄積された損失から計算履歴が削除されます。model の最後の状態は、RNNインスタンスが参照を保持しているので失われないことに注意してください。

truncated backpropagationの実装は単純で、複雑なトリックがないので、このメソッドをさまざまな状況に一般化できます。たとえば、上記のコードを簡単に拡張して、backpropagationのタイミングと切り捨ての長さの間で異なるスケジュールを使用することができます。

計算履歴を保存しないネットワーク評価

リカレントニューラルネットワークの評価では、通常、計算履歴を保存する必要はありません。unchainingが制限のあるメモリにおいて無限の長さのシーケンスを使うことを可能にしますが、これは回避策です。

代わりに、Chainerは計算履歴を保存しないForward計算の評価モードを提供します。これはすべての入力変数に volatile フラグを渡すだけで有効になります。このような変数は volatile変数 と呼ばれます。

Volatile変数は、構築時に volatile=’on’` を渡すことによって作成されます。

x_list = [Variable(..., volatile='on') for _ in range(100)]  # list of 100 words
loss = compute_loss(x_list)

volatile変数は計算履歴を記憶していないので、ここでは勾配を計算するための loss.backward() を呼び出すことができないことに注意して下さい。

Volatile変数は、メモリ使用量を減らすためにfeed-forwardネットワークを評価するのにも役立ちます。

volatile変数は、 Variable.volatile 属性を設定することによって直接変更することができます。これにより、学習済みの特徴抽出ネットワークと学習可能な予測ネットワークを組み合わせることができます。たとえば、別の学習済みネットワーク fixed_func の上位に位置するfeed-forwardネットワーク predictor_func を訓練したいと仮定します。fixed_func の計算履歴を保存せずに predictor_func を訓練したいと思います。これは簡単に次のコードスニペットによって実現出来ます。( x_data と y_data がそれぞれ入力データとラベルを示すと仮定します)。

x = Variable(x_data, volatile='on')
feat = fixed_func(x)
feat.volatile = 'off'
y = predictor_func(feat)
y.backward()

最初は、入力変数 x はvolatileであるため、 fixed_func は、volatileモードで実行されるため、計算履歴を記憶されません。中間変数 feat は手動でvolatileをオフに設定されるので、predictor_func はnon-volatileモードで計算履歴を記憶しながら実行されます。計算の履歴は変数 featy の間にのみ記憶されるため、backward計算は feat 変数で停止します。

警告

同じ関数の引数としてvolatile変数とnon-volatile変数を混在させることはできません。non-volatile変数のように振る舞い、かつvolatile変数と混在する変数を作成したい場合は、 'off' フラグの代わりに 'auto' フラグを使います。

Trainerで作る

上記のコードは、プレーンな関数/変数APIで記述されています。訓練ループを書くときは、拡張機能で機能を簡単に追加できるTrainerを使用する方が良いです。

Trainerに実装する前に、訓練設定を明確にしましょう。ここでは、Penn Tree Bankデータセットを文章のデータセットとして使用します。各文章は単語列として表されます。すべての文章を1つの長い単語列に連結し、各文章は “End of Sequence”を表す特別な単語 <eos> で区切られています。このデータセットは chainer.datasets.get_ptb_words() で簡単に取得できます。この関数はtrain、validation、およびtestデータセットを返します。これらのデータセットはそれぞれ長い整数配列として表されます。各整数は単語IDを表します。

私たちの仕事は、長い単語列からリカレントニューラルネットワークの言語モデルを学習することです。我々は、mini-batchesから異なる場所にある単語を使用する。これは、シーケンス内の異なる位置を指す \(B\) インデックスを維持し、各イテレーションでこれらのインデックスから読み取り、読み取り後にすべてのインデックスをインクリメントすることを意味します。もちろん、1つのインデックスがシーケンス全体の終わりに達すると、インデックスを0に戻します。

この訓練手順を実装するには、Trainerの次のコンポーネントをカスタマイズする必要があります。

  • ビルトインイテレータは、異なる場所からの読み込みとmini-batchへの集約をサポートしていません。

  • デフォルトの更新機能はtruncated BPTTをサポートしていません。

データセット専用のデータセットイテレータを記述する場合、インターフェースが固定されていなくてもデータセットの実装は任意です。一方、イテレータは Iterator インターフェースをサポートしなければなりません。実装する重要なメソッドと属性は、 batch_sizeepochepoch_detailis_new_epochiteration__next__ 、および serialize です。以下は examples/ptb ディレクトリの公式のコードです。

from __future__ import division

class ParallelSequentialIterator(chainer.dataset.Iterator):
    def __init__(self, dataset, batch_size, repeat=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.epoch = 0
        self.is_new_epoch = False
        self.repeat = repeat
        self.offsets = [i * len(dataset) // batch_size for i in range(batch_size)]
        self.iteration = 0

    def __next__(self):
        length = len(self.dataset)
        if not self.repeat and self.iteration * self.batch_size >= length:
            raise StopIteration
        cur_words = self.get_words()
        self.iteration += 1
        next_words = self.get_words()

        epoch = self.iteration * self.batch_size // length
        self.is_new_epoch = self.epoch < epoch
        if self.is_new_epoch:
            self.epoch = epoch

        return list(zip(cur_words, next_words))

    @property
    def epoch_detail(self):
        return self.iteration * self.batch_size / len(self.dataset)

    def get_words(self):
        return [self.dataset[(offset + self.iteration) % len(self.dataset)]
                for offset in self.offsets]

    def serialize(self, serializer):
        self.iteration = serializer('iteration', self.iteration)
        self.epoch = serializer('epoch', self.epoch)

train_iter = ParallelSequentialIterator(train, 20)
val_iter = ParallelSequentialIterator(val, 1, repeat=False)

コードはやや長いですが、そのアイデアは簡単です。このイテレータは、シーケンス全体に均等に配置された位置を示す offsets を作成します。mini-batchesのi番目の例は、i番目のオフセットでシーケンスを参照します。イテレータは、現在の単語と次の単語のタプルのリストを返します。各mini-batchは、標準アップデーターの concat_examples 関数(前のチュートリアルを参照)によって整数配列のタプルに変換されます。

Back Propagation Through Time (BPTT)は次のように実装されています。

def update_bptt(updater):
    loss = 0
    for i in range(35):
        batch = train_iter.__next__()
        x, t = chainer.dataset.concat_example(batch)
        loss += model(chainer.Variable(x), chainer.Variable(t))

    model.cleargrads()
    loss.backward()
    loss.unchain_backward()  # truncate
    optimizer.update()

updater = training.StandardUpdater(train_iter, optimizer, update_bptt)

この場合、連続する35単語ごとにパラメータを更新します。unchain_backward の呼び出しは、LSTM Linkに蓄積された計算の履歴を切り捨てます。Trainerを設定するコードの残りの部分は、前のチュートリアルで示したものとほぼ同じです。


このセクションでは、Chainerでリカレントニューラルネットワークを書く方法と、計算の履歴を管理するためのいくつかの基本的なテクニック(つまり計算グラフ)を示しました。examples/ptb ディレクトリの例は、Penn TreebankコーパスからのLSTM言語モデルのTruncated Propagation学習を実装しています。次のセクションでは、ChainerでGPUを使用する方法を見ていきます。