データセットの抽象化

Chainerは、トレーニングと検証のデータセットの共通インタフェースをサポートしています。データセットのサポートは、データセット、イテレータ、バッチ変換関数の3つのコンポーネントで構成されています。

Dataset represents a set of examples. The interface is only determined by combination with iterators you want to use on it. The built-in iterators of Chainer requires the dataset to support __getitem__ and __len__ method. In particular, the __getitem__ method should support indexing by both an integer and a slice. We can easily support slice indexing by inheriting DatasetMixin, in which case users only have to implement get_example() method for indexing. Some iterators also restrict the type of each example. Basically, datasets are considered as stateless objects, so that we do not need to save the dataset as a checkpoint of the training procedure.

Iterator iterates over the dataset, and at each iteration, it yields a mini batch of examples as a list. Iterators should support the Iterator interface, which includes the standard iterator protocol of Python. Iterators manage where to read next, which means they are stateful.

Batch conversion function converts the mini batch into arrays to feed to the neural nets. They are also responsible to send each array to an appropriate device. Chainer currently provides concat_example() as the only example of batch conversion functions.

These components are all customizable, and designed to have a minimum interface to restrict the types of datasets and ways to handle them. In most cases, though, implementations provided by Chainer itself are enough to cover the usages.

Chainer also has a light system to download, manage, and cache concrete examples of datasets. All datasets managed through the system are saved under the dataset root directory, which is determined by the CHAINER_DATASET_ROOT environment variable, and can also be set by the set_dataset_root() function.

データセット表現

データセットの実装については、データセットの例 を参照してください。

class chainer.dataset.DatasetMixin[ソース]

Default implementation of dataset indexing.

DatasetMixin provides the __getitem__() operator. The default implementation uses get_example() to extract each example, and combines the results into a list. This mixin makes it easy to implement a new dataset that does not support efficient slicing.

Dataset implementation using DatasetMixin still has to provide the __len__() operator explicitly.

get_example(i)[ソース]

Returns the i-th example.

Implementations should override it. It should raise IndexError if the index is invalid.

パラメータ:i (int) – The index of the example.
戻り値:The i-th example.

イテレータインタフェース

データセットイテレータの実装については、イテレーターの例 を参照してください。

class chainer.dataset.Iterator[ソース]

Base class of all dataset iterators.

Iterator iterates over the dataset, yielding a minibatch at each iteration. Minibatch is a list of examples. Each implementation should implement an iterator protocol (e.g., the __next__() method).

Note that, even if the iterator supports setting the batch size, it does not guarantee that each batch always contains the same number of examples. For example, if you let the iterator to stop at the end of the sweep, the last batch may contain a fewer number of examples.

The interface between the iterator and the underlying dataset is not fixed, and up to the implementation.

Each implementation should provide the following attributes (not needed to be writable).

  • batch_size: Number of examples within each minibatch.
  • epoch: Number of completed sweeps over the dataset.
  • epoch_detail: Floating point number version of the epoch. For example, if the iterator is at the middle of the dataset at the third epoch, then this value is 2.5.
  • is_new_epoch: True if the epoch count was incremented at the last update.

Each implementation should also support serialization to resume/suspend the iteration.

finalize()[ソース]

Finalizes the iterator and possibly releases the resources.

This method does nothing by default. Implementation may override it to better handle the internal resources.

next()[ソース]

Python2 alternative of __next__.

It calls __next__() by default.

serialize(serializer)[ソース]

Serializes the internal state of the iterator.

This is a method to support serializer protocol of Chainer.

注釈

It should only serialize the internal state that changes over the iteration. It should not serializes what is set manually by users such as the batch size.

バッチ変換関数

chainer.dataset.concat_examples(batch, device=None, padding=None)[ソース]

Concatenates a list of examples into array(s).

Dataset iterator yields a list of examples. If each example is an array, this function concatenates them along the newly-inserted first axis (called batch dimension) into one array. The basic behavior is same for examples consisting of multiple arrays, i.e., corresponding arrays of all examples are concatenated.

For instance, consider each example consists of two arrays (x, y). Then, this function concatenates x ‘s into one array, and y ‘s into another array, and returns a tuple of these two arrays. Another example: consider each example is a dictionary of two entries whose keys are 'x' and 'y', respectively, and values are arrays. Then, this function concatenates x ‘s into one array, and y ‘s into another array, and returns a dictionary with two entries x and y whose values are the concatenated arrays.

When the arrays to concatenate have different shapes, the behavior depends on the padding value. If padding is None (default), it raises an error. Otherwise, it builds an array of the minimum shape that the contents of all arrays can be substituted to. The padding value is then used to the extra elements of the resulting arrays.

TODO(beam2d): Add an example.

パラメータ:
  • batch (list) – A list of examples. This is typically given by a dataset iterator.
  • device (int) – Device ID to which each array is sent. Negative value indicates the host memory (CPU). If it is omitted, all arrays are left in the original device.
  • padding – Scalar value for extra elements. If this is None (default), an error is raised on shape mismatch. Otherwise, an array of minimum dimensionalities that can accommodate all arrays is created, and elements outside of the examples are padded by this value.
戻り値:

Array, a tuple of arrays, or a dictionary of arrays. The type depends on the type of each example in the batch.

データセット管理

chainer.dataset.get_dataset_root()[ソース]

データセットをダウンロードしてキャッシュするルートディレクトリへのパスを取得します。

戻り値:The path to the dataset root directory.
戻り値の型:str
chainer.dataset.set_dataset_root(path)[ソース]

データセットをダウンロードしてキャッシュするルートディレクトリを設定します。

There are two ways to set the dataset root directory. One is by setting the environment variable CHAINER_DATASET_ROOT. The other is by using this function. If both are specified, one specified via this function is used. The default dataset root is $HOME/.chainer/dataset.

パラメータ:path (str) – Path to the new dataset root directory.
chainer.dataset.cached_download(url)[ソース]

ファイルをダウンロードしてキャッシュします。

It downloads a file from the URL if there is no corresponding cache. After the download, this function stores a cache to the directory under the dataset root (see set_dataset_root()). If there is already a cache for the given URL, it just returns the path to the cache without downloading the same file.

パラメータ:url (str) – URL to download from.
戻り値:Path to the downloaded file.
戻り値の型:str
chainer.dataset.cache_or_load_file(path, creator, loader)[ソース]

存在しない場合はファイルをキャッシュし、存在しない場合はファイルをロードします。

This is a utility function used in dataset loading routines. The creator creates the file to given path, and returns the content. If the file already exists, the loader is called instead, and it loads the file and returns the content.

Note that the path passed to the creator is temporary one, and not same as the path given to this function. This function safely renames the file created by the creator to a given path, even if this function is called simultaneously by multiple threads or processes.

パラメータ:
  • path (str) – Path to save the cached file.
  • creator – Function to create the file and returns the content. It takes a path to temporary place as the argument. Before calling the creator, there is no file at the temporary path.
  • loader – Function to load the cached file and returns the content.
戻り値:

It returns the returned values by the creator or the loader.