From ef7b6d557ec397bc96219a9b4345b240f3918d4c Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 29 Apr 2021 14:01:36 +0800 Subject: [PATCH] Add fake interface for register_hook in static mode (#32642) (#32660) * add fake interface for hook in static mode * add unittests * fix failed unittests --- python/paddle/fluid/framework.py | 14 +++--- .../fluid/tests/unittests/test_detach.py | 12 +----- .../unittests/test_tensor_register_hook.py | 43 +++++++++++++++++++ 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a280667d03d..0e9d756848a 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -246,11 +246,11 @@ def _static_only_(func): def _fake_interface_only_(func): def __impl__(*args, **kwargs): raise AssertionError( - "'%s' should be called by imperative Varible in imperative mode, please run it in dygraph " - "mode. You can turn off paddle.enable_static() if you are in static mode, or turn off " - "ProgramTranslator if you are using @paddle.jit.to_static. If you have to run ProgramTranslator, " - "please use other API to replace '%s'" % (func.__name__, - func.__name__)) + "'%s' only can be called by `paddle.Tensor` in dynamic graph mode. Suggestions:\n" + " 1. If you are in static graph mode, you can switch to dynamic graph mode by turning off `paddle.enable_static()` or calling `paddle.disable_static()`.\n" + " 2. If you are using `@paddle.jit.to_static`, you can turn off ProgramTranslator by calling `paddle.jit.ProgramTranslator().enable(False)`. " + "If you have to translate dynamic graph to static graph, please use other API to replace '%s'." + % (func.__name__, func.__name__)) return __impl__ @@ -1306,6 +1306,10 @@ class Variable(object): """ pass + @fake_interface_only + def register_hook(self, hook): + pass + def __str__(self): return self._to_readable_code() diff --git a/python/paddle/fluid/tests/unittests/test_detach.py b/python/paddle/fluid/tests/unittests/test_detach.py index 38cdd9b727f..5a31418205c 100644 --- a/python/paddle/fluid/tests/unittests/test_detach.py +++ b/python/paddle/fluid/tests/unittests/test_detach.py @@ -152,18 +152,8 @@ class Test_Detach(unittest.TestCase): def test_detach_exception(self): x = fluid.layers.data(name="a", shape=[3, 4], dtype='float32') y = fluid.layers.fc(input=x, size=10, bias_attr=True) - try: + with self.assertRaises(AssertionError): y_detach = y.detach() - except Exception as e: - # Here is to check - assert type(e) == AssertionError - assert str(e) == ( - "'detach' should be called by imperative Varible " - "in imperative mode, please run it in dygraph mode. You can " - "turn off paddle.enable_static() if you are in static mode, " - "or turn off ProgramTranslator if you are using " - "@paddle.jit.to_static. If you have to run ProgramTranslator, " - "please use other API to replace 'detach'") class TestInplace(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py index a03e4ae4bd9..52256766fed 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_register_hook.py @@ -39,6 +39,21 @@ class SimpleNet(nn.Layer): return ret1, out +class SimpleNetForStatic(nn.Layer): + def __init__(self, in_size, out_size): + super(SimpleNetForStatic, self).__init__() + self.linear1 = nn.Linear(in_size, in_size) + self.linear2 = nn.Linear(in_size, out_size) + + def forward(self, x): + ret1 = self.linear1(x) + ret1.register_hook(lambda grad: grad * 2) + + ret2 = self.linear2(ret1) + out = paddle.mean(ret2, axis=-1) + return out + + class TestTensorRegisterHook(unittest.TestCase): def setUp(self): self.seed = 2021 @@ -451,6 +466,34 @@ class TestTensorRegisterHook(unittest.TestCase): with self.assertRaises(RuntimeError): x.register_hook(lambda grad: grad * 2) + def test_register_hook_in_static_mode(self): + paddle.enable_static() + + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.scope_guard(paddle.static.Scope()): + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + name='x', shape=[None, self.in_size], dtype='float32') + + net = SimpleNetForStatic(self.in_size, self.out_size) + with self.assertRaises(AssertionError): + out = net(x) + + paddle.disable_static() + + def test_register_hook_in_dy2static_mode(self): + net = SimpleNetForStatic(self.in_size, self.out_size) + jit_net = paddle.jit.to_static( + net, input_spec=[paddle.static.InputSpec([None, self.in_size])]) + + data = np.random.uniform( + size=[self.batch_size, self.in_size]).astype('float32') + data_t = paddle.to_tensor(data) + + with self.assertRaises(AssertionError): + out = jit_net(data_t) + HOOK_INIT_VALUE = 10 HOOK_IS_CALLED = False -- GitLab