訓練ループの抽象化

Chainer provides a standard implementation of the training loops under the chainer.training module. It is built on top of many other core features of Chainer, including Variable and Function, Link/Chain/ChainList, Optimizer, Dataset, and Reporter/Summary. Compared to the training loop abstraction of other machine learning tool kits, Chainer’s training framework aims at maximal flexibility, while keeps the simplicity for the typical usages. Most components are pluggable, and users can overwrite the definition.

The core of the training loop abstraction is Trainer, which implements the training loop itself. The training loop consists of two parts: one is Updater, which actually updates the parameters to train, and the other is Extension for arbitrary functionalities other than the parameter update.

Updater and some extensions use dataset and Iterator to scan the datasets and load mini batches. The trainer also uses Reporter to collect the observed values, and some extensions use DictSummary to accumulate them and computes the statistics.

You can find many examples for the usage of this training utilities from the official examples. You can also search the extension implementations from Trainer拡張.

Trainer

class chainer.training.Trainer(updater, stop_trigger=None, out='result')[ソース]

Chainerの標準トレーニングループ。

Trainer is an implementation of a training loop. Users can invoke the training by calling the run() method.

Each iteration of the training loop proceeds as follows.

  • Update of the parameters. It includes the mini-batch loading, forward and backward computations, and an execution of the update formula. These are all done by the update object held by the trainer.
  • Invocation of trainer extensions in the descending order of their priorities. A trigger object is attached to each extension, and it decides at each iteration whether the extension should be executed. Trigger objects are callable objects that take the trainer object as the argument and return a boolean value indicating whether the extension should be called or not.

Extensions are callable objects that take the trainer object as the argument. There are two ways to define custom extensions: inheriting the Extension class, and decorating functions by make_extension(). See Extension for more details on custom extensions.

Users can register extensions to the trainer by calling the extend() method, where some configurations can be added.

  • Trigger object, which is also explained above. In most cases, IntervalTrigger is used, in which case users can simply specify a tuple of the interval length and its unit, like (1000, 'iteration') or (1, 'epoch').
  • The order of execution of extensions is determined by their priorities. Extensions of higher priorities are invoked earlier. There are three standard values for the priorities:
    • PRIORITY_WRITER. This is the priority for extensions that write some records to the observation dictionary. It includes cases that the extension directly adds values to the observation dictionary, or the extension uses the chainer.report() function to report values to the observation dictionary.
    • PRIORITY_EDITOR. This is the priority for extensions that edit the observation dictionary based on already reported values.
    • PRIORITY_READER. This is the priority for extensions that only read records from the observation dictionary. This is also suitable for extensions that do not use the observation dictionary at all.
  • Extensions with invoke_before_training flag on are also invoked at the beginning of the training loop. Extensions that update the training status (e.g., changing learning rates) should have this flag to be True to ensure that resume of the training loop correctly recovers the training status.

The current state of the trainer object and objects handled by the trainer can be serialized through the standard serialization protocol of Chainer. It enables us to easily suspend and resume the training loop.

注釈

The serialization does not recover everything of the training loop. It only recovers the states which change over the training (e.g. parameters, optimizer states, the batch iterator state, extension states, etc.). You must initialize the objects correctly before deserializing the states.

On the other hand, it means that users can change the settings on deserialization. For example, the exit condition can be changed on the deserialization, so users can train the model for some iterations, suspend it, and then resume it with larger number of total iterations.

During the training, it also creates a Reporter object to store observed values on each update. For each iteration, it creates a fresh observation dictionary and stores it in the observation attribute.

Links of the target model of each optimizer are registered to the reporter object as observers, where the name of each observer is constructed as the format <optimizer name><link name>. The link name is given by the chainer.Link.namedlink() method, which represents the path to each link in the hierarchy. Other observers can be registered by accessing the reporter object via the reporter attribute.

The default trainer is plain, i.e., it does not contain any extensions.

パラメータ:
  • updater (Updater) – Updater object. It defines how to update the models.
  • stop_trigger – Trigger that determines when to stop the training loop. If it is not callable, it is passed to IntervalTrigger.
変数:
  • updater – The updater object for this trainer.
  • stop_trigger – Trigger that determines when to stop the training loop. The training loop stops at the iteration on which this trigger returns True.
  • observation – Observation of values made at the last update. See the Reporter class for details.
  • out – Output directory.
  • reporter – Reporter object to report observed values.
elapsed_time

Total time used for the training.

The time is in seconds. If the training is resumed from snapshot, it includes the time of all the previous training to get the current state of the trainer.

extend(extension, name=None, trigger=None, priority=None, invoke_before_training=None)[ソース]

Registers an extension to the trainer.

Extension is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations.

If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is _N where N is the ordinal of the extensions.

See Extension for the interface of extensions.

パラメータ:
  • extension – Extension to register.
  • name (str) – Name of the extension. If it is omitted, the default_name attribute of the extension is used instead. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above.
  • trigger (tuple or Trigger) – Trigger object that determines when to invoke the extension. If it is None, extension.trigger is used instead. If the trigger is not callable, it is passed to IntervalTrigger to build an interval trigger.
  • priority (int) – Invocation priority of the extension. Extensions are invoked in the descending order of priorities in each iteration. If this is None, extension.priority is used instead.
  • invoke_before_training (bool or None) – If True, the extension is also invoked just before entering the training loop. If this is None, extension.invoke_before_training is used instead. This option is mainly used for extensions that alter the training configuration (e.g., learning rates); in such a case, resuming from snapshots require the call of extension to recover the configuration before any updates.
get_extension(name)[ソース]

Returns the extension of a given name.

パラメータ:name (str) – Name of the extension.
戻り値:Extension.
run()[ソース]

Executes the training loop.

This method is the core of Trainer. It executes the whole loop of training the models.

Note that this method cannot run multiple times for one trainer object.

Updater

class chainer.training.Updater[ソース]

TrainerのためのUpdaterオブジェクトのインタフェース。

TODO(beam2d): document it.

connect_trainer(trainer)[ソース]

Connects the updater to the trainer that will call it.

The typical usage of this method is to register additional links to the reporter of the trainer. This method is called at the end of the initialization of Trainer. The default implementation does nothing.

パラメータ:trainer (Trainer) – Trainer object to which the updater is registered.
finalize()[ソース]

Finalizes the updater object.

This method is called at the end of training loops. It should finalize each dataset iterator used in this updater.

get_all_optimizers()[ソース]

Gets a dictionary of all optimizers for this updater.

戻り値:Dictionary that maps names to optimizers.
戻り値の型:dict
get_optimizer(name)[ソース]

Gets the optimizer of given name.

Updater holds one or more optimizers with names. They can be retrieved by this method.

パラメータ:name (str) – Name of the optimizer.
戻り値:Optimizer of the name.
戻り値の型:Optimizer
serialize(serializer)[ソース]

Serializes the current state of the updater object.

update()[ソース]

Updates the parameters of the target model.

This method implements an update formula for the training task, including data loading, forward/backward computations, and actual updates of parameters.

This method is called once at each iteration of the training loop.

class chainer.training.StandardUpdater(iterator, optimizer, converter=<function concat_examples>, device=None, loss_func=None)[ソース]

Standard implementation of Updater.

This is the standard implementation of Updater. It accepts one or more training datasets and one or more optimizers. The default update routine assumes that there is only one training dataset and one optimizer. Users can override this update routine by inheriting this class and overriding the update_core() method. Each batch is converted to input arrays by concat_examples() by default, which can also be manually set by converter argument.

パラメータ:
  • iterator – Dataset iterator for the training dataset. It can also be a dictionary of iterators. If this is just an iterator, then the iterator is registered by the name 'main'.
  • optimizer – Optimizer to update parameters. It can also be a dictionary of optimizers. If this is just an optimizer, then the optimizer is registered by the name 'main'.
  • converter – Converter function to build input arrays. Each batch extracted by the main iterator and the device option are passed to this function. concat_examples() is used by default.
  • device – Device to which the training data is sent. Negative value indicates the host memory (CPU).
  • loss_func – Loss function. The target link of the main optimizer is used by default.
変数:
  • converter – Converter function.
  • loss_func – Loss function. If it is None, the target link of the main optimizer is used instead.
  • device – Device to which the training data is sent.
  • iteration – Current number of completed updates.
get_iterator(name)[ソース]

Gets the dataset iterator of given name.

パラメータ:name (str) – Name of the dataset iterator.
戻り値:Corresponding dataset iterator.
戻り値の型:Iterator
class chainer.training.ParallelUpdater(iterator, optimizer, converter=<function concat_examples>, models=None, devices=None, loss_func=None)[ソース]

Implementation of a parallel GPU Updater.

This is an implementation of Updater that uses multiple GPUs. It behaves similarly to StandardUpdater. The update routine is modified to support data-parallel computation on multiple GPUs in one machine. It is based on synchronous parallel SGD: it parallelizes the gradient computation over a mini-batch, and updates the parameters only in the main device.

パラメータ:
  • iterator – Dataset iterator for the training dataset. It can also be a dictionary of iterators. If this is just an iterator, then the iterator is registered by the name 'main'.
  • optimizer – Optimizer to update parameters. It can also be a dictionary of optimizers. If this is just an optimizer, then the optimizer is registered by the name 'main'.
  • converter – Converter function to build input arrays. Each batch extracted by the main iterator is split equally between the devices and then passed with corresponding device option to this function. concat_examples() is used by default.
  • models – Dictionary of models. The main model should be the same model attached to the 'main' optimizer.
  • devices – Dictionary of devices to which the training data is sent. The devices should be arranged in a dictionary with the same structure as models.
  • loss_func – Loss function. The model is used as a loss function by default.

Extension

class chainer.training.Extension[ソース]

Trainer拡張の基本クラス。

Extension of Trainer is a callable object that takes the trainer object as the argument. It also provides some default configurations as its attributes, e.g. the default trigger and the default priority. This class provides a set of typical default values for these attributes.

There are two ways to define users’ own extensions: inheriting this class, or decorating closures by make_extension(). Decorator can slightly reduce the overhead and is much easier to use, while this class provides more flexibility (for example, it can have methods to configure the behavior).

変数:
  • trigger – Default value of trigger for this extension. It is set to (1, 'iteration') by default.
  • priority – Default priority of the extension. It is set to PRIORITY_READER by default.
  • invoke_before_training – Default flag to decide whether this extension should be invoked before the training starts. The default value is False.
default_name

Default name of the extension.

It is the name of the class by default. Implementation can override this property, or provide a class attribute to hide it.

finalize()[ソース]

Finalizes the extension.

This method is called at the end of the training loop.

serialize(serializer)[ソース]

Serializes the extension state.

It is called when a trainer that owns this extension is serialized. It serializes nothing by default.

chainer.training.make_extension(trigger=None, default_name=None, priority=None, invoke_before_training=False, finalizer=None)[ソース]

Decorator to make given functions into trainer extensions.

This decorator just adds some attributes to a given function. The value of the attributes are given by the arguments of this decorator.

See Extension for details of trainer extensions. Most of the default values of arguments also follow those for this class.

パラメータ:
  • trigger – Default trigger of the extension.
  • default_name – Default name of the extension. The name of a given function is used by default.
  • priority (int) – Default priority of the extension.
  • invoke_before_training (bool) – Default flag to decide whether the extension should be invoked before any training.
  • finalizer – Finalizer function of this extension. The finalizer is called at the end of the training loop.

Trigger

Trigger is a callable object to decide when to process some specific event within the training loop. It takes a Trainer object as the argument, and returns True if some event should be fired.

It is mainly used to determine when to call an extension. It is also used to determine when to quit the training loop.

chainer.training.get_trigger(trigger)[ソース]

Gets a trigger object.

Trigger object is a callable that accepts a Trainer object as an argument and returns a boolean value. When it returns True, various kinds of events can occur depending on the context in which the trigger is used. For example, if the trigger is passed to the Trainer as the stop trigger, the training loop breaks when the trigger returns True. If the trigger is passed to the extend() method of a trainer, then the registered extension is invoked only when the trigger returns True.

This function returns a trigger object based on the argument. If trigger is already a callable, it just returns the trigger. If trigger is None, it returns a trigger that never fires. Otherwise, it passes the value to IntervalTrigger.

パラメータ:trigger – Trigger object. It can be either an already built trigger object (i.e., a callable object that accepts a trainer object and returns a bool value), or a tuple. In latter case, the tuple is passed to IntervalTrigger.
戻り値:trigger if it is a callable, otherwise a IntervalTrigger object made from trigger.