未验证 提交 166a1ae9 编写于 作者: Z zhongpu 提交者: GitHub

support forward hook for dygraph (#22443)

* support forward hook for dygraph, test=develop

* add optest for forward_hook in dygraph, test=develop

* add optest, test=develop

* polish code, test=develop

* add sample code, test=develop

* rename forwrd_hook to forward_post_hook, test=develop

* fix the api description, test=develop

* fix api description, test=develop
上级 a62599a8
...@@ -27,6 +27,7 @@ from .base import program_desc_tracing_guard ...@@ -27,6 +27,7 @@ from .base import program_desc_tracing_guard
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
import copy import copy
import weakref
import warnings import warnings
__all__ = ['Layer'] __all__ = ['Layer']
...@@ -40,6 +41,22 @@ def _convert_camel_to_snake(name): ...@@ -40,6 +41,22 @@ def _convert_camel_to_snake(name):
return _all_cap_re.sub(r'\1_\2', s1).lower() return _all_cap_re.sub(r'\1_\2', s1).lower()
class HookRemoveHelper(object):
""" A HookRemoveHelper that can be used to remove hook. """
next_hook_id = 0
def __init__(self, hooks):
self._hooks_ref = weakref.ref(hooks)
self._hook_id = HookRemoveHelper.next_hook_id
HookRemoveHelper.next_hook_id += 1
def remove(self):
hooks = self._hooks_ref()
if hooks is not None and self._hook_id in hooks:
del hooks[self._hook_id]
class Layer(core.Layer): class Layer(core.Layer):
"""Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on. """Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
...@@ -70,6 +87,9 @@ class Layer(core.Layer): ...@@ -70,6 +87,9 @@ class Layer(core.Layer):
self._sub_layers = collections.OrderedDict() self._sub_layers = collections.OrderedDict()
self._loaddict_holder = collections.OrderedDict() self._loaddict_holder = collections.OrderedDict()
self._forward_pre_hooks = collections.OrderedDict()
self._forward_post_hooks = collections.OrderedDict()
def train(self): def train(self):
framework._dygraph_tracer().train_mode() framework._dygraph_tracer().train_mode()
...@@ -84,6 +104,108 @@ class Layer(core.Layer): ...@@ -84,6 +104,108 @@ class Layer(core.Layer):
""" """
return self._full_name return self._full_name
def register_forward_post_hook(self, hook):
"""Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed.
It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively.
User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.
hook(Layer, input, output) -> None or modified output
Parameters:
hook(function): a function registered as a forward post-hook
Returns:
HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
# the forward_post_hook change the output of the layer: output = output * 2
def forward_post_hook(layer, input, output):
# user can use layer, input and output for information statistis tasks
# change the output
return output * 2
with fluid.dygraph.guard():
linear = fluid.Linear(13, 5, dtype="float32")
# register the hook
forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
value = np.arange(26).reshape(2, 13).astype("float32")
in = fluid.dygraph.to_variable(value0)
out0 = linear(in)
# remove the hook
forward_post_hook_handle.remove()
out1 = linear(in)
# hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
assert (out0.numpy() == (out1.numpy()) * 2).any()
"""
hook_remove_helper = HookRemoveHelper(self._forward_post_hooks)
self._forward_post_hooks[hook_remove_helper._hook_id] = hook
return hook_remove_helper
def register_forward_pre_hook(self, hook):
"""Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed.
It should have the following form, `input` of the `hook` is `input` of the `Layer`,
hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if
a single value is returned(unless that value is already a tuple).
User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.
hook(Layer, input) -> None or modified input
Parameters:
hook(function): a function registered as a forward pre-hook
Returns:
HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
Examples:
.. code-block:: python
import paddle.fluid as fluid
# the forward_post_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input):
# user can use layer and input for information statistis tasks
# change the input
input_return = (input[0] * 2)
return input_return
with fluid.dygraph.guard():
linear = fluid.Linear(13, 5, dtype="float32")
# register the hook
forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)
value0 = np.arange(26).reshape(2, 13).astype("float32")
in0 = fluid.dygraph.to_variable(value0)
out0 = linear(in0)
# remove the hook
forward_pre_hook_handle.remove()
value1 = value0 * 2
in1 = fluid.dygraph.to_variable(value1)
out1 = linear(in1)
# hook change the linear's input to input * 2, so out0 is equal to out1.
assert (out0.numpy() == out1.numpy()).any()
"""
hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
return hook_remove_helper
def create_parameter(self, def create_parameter(self,
shape, shape,
attr=None, attr=None,
...@@ -293,6 +415,13 @@ class Layer(core.Layer): ...@@ -293,6 +415,13 @@ class Layer(core.Layer):
pass pass
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
for forward_pre_hook in self._forward_pre_hooks.values():
hook_result = forward_pre_hook(self, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
if not self._built: if not self._built:
with program_desc_tracing_guard(False): with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs) self._build_once(*inputs, **kwargs)
...@@ -302,6 +431,12 @@ class Layer(core.Layer): ...@@ -302,6 +431,12 @@ class Layer(core.Layer):
self._built = True self._built = True
outputs = self.forward(*inputs, **kwargs) outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
if hook_result is not None:
outputs = hook_result
return outputs return outputs
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import six
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.dygraph.base as base
from test_imperative_lod_tensor_to_selected_rows import SimpleNet
call_forward_hook = False
call_forward_pre_hook = False
def forward_hook(layer, input, output):
global call_forward_hook
call_forward_hook = True
def forward_pre_hook(layer, input):
global call_forward_pre_hook
call_forward_pre_hook = True
def forward_hook1(layer, input, output):
return output * 2
def forward_pre_hook1(layer, input):
input_return = (input[0] * 2, input[1])
return input_return
class Test_Forward_Hook(unittest.TestCase):
# test forward_pre_hook and forward_hook that have return value
def test_forward_hook_return_value(self):
seed = 90
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word1 = input_word * 2
input_word = input_word.reshape((-1, 3, 1))
input_word1 = input_word1.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))
input = base.to_variable(input_word)
input1 = base.to_variable(input_word1)
y = base.to_variable(y_data)
simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")
# origin, don't register any hook
outs_origin = simplenet(input, y)
outs_origin1 = simplenet(input1, y)
# register forward_pre_hook
forward_pre_hook_handle1 = simplenet.register_forward_pre_hook(
forward_pre_hook1)
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin1.numpy()))
# remove forward_pre_hook
forward_pre_hook_handle1.remove()
outs_pre_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))
# register forward_hook
forward_hook_handle1 = simplenet.register_forward_post_hook(
forward_hook1)
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2))
# remove forward_hook
forward_hook_handle1.remove()
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))
# test forward_pre_hook and forward_hook that don't have return value
def test_forward_hook(self):
seed = 90
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
global call_forward_hook
global call_forward_pre_hook
input_word = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7,
8]).reshape(6, 3).astype('int64')
input_word = input_word.reshape((-1, 3, 1))
y_data = np.array(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8,
9]).reshape(6, 3).astype('int64')
y_data = y_data.reshape((-1, 1))
input = base.to_variable(input_word)
y = base.to_variable(y_data)
simplenet = SimpleNet(
hidden_size=20,
vocab_size=32,
num_steps=3,
init_scale=0.1,
is_sparse=False,
dtype="float32")
# origin, don't register any hook
outs_origin = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_pre_hook)
# register forward_hook and forward_pre_hook
forward_hook_handle = simplenet.register_forward_post_hook(
forward_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_pre_hook)
# remove forward_hook
forward_hook_handle.remove()
call_forward_hook = False
call_forward_pre_hook = False
outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertTrue(call_forward_pre_hook)
# remove forward_pre_hook
forward_pre_hook_handle.remove()
call_forward_hook = False
call_forward_pre_hook = False
outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_pre_hook)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册