未验证 提交 166a1ae9 编写于 作者: Z zhongpu 提交者: GitHub

support forward hook for dygraph (#22443)

* support forward hook for dygraph, test=develop

* add optest for forward_hook in dygraph, test=develop

* add optest, test=develop

* polish code, test=develop

* add sample code, test=develop

* rename forwrd_hook to forward_post_hook, test=develop

* fix the api description, test=develop

* fix api description, test=develop
上级 a62599a8
......@@ -27,6 +27,7 @@ from .base import program_desc_tracing_guard
from paddle.fluid import framework
from ..param_attr import ParamAttr
import copy
import weakref
import warnings
__all__ = ['Layer']
......@@ -40,6 +41,22 @@ def _convert_camel_to_snake(name):
return _all_cap_re.sub(r'\1_\2', s1).lower()
class HookRemoveHelper(object):
""" A HookRemoveHelper that can be used to remove hook. """
next_hook_id = 0
def __init__(self, hooks):
self._hooks_ref = weakref.ref(hooks)
self._hook_id = HookRemoveHelper.next_hook_id
HookRemoveHelper.next_hook_id += 1
def remove(self):
hooks = self._hooks_ref()
if hooks is not None and self._hook_id in hooks:
del hooks[self._hook_id]
class Layer(core.Layer):
"""Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
......@@ -70,6 +87,9 @@ class Layer(core.Layer):
self._sub_layers = collections.OrderedDict()
self._loaddict_holder = collections.OrderedDict()
self._forward_pre_hooks = collections.OrderedDict()
self._forward_post_hooks = collections.OrderedDict()
def train(self):
framework._dygraph_tracer().train_mode()
......@@ -84,6 +104,108 @@ class Layer(core.Layer):
"""
return self._full_name
def register_forward_post_hook(self, hook):
"""Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed.
It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively.
User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.
hook(Layer, input, output) -> None or modified output
Parameters:
hook(function): a function registered as a forward post-hook
Returns:
HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
# the forward_post_hook change the output of the layer: output = output * 2
def forward_post_hook(layer, input, output):
# user can use layer, input and output for information statistis tasks
# change the output
return output * 2
with fluid.dygraph.guard():
linear = fluid.Linear(13, 5, dtype="float32")
# register the hook
forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
value = np.arange(26).reshape(2, 13).astype("float32")
in = fluid.dygraph.to_variable(value0)
out0 = linear(in)
# remove the hook
forward_post_hook_handle.remove()
out1 = linear(in)
# hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
assert (out0.numpy() == (out1.numpy()) * 2).any()
"""
hook_remove_helper = HookRemoveHelper(self._forward_post_hooks)
self._forward_post_hooks[hook_remove_helper._hook_id] = hook
return hook_remove_helper
def register_forward_pre_hook(self, hook):
"""Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed.
It should have the following form, `input` of the `hook` is `input` of the `Layer`,
hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if
a single value is returned(unless that value is already a tuple).
User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.
hook(Layer, input) -> None or modified input
Parameters:
hook(function): a function registered as a forward pre-hook
Returns:
HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
# the forward_post_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input):
# user can use layer and input for information statistis tasks
# change the input
input_return = (input[0] * 2)
return input_return
with fluid.dygraph.guard():
linear = fluid.Linear(13, 5, dtype="float32")
# register the hook
forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)
value0 = np.arange(26).reshape(2, 13).astype("float32")
in0 = fluid.dygraph.to_variable(value0)
out0 = linear(in0)
# remove the hook
forward_pre_hook_handle.remove()
value1 = value0 * 2
in1 = fluid.dygraph.to_variable(value1)
out1 = linear(in1)
# hook change the linear's input to input * 2, so out0 is equal to out1.
assert (out0.numpy() == out1.numpy()).any()
"""
hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
return hook_remove_helper
def create_parameter(self,
shape,
attr=None,
......@@ -293,6 +415,13 @@ class Layer(core.Layer):
pass
def __call__(self, *inputs, **kwargs):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
......@@ -302,6 +431,12 @@ class Layer(core.Layer):
self._built = True
outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result
return outputs
def forward(self, *inputs, **kwargs):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import six
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.dygraph.base as base
from test_imperative_lod_tensor_to_selected_rows import SimpleNet
call_forward_hook = False
call_forward_pre_hook = False
def forward_hook(layer, input, output):
global call_forward_hook
call_forward_hook = True
def forward_pre_hook(layer, input):
global call_forward_pre_hook
call_forward_pre_hook = True
def forward_hook1(layer, input, output):
return output * 2
def forward_pre_hook1(layer, input):
input_return = (input[0] * 2, input[1])
return input_return
class Test_Forward_Hook(unittest.TestCase):
# test forward_pre_hook and forward_hook that have return value
def test_forward_hook_return_value(self):
seed = 90
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word1 = input_word * 2
input_word = input_word.reshape((-1, 3, 1))
input_word1 = input_word1.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))
input = base.to_variable(input_word)
input1 = base.to_variable(input_word1)
y = base.to_variable(y_data)
simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")
# origin, don't register any hook
outs_origin = simplenet(input, y)
outs_origin1 = simplenet(input1, y)
# register forward_pre_hook
forward_pre_hook_handle1 = simplenet.register_forward_pre_hook(
forward_pre_hook1)
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin1.numpy()))
# remove forward_pre_hook
forward_pre_hook_handle1.remove()
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))
# register forward_hook
forward_hook_handle1 = simplenet.register_forward_post_hook(
forward_hook1)
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2))
# remove forward_hook
forward_hook_handle1.remove()
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))
# test forward_pre_hook and forward_hook that don't have return value
def test_forward_hook(self):
seed = 90
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
global call_forward_hook
global call_forward_pre_hook
input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word = input_word.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))
input = base.to_variable(input_word)
y = base.to_variable(y_data)
simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")
# origin, don't register any hook
outs_origin = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_pre_hook)
# register forward_hook and forward_pre_hook
forward_hook_handle = simplenet.register_forward_post_hook(
forward_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_pre_hook)
# remove forward_hook
forward_hook_handle.remove()
call_forward_hook = False
call_forward_pre_hook = False
outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertTrue(call_forward_pre_hook)
# remove forward_pre_hook
forward_pre_hook_handle.remove()
call_forward_hook = False
call_forward_pre_hook = False
outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_pre_hook)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册