未验证 提交 a909bdf1 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support eager hook_for_layer (#39531)

* Update comment

* [Eager] Support test_imperative_hook_for_layer with _test_eager_guard()

* Polish code name style

* Fix a error name

* Polish code, make it clear and simple
上级 24b8f63e
...@@ -342,7 +342,7 @@ class Layer(object): ...@@ -342,7 +342,7 @@ class Layer(object):
import paddle import paddle
import numpy as np import numpy as np
# the forward_post_hook change the input of the layer: input = input * 2 # the forward_pre_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input): def forward_pre_hook(layer, input):
# user can use layer and input for information statistis tasks # user can use layer and input for information statistis tasks
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -25,14 +25,15 @@ import paddle.fluid.core as core ...@@ -25,14 +25,15 @@ import paddle.fluid.core as core
import paddle.fluid.dygraph.base as base import paddle.fluid.dygraph.base as base
from test_imperative_lod_tensor_to_selected_rows import SimpleNet from test_imperative_lod_tensor_to_selected_rows import SimpleNet
from paddle.fluid.framework import _test_eager_guard
call_forward_hook = False call_forward_post_hook = False
call_forward_pre_hook = False call_forward_pre_hook = False
def forward_hook(layer, input, output): def forward_post_hook(layer, input, output):
global call_forward_hook global call_forward_post_hook
call_forward_hook = True call_forward_post_hook = True
def forward_pre_hook(layer, input): def forward_pre_hook(layer, input):
...@@ -40,7 +41,7 @@ def forward_pre_hook(layer, input): ...@@ -40,7 +41,7 @@ def forward_pre_hook(layer, input):
call_forward_pre_hook = True call_forward_pre_hook = True
def forward_hook1(layer, input, output): def forward_post_hook1(layer, input, output):
return output * 2 return output * 2
...@@ -50,8 +51,8 @@ def forward_pre_hook1(layer, input): ...@@ -50,8 +51,8 @@ def forward_pre_hook1(layer, input):
class Test_Forward_Hook(unittest.TestCase): class Test_Forward_Hook(unittest.TestCase):
# test forward_pre_hook and forward_hook that have return value # test forward_pre_hook and forward_post_hook that have return value
def test_forward_hook_return_value(self): def func_forward_hook_return_value(self):
seed = 90 seed = 90
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
...@@ -104,23 +105,23 @@ class Test_Forward_Hook(unittest.TestCase): ...@@ -104,23 +105,23 @@ class Test_Forward_Hook(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy())) np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))
# register forward_hook # register forward_posst_hook
forward_hook_handle1 = simplenet.register_forward_post_hook( forward_post_hook_handle1 = simplenet.register_forward_post_hook(
forward_hook1) forward_post_hook1)
outs_forward_hook = simplenet(input, y) outs_forward_hook = simplenet(input, y)
self.assertTrue( self.assertTrue(
np.array_equal(outs_forward_hook.numpy(), np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2)) outs_origin.numpy() * 2))
# remove forward_hook # remove forward_post_hook
forward_hook_handle1.remove() forward_post_hook_handle1.remove()
outs_forward_hook = simplenet(input, y) outs_forward_hook = simplenet(input, y)
self.assertTrue( self.assertTrue(
np.array_equal(outs_forward_hook.numpy(), np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy())) outs_origin.numpy()))
# test forward_pre_hook and forward_hook that don't have return value # test forward_pre_hook and forward_post_hook that don't have return value
def test_forward_hook(self): def func_forward_hook(self):
seed = 90 seed = 90
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
...@@ -133,7 +134,7 @@ class Test_Forward_Hook(unittest.TestCase): ...@@ -133,7 +134,7 @@ class Test_Forward_Hook(unittest.TestCase):
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True}) fluid.set_flags({'FLAGS_sort_sum_gradient': True})
global call_forward_hook global call_forward_post_hook
global call_forward_pre_hook global call_forward_pre_hook
input_word = np.array( input_word = np.array(
...@@ -158,38 +159,45 @@ class Test_Forward_Hook(unittest.TestCase): ...@@ -158,38 +159,45 @@ class Test_Forward_Hook(unittest.TestCase):
# origin, don't register any hook # origin, don't register any hook
outs_origin = simplenet(input, y) outs_origin = simplenet(input, y)
self.assertFalse(call_forward_hook) self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook) self.assertFalse(call_forward_pre_hook)
# register forward_hook and forward_pre_hook # register forward_post_hook and forward_pre_hook
forward_hook_handle = simplenet.register_forward_post_hook( forward_post_hook_handle = simplenet.register_forward_post_hook(
forward_hook) forward_post_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook( forward_pre_hook_handle = simplenet.register_forward_pre_hook(
forward_pre_hook) forward_pre_hook)
outs_hook = simplenet(input, y) outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook) self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook) self.assertTrue(call_forward_pre_hook)
outs_hook = simplenet(input, y) outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook) self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook) self.assertTrue(call_forward_pre_hook)
# remove forward_hook # remove forward_post_hook
forward_hook_handle.remove() forward_post_hook_handle.remove()
call_forward_hook = False call_forward_post_hook = False
call_forward_pre_hook = False call_forward_pre_hook = False
outs_remove_forward_hook = simplenet(input, y) outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_hook) self.assertFalse(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook) self.assertTrue(call_forward_pre_hook)
# remove forward_pre_hook # remove forward_pre_hook
forward_pre_hook_handle.remove() forward_pre_hook_handle.remove()
call_forward_hook = False call_forward_post_hook = False
call_forward_pre_hook = False call_forward_pre_hook = False
outs_remove_hook = simplenet(input, y) outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_hook) self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook) self.assertFalse(call_forward_pre_hook)
def test_forward_hook_return_value(self):
with _test_eager_guard():
self.func_forward_hook()
self.func_forward_hook_return_value()
self.func_forward_hook()
self.func_forward_hook_return_value()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册