関数フック

Chainer provides a function-hook mechanism that enriches the behavior of forward and backward propagation of Function.

ベースクラス

class chainer.function.FunctionHook[ソース]

Base class of hooks for Functions.

FunctionHook is an callback object that is registered to Function. Registered function hooks are invoked before and after forward and backward operations of each function.

Function hooks that derive FunctionHook are required to implement four methods: forward_preprocess(), forward_postprocess(), backward_preprocess(), and backward_postprocess(). By default, these methods do nothing.

Specifically, when __call__() method of some function is invoked, forward_preprocess() (resp. forward_postprocess()) of all function hooks registered to this function are called before (resp. after) forward propagation.

Likewise, when backward() of some Variable is invoked, backward_preprocess() (resp. backward_postprocess()) of all function hooks registered to the function which holds this variable as a gradient are called before (resp. after) backward propagation.

There are two ways to register FunctionHook objects to Function objects.

First one is to use with statement. Function hooks hooked in this way are registered to all functions within with statement and are unregistered at the end of with statement.

The following code is a simple example in which we measure the elapsed time of a part of forward propagation procedure with TimerHook, which is a subclass of FunctionHook.

>>> import chainer, chainer.links as L, chainer.functions as F
... class Model(chainer.Chain):
...     def __call__(self, x1):
...         return F.exp(self.l(x1))
... model1 = Model(l=L.Linear(10, 10))
... model2 = Model(l=L.Linear(10, 10))
... x = chainer.Variable(numpy.zeros((1, 10), 'f'))
... with chainer.function_hooks.TimerHook() as m:
...     _ = model1(x)
...     y = model2(x)
...     print(m.total_time())
... model3 = Model(l=L.Linear(10, 10))
... z = model3(y)

In this example, we measure the elapsed times for each forward propagation of all functions in model1 and model2 (specifically, LinearFunction and Exp of model1 and model2). Note that model3 is not a target of measurement as TimerHook is unregistered before forward propagation of model3.

注釈

Chainer stores the dictionary of registered function hooks as a thread local object. So, function hooks registered are different depending on threads.

The other one is to register directly to Function object with add_hook() method. Function hooks registered in this way can be removed by delete_hook() method. Contrary to former registration method, function hooks are registered only to the function which add_hook() is called.

パラメータ:name (str) – Name of this function hook.
backward_postprocess(function, in_data, out_grad)[ソース]

Callback function invoked after backward propagation.

パラメータ:
  • function (Function) – Function object to which the function hook is registered.
  • in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input of forward propagation.
  • out_grad (tuple of numpy.ndarray or tuple of cupy.ndarray) – Gradient data of backward propagation.
backward_preprocess(function, in_data, out_grad)[ソース]

Callback function invoked before backward propagation.

パラメータ:
  • function (Function) – Function object to which the function hook is registered.
  • in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input data of forward propagation.
  • out_grad (tuple of numpy.ndarray or tuple of cupy.ndarray) – Gradient data of backward propagation.
forward_postprocess(function, in_data)[ソース]

Callback function invoked after forward propagation.

パラメータ:
  • function (Function) – Function object to which the function hook is registered.
  • in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input data of forward propagation.
forward_preprocess(function, in_data)[ソース]

Callback function invoked before forward propagation.

パラメータ:
  • function (Function) – Function object to which the function hook is registered.
  • in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input data of forward propagation.

具体的な関数フック

class chainer.function_hooks.PrintHook(sep='', end='n', file=<_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>, flush=True)[ソース]

Function hook that prints debug information.

This function hook outputs the debug information of input arguments of forward and backward methods involved in the hooked functions at preprocessing time (that is, just before each method is called).

The basic usage is to use it with with statement.

>>> import chainer, chainer.functions as F, chainer.links as L
... l = L.Linear(10, 10)
... x = chainer.Variable(numpy.zeros((1, 10), 'f'))
... with chainer.function_hooks.PrintHook():
...     y = l(x)
...     z = F.sum(y)
...     z.backward()

In this example, PrintHook shows the debug information of forward propagation of LinearFunction (which is implicitly called by l) and Sum (called by F.sum) and backward propagation of z and y.

Unlike simple “debug print” technique, where users insert print functions at every function to be inspected, we can show the information of all functions involved with single with statement.

Further, this hook enables us to show the information of backward methods without inserting print functions into Chainer’s library code.

変数:
  • sep – Separator of print function.
  • end – Character to be added at the end of print function.
  • file – Output file_like object that that redirect to.
  • flush – If True, this hook forcibly flushes the text stream at the end of preprocessing.
class chainer.function_hooks.TimerHook[ソース]

Function hook for measuring elapsed time of functions.

変数:call_history – List of measurement results. It consists of pairs of the function that calls this hook and the elapsed time the function consumes.
total_time()[ソース]

Returns total elapsed time in seconds.