From 166a1ae9021c27a605175adbe4c2b7356b3b34c9 Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Fri, 3 Apr 2020 14:50:59 +0800 Subject: [PATCH] support forward hook for dygraph (#22443) * support forward hook for dygraph, test=develop * add optest for forward_hook in dygraph, test=develop * add optest, test=develop * polish code, test=develop * add sample code, test=develop * rename forwrd_hook to forward_post_hook, test=develop * fix the api description, test=develop * fix api description, test=develop --- python/paddle/fluid/dygraph/layers.py | 135 ++++++++++++ .../test_imperative_hook_for_layer.py | 197 ++++++++++++++++++ 2 files changed, 332 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 58113808a98..94ac86b7741 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -27,6 +27,7 @@ from .base import program_desc_tracing_guard from paddle.fluid import framework from ..param_attr import ParamAttr import copy +import weakref import warnings __all__ = ['Layer'] @@ -40,6 +41,22 @@ def _convert_camel_to_snake(name): return _all_cap_re.sub(r'\1_\2', s1).lower() +class HookRemoveHelper(object): + """ A HookRemoveHelper that can be used to remove hook. """ + + next_hook_id = 0 + + def __init__(self, hooks): + self._hooks_ref = weakref.ref(hooks) + self._hook_id = HookRemoveHelper.next_hook_id + HookRemoveHelper.next_hook_id += 1 + + def remove(self): + hooks = self._hooks_ref() + if hooks is not None and self._hook_id in hooks: + del hooks[self._hook_id] + + class Layer(core.Layer): """Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on. @@ -70,6 +87,9 @@ class Layer(core.Layer): self._sub_layers = collections.OrderedDict() self._loaddict_holder = collections.OrderedDict() + self._forward_pre_hooks = collections.OrderedDict() + self._forward_post_hooks = collections.OrderedDict() + def train(self): framework._dygraph_tracer().train_mode() @@ -84,6 +104,108 @@ class Layer(core.Layer): """ return self._full_name + def register_forward_post_hook(self, hook): + """Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed. + + It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively. + User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer. + + hook(Layer, input, output) -> None or modified output + + Parameters: + hook(function): a function registered as a forward post-hook + + Returns: + HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` . + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + # the forward_post_hook change the output of the layer: output = output * 2 + def forward_post_hook(layer, input, output): + # user can use layer, input and output for information statistis tasks + + # change the output + return output * 2 + + with fluid.dygraph.guard(): + linear = fluid.Linear(13, 5, dtype="float32") + + # register the hook + forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook) + + value = np.arange(26).reshape(2, 13).astype("float32") + in = fluid.dygraph.to_variable(value0) + + out0 = linear(in) + + # remove the hook + forward_post_hook_handle.remove() + + out1 = linear(in) + + # hook change the linear's output to output * 2, so out0 is equal to out1 * 2. + assert (out0.numpy() == (out1.numpy()) * 2).any() + """ + hook_remove_helper = HookRemoveHelper(self._forward_post_hooks) + self._forward_post_hooks[hook_remove_helper._hook_id] = hook + return hook_remove_helper + + def register_forward_pre_hook(self, hook): + """Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed. + + It should have the following form, `input` of the `hook` is `input` of the `Layer`, + hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if + a single value is returned(unless that value is already a tuple). + User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer. + + hook(Layer, input) -> None or modified input + + Parameters: + hook(function): a function registered as a forward pre-hook + + Returns: + HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` . + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + # the forward_post_hook change the input of the layer: input = input * 2 + def forward_pre_hook(layer, input): + # user can use layer and input for information statistis tasks + + # change the input + input_return = (input[0] * 2) + return input_return + + with fluid.dygraph.guard(): + linear = fluid.Linear(13, 5, dtype="float32") + + # register the hook + forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook) + + value0 = np.arange(26).reshape(2, 13).astype("float32") + in0 = fluid.dygraph.to_variable(value0) + out0 = linear(in0) + + # remove the hook + forward_pre_hook_handle.remove() + + value1 = value0 * 2 + in1 = fluid.dygraph.to_variable(value1) + out1 = linear(in1) + + # hook change the linear's input to input * 2, so out0 is equal to out1. + assert (out0.numpy() == out1.numpy()).any() + """ + hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks) + self._forward_pre_hooks[hook_remove_helper._hook_id] = hook + return hook_remove_helper + def create_parameter(self, shape, attr=None, @@ -293,6 +415,13 @@ class Layer(core.Layer): pass def __call__(self, *inputs, **kwargs): + for forward_pre_hook in self._forward_pre_hooks.values(): + hook_result = forward_pre_hook(self, inputs) + if hook_result is not None: + if not isinstance(hook_result, tuple): + hook_result = (hook_result, ) + inputs = hook_result + if not self._built: with program_desc_tracing_guard(False): self._build_once(*inputs, **kwargs) @@ -302,6 +431,12 @@ class Layer(core.Layer): self._built = True outputs = self.forward(*inputs, **kwargs) + + for forward_post_hook in self._forward_post_hooks.values(): + hook_result = forward_post_hook(self, inputs, outputs) + if hook_result is not None: + outputs = hook_result + return outputs def forward(self, *inputs, **kwargs): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py new file mode 100644 index 00000000000..4fe4d963ca5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py @@ -0,0 +1,197 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np +import six + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.dygraph.base as base + +from test_imperative_lod_tensor_to_selected_rows import SimpleNet + +call_forward_hook = False +call_forward_pre_hook = False + + +def forward_hook(layer, input, output): + global call_forward_hook + call_forward_hook = True + + +def forward_pre_hook(layer, input): + global call_forward_pre_hook + call_forward_pre_hook = True + + +def forward_hook1(layer, input, output): + return output * 2 + + +def forward_pre_hook1(layer, input): + input_return = (input[0] * 2, input[1]) + return input_return + + +class Test_Forward_Hook(unittest.TestCase): + # test forward_pre_hook and forward_hook that have return value + def test_forward_hook_return_value(self): + seed = 90 + + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + + input_word = np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, + 8]).reshape(6, 3).astype('int64') + input_word1 = input_word * 2 + input_word = input_word.reshape((-1, 3, 1)) + input_word1 = input_word1.reshape((-1, 3, 1)) + y_data = np.array( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, + 9]).reshape(6, 3).astype('int64') + y_data = y_data.reshape((-1, 1)) + + input = base.to_variable(input_word) + input1 = base.to_variable(input_word1) + y = base.to_variable(y_data) + + simplenet = SimpleNet( + hidden_size=20, + vocab_size=32, + num_steps=3, + init_scale=0.1, + is_sparse=False, + dtype="float32") + + # origin, don't register any hook + outs_origin = simplenet(input, y) + outs_origin1 = simplenet(input1, y) + + # register forward_pre_hook + forward_pre_hook_handle1 = simplenet.register_forward_pre_hook( + forward_pre_hook1) + outs_pre_hook = simplenet(input, y) + self.assertTrue( + np.array_equal(outs_pre_hook.numpy(), outs_origin1.numpy())) + + # remove forward_pre_hook + forward_pre_hook_handle1.remove() + outs_pre_hook = simplenet(input, y) + self.assertTrue( + np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy())) + + # register forward_hook + forward_hook_handle1 = simplenet.register_forward_post_hook( + forward_hook1) + outs_forward_hook = simplenet(input, y) + self.assertTrue( + np.array_equal(outs_forward_hook.numpy(), + outs_origin.numpy() * 2)) + + # remove forward_hook + forward_hook_handle1.remove() + outs_forward_hook = simplenet(input, y) + self.assertTrue( + np.array_equal(outs_forward_hook.numpy(), + outs_origin.numpy())) + + # test forward_pre_hook and forward_hook that don't have return value + def test_forward_hook(self): + seed = 90 + + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + + global call_forward_hook + global call_forward_pre_hook + + input_word = np.array( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, + 8]).reshape(6, 3).astype('int64') + input_word = input_word.reshape((-1, 3, 1)) + y_data = np.array( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, + 9]).reshape(6, 3).astype('int64') + y_data = y_data.reshape((-1, 1)) + + input = base.to_variable(input_word) + y = base.to_variable(y_data) + + simplenet = SimpleNet( + hidden_size=20, + vocab_size=32, + num_steps=3, + init_scale=0.1, + is_sparse=False, + dtype="float32") + + # origin, don't register any hook + outs_origin = simplenet(input, y) + self.assertFalse(call_forward_hook) + self.assertFalse(call_forward_pre_hook) + + # register forward_hook and forward_pre_hook + forward_hook_handle = simplenet.register_forward_post_hook( + forward_hook) + forward_pre_hook_handle = simplenet.register_forward_pre_hook( + forward_pre_hook) + outs_hook = simplenet(input, y) + self.assertTrue(call_forward_hook) + self.assertTrue(call_forward_pre_hook) + + outs_hook = simplenet(input, y) + self.assertTrue(call_forward_hook) + self.assertTrue(call_forward_pre_hook) + + # remove forward_hook + forward_hook_handle.remove() + call_forward_hook = False + call_forward_pre_hook = False + outs_remove_forward_hook = simplenet(input, y) + self.assertFalse(call_forward_hook) + self.assertTrue(call_forward_pre_hook) + + # remove forward_pre_hook + forward_pre_hook_handle.remove() + call_forward_hook = False + call_forward_pre_hook = False + outs_remove_hook = simplenet(input, y) + self.assertFalse(call_forward_hook) + self.assertFalse(call_forward_pre_hook) + + +if __name__ == '__main__': + unittest.main() -- GitLab