未验证 提交 a909bdf1 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support eager hook_for_layer (#39531)

* Update comment

* [Eager] Support test_imperative_hook_for_layer with _test_eager_guard()

* Polish code name style

* Fix a error name

* Polish code, make it clear and simple
上级 24b8f63e
......@@ -342,7 +342,7 @@ class Layer(object):
import paddle
import numpy as np
# the forward_post_hook change the input of the layer: input = input * 2
# the forward_pre_hook change the input of the layer: input = input * 2
def forward_pre_hook(layer, input):
# user can use layer and input for information statistis tasks
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -25,14 +25,15 @@ import paddle.fluid.core as core
import paddle.fluid.dygraph.base as base
from test_imperative_lod_tensor_to_selected_rows import SimpleNet
from paddle.fluid.framework import _test_eager_guard
call_forward_hook = False
call_forward_post_hook = False
call_forward_pre_hook = False
def forward_hook(layer, input, output):
global call_forward_hook
call_forward_hook = True
def forward_post_hook(layer, input, output):
global call_forward_post_hook
call_forward_post_hook = True
def forward_pre_hook(layer, input):
......@@ -40,7 +41,7 @@ def forward_pre_hook(layer, input):
call_forward_pre_hook = True
def forward_hook1(layer, input, output):
def forward_post_hook1(layer, input, output):
return output * 2
......@@ -50,8 +51,8 @@ def forward_pre_hook1(layer, input):
class Test_Forward_Hook(unittest.TestCase):
# test forward_pre_hook and forward_hook that have return value
def test_forward_hook_return_value(self):
# test forward_pre_hook and forward_post_hook that have return value
def func_forward_hook_return_value(self):
seed = 90
places = [fluid.CPUPlace()]
......@@ -104,23 +105,23 @@ class Test_Forward_Hook(unittest.TestCase):
self.assertTrue(
np.array_equal(outs_pre_hook.numpy(), outs_origin.numpy()))
# register forward_hook
forward_hook_handle1 = simplenet.register_forward_post_hook(
forward_hook1)
# register forward_posst_hook
forward_post_hook_handle1 = simplenet.register_forward_post_hook(
forward_post_hook1)
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy() * 2))
# remove forward_hook
forward_hook_handle1.remove()
# remove forward_post_hook
forward_post_hook_handle1.remove()
outs_forward_hook = simplenet(input, y)
self.assertTrue(
np.array_equal(outs_forward_hook.numpy(),
outs_origin.numpy()))
# test forward_pre_hook and forward_hook that don't have return value
def test_forward_hook(self):
# test forward_pre_hook and forward_post_hook that don't have return value
def func_forward_hook(self):
seed = 90
places = [fluid.CPUPlace()]
......@@ -133,7 +134,7 @@ class Test_Forward_Hook(unittest.TestCase):
fluid.default_main_program().random_seed = seed
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
global call_forward_hook
global call_forward_post_hook
global call_forward_pre_hook
input_word = np.array(
......@@ -158,38 +159,45 @@ class Test_Forward_Hook(unittest.TestCase):
# origin, don't register any hook
outs_origin = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook)
# register forward_hook and forward_pre_hook
forward_hook_handle = simplenet.register_forward_post_hook(
forward_hook)
# register forward_post_hook and forward_pre_hook
forward_post_hook_handle = simplenet.register_forward_post_hook(
forward_post_hook)
forward_pre_hook_handle = simplenet.register_forward_pre_hook(
forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)
outs_hook = simplenet(input, y)
self.assertTrue(call_forward_hook)
self.assertTrue(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)
# remove forward_hook
forward_hook_handle.remove()
call_forward_hook = False
# remove forward_post_hook
forward_post_hook_handle.remove()
call_forward_post_hook = False
call_forward_pre_hook = False
outs_remove_forward_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertTrue(call_forward_pre_hook)
# remove forward_pre_hook
forward_pre_hook_handle.remove()
call_forward_hook = False
call_forward_post_hook = False
call_forward_pre_hook = False
outs_remove_hook = simplenet(input, y)
self.assertFalse(call_forward_hook)
self.assertFalse(call_forward_post_hook)
self.assertFalse(call_forward_pre_hook)
def test_forward_hook_return_value(self):
with _test_eager_guard():
self.func_forward_hook()
self.func_forward_hook_return_value()
self.func_forward_hook()
self.func_forward_hook_return_value()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册