From d29ece4c975fa5138e10c94bbda3fcffcf286fab Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Mon, 13 Apr 2020 17:11:05 +0800 Subject: [PATCH] add register_forward_hook api for Layer (#1941) * add register_forward_hook api for Layer, test=develop * fix sample code, test=develop --- doc/fluid/api_cn/dygraph_cn/Layer_cn.rst | 94 ++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/doc/fluid/api_cn/dygraph_cn/Layer_cn.rst b/doc/fluid/api_cn/dygraph_cn/Layer_cn.rst index fcc1bf6fc..b13278ad1 100644 --- a/doc/fluid/api_cn/dygraph_cn/Layer_cn.rst +++ b/doc/fluid/api_cn/dygraph_cn/Layer_cn.rst @@ -21,6 +21,100 @@ Layer的全名。组成方式为: ``name_scope`` + “/” + MyLayer.__class__ 返回类型:str +.. py:method:: register_forward_pre_hook(hook) + +为Layer注册一个 ``forward pre-hook`` 函数,该 ``hook`` 函数将会在 ``forward`` 函数调用之前被调用。 + +``hook`` 函数具有以下形式:它的 ``input`` 是 ``Layer`` 的 ``input`` ,并且可以返回一个元组或者单个修改值;如果返回单个修改值,则将值包装到一个元组中。用户可以使用该函数来查看或修改 ``Layer`` ``forward`` 函数的输入。 + +hook(Layer, input) -> None or modified input + +参数: + - **hook** (function) - 被注册为 ``forward pre-hook`` 的函数 + +返回:一个 ``HookRemoveHelper`` 类对象,可通过调用 ``hook_remove_helper.remove()`` 来删除注册的hook函数。 + +返回类型: ``HookRemoveHelper`` 类对象 + +**代码示例** + +.. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + # forward_pre_hook函数修改了layer的输入:input = input * 2 + def forward_pre_hook(layer, input): + # 改变输入值 + input_return = (input[0] * 2) + return input_return + + with fluid.dygraph.guard(): + linear = fluid.Linear(13, 5, dtype="float32") + + # 注册hook + forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook) + + value0 = np.arange(26).reshape(2, 13).astype("float32") + in0 = fluid.dygraph.to_variable(value0) + out0 = linear(in0) + + # 移除hook + forward_pre_hook_handle.remove() + + value1 = value0 * 2 + in1 = fluid.dygraph.to_variable(value1) + out1 = linear(in1) + + # hook改变了layer的输入(input = input * 2),所以out0等于out1 + assert (out0.numpy() == out1.numpy()).any() + +.. py:method:: register_forward_post_hook(hook) + +为Layer注册一个 ``forward post-hook`` 函数,该 ``hook`` 函数将会在 ``forward`` 函数调用之后被调用。 + +``hook`` 函数具有以下形式,它的 ``input`` 和 ``output`` 是 ``Layer`` 的 ``input`` 和 ``output`` 。用户可以用该函数来查看和修改 ``Layer`` ``forward`` 函数的输出。 + +hook(Layer, input, output) -> None or modified output + +参数: + - **hook** (function) - 被注册为 ``forward post-hook`` 的函数 + +返回:一个 ``HookRemoveHelper`` 类对象,可通过调用 ``hook_remove_helper.remove()`` 来删除注册的hook函数。 + +返回类型: ``HookRemoveHelper`` 类对象 + +**代码示例** + +.. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + # forward_post_hook函数改变了layer的输出:output = output * 2 + def forward_post_hook(layer, input, output): + # 改变输出值 + return output * 2 + + with fluid.dygraph.guard(): + linear = fluid.Linear(13, 5, dtype="float32") + + # 注册hook + forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook) + + value1 = np.arange(26).reshape(2, 13).astype("float32") + in1 = fluid.dygraph.to_variable(value1) + + out0 = linear(in1) + + # remove the hook + forward_post_hook_handle.remove() + + out1 = linear(in1) + + # hook改变了layer的输出(output = output * 2),所以out0等于out1 * 2 + assert (out0.numpy() == (out1.numpy()) * 2).any() + .. py:method:: create_parameter(shape, attr=None, dtype="float32", is_bias=False, default_initializer=None) 为Layer创建参数。 -- GitLab