From a909bdf154fa0d44124b73dcdb1c8c4205c83999 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 16 Feb 2022 18:55:42 +0800 Subject: [PATCH] [Eager] Support eager hook_for_layer (#39531) * Update comment * [Eager] Support test_imperative_hook_for_layer with _test_eager_guard() * Polish code name style * Fix a error name * Polish code, make it clear and simple --- python/paddle/fluid/dygraph/layers.py | 2 +- .../test_imperative_hook_for_layer.py | 64 +++++++++++-------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 6a65b3bd9c6..53dbf1a66b2 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -342,7 +342,7 @@ class Layer(object): import paddle import numpy as np - # the forward_post_hook change the input of the layer: input = input * 2 + # the forward_pre_hook change the input of the layer: input = input * 2 def forward_pre_hook(layer, input): # user can use layer and input for information statistis tasks diff --git a/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py index 31735368431..4c457e9345c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_hook_for_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,14 +25,15 @@ import paddle.fluid.core as core import paddle.fluid.dygraph.base as base from test_imperative_lod_tensor_to_selected_rows import SimpleNet +from paddle.fluid.framework import _test_eager_guard -call_forward_hook = False +call_forward_post_hook = False call_forward_pre_hook = False -def forward_hook(layer, input, output): - global call_forward_hook - call_forward_hook = True +def forward_post_hook(layer, input, output): + global call_forward_post_hook + call_forward_post_hook = True def forward_pre_hook(layer, input): @@ -40,7 +41,7 @@ def forward_pre_hook(layer, input): call_forward_pre_hook = True -def forward_hook1(layer, input, output): +def forward_post_hook1(layer, input, output): return output * 2 @@ -50,8 +51,8 @@ def forward_pre_hook1(layer, input): class Test_Forward_Hook(unittest.TestCase): - # test forward_pre_hook and forward_hook that have return value - def test_forward_hook_return_value(self): + # test forward_pre_hook and forward_post_hook that have return value + def func_forward_hook_return_value(self): seed = 90 places = [fluid.CPUPlace()] @@ -104,23 +105,23 @@ class Test_Forward_Hook(unittest.TestCase): self.assertTrue( np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy())) - # register forward_hook - forward_hook_handle1 = simplenet.register_forward_post_hook( - forward_hook1) + # register forward_posst_hook + forward_post_hook_handle1 = simplenet.register_forward_post_hook( + forward_post_hook1) outs_forward_hook = simplenet(input, y) self.assertTrue( np.array_equal(outs_forward_hook.numpy(), outs_origin.numpy() * 2)) - # remove forward_hook - forward_hook_handle1.remove() + # remove forward_post_hook + forward_post_hook_handle1.remove() outs_forward_hook = simplenet(input, y) self.assertTrue( np.array_equal(outs_forward_hook.numpy(), outs_origin.numpy())) - # test forward_pre_hook and forward_hook that don't have return value - def test_forward_hook(self): + # test forward_pre_hook and forward_post_hook that don't have return value + def func_forward_hook(self): seed = 90 places = [fluid.CPUPlace()] @@ -133,7 +134,7 @@ class Test_Forward_Hook(unittest.TestCase): fluid.default_main_program().random_seed = seed fluid.set_flags({'FLAGS_sort_sum_gradient': True}) - global call_forward_hook + global call_forward_post_hook global call_forward_pre_hook input_word = np.array( @@ -158,38 +159,45 @@ class Test_Forward_Hook(unittest.TestCase): # origin, don't register any hook outs_origin = simplenet(input, y) - self.assertFalse(call_forward_hook) + self.assertFalse(call_forward_post_hook) self.assertFalse(call_forward_pre_hook) - # register forward_hook and forward_pre_hook - forward_hook_handle = simplenet.register_forward_post_hook( - forward_hook) + # register forward_post_hook and forward_pre_hook + forward_post_hook_handle = simplenet.register_forward_post_hook( + forward_post_hook) forward_pre_hook_handle = simplenet.register_forward_pre_hook( forward_pre_hook) outs_hook = simplenet(input, y) - self.assertTrue(call_forward_hook) + self.assertTrue(call_forward_post_hook) self.assertTrue(call_forward_pre_hook) outs_hook = simplenet(input, y) - self.assertTrue(call_forward_hook) + self.assertTrue(call_forward_post_hook) self.assertTrue(call_forward_pre_hook) - # remove forward_hook - forward_hook_handle.remove() - call_forward_hook = False + # remove forward_post_hook + forward_post_hook_handle.remove() + call_forward_post_hook = False call_forward_pre_hook = False outs_remove_forward_hook = simplenet(input, y) - self.assertFalse(call_forward_hook) + self.assertFalse(call_forward_post_hook) self.assertTrue(call_forward_pre_hook) # remove forward_pre_hook forward_pre_hook_handle.remove() - call_forward_hook = False + call_forward_post_hook = False call_forward_pre_hook = False outs_remove_hook = simplenet(input, y) - self.assertFalse(call_forward_hook) + self.assertFalse(call_forward_post_hook) self.assertFalse(call_forward_pre_hook) + def test_forward_hook_return_value(self): + with _test_eager_guard(): + self.func_forward_hook() + self.func_forward_hook_return_value() + self.func_forward_hook() + self.func_forward_hook_return_value() + if __name__ == '__main__': unittest.main() -- GitLab