未验证 提交 9aad7527 编写于 作者: C Chen Weihang 提交者: GitHub

Add fake interface for register_hook in static mode (#32642)

* add fake interface for hook in static mode

* add unittests

* fix failed unittests
上级 abcb3f54
...@@ -246,11 +246,11 @@ def _static_only_(func): ...@@ -246,11 +246,11 @@ def _static_only_(func):
def _fake_interface_only_(func): def _fake_interface_only_(func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
raise AssertionError( raise AssertionError(
"'%s' should be called by imperative Varible in imperative mode, please run it in dygraph " "'%s' only can be called by `paddle.Tensor` in dynamic graph mode. Suggestions:\n"
"mode. You can turn off paddle.enable_static() if you are in static mode, or turn off " " 1. If you are in static graph mode, you can switch to dynamic graph mode by turning off `paddle.enable_static()` or calling `paddle.disable_static()`.\n"
"ProgramTranslator if you are using @paddle.jit.to_static. If you have to run ProgramTranslator, " " 2. If you are using `@paddle.jit.to_static`, you can turn off ProgramTranslator by calling `paddle.jit.ProgramTranslator().enable(False)`. "
"please use other API to replace '%s'" % (func.__name__, "If you have to translate dynamic graph to static graph, please use other API to replace '%s'."
func.__name__)) % (func.__name__, func.__name__))
return __impl__ return __impl__
...@@ -1306,6 +1306,10 @@ class Variable(object): ...@@ -1306,6 +1306,10 @@ class Variable(object):
""" """
pass pass
@fake_interface_only
def register_hook(self, hook):
pass
def __str__(self): def __str__(self):
return self._to_readable_code() return self._to_readable_code()
......
...@@ -152,18 +152,8 @@ class Test_Detach(unittest.TestCase): ...@@ -152,18 +152,8 @@ class Test_Detach(unittest.TestCase):
def test_detach_exception(self): def test_detach_exception(self):
x = fluid.layers.data(name="a", shape=[3, 4], dtype='float32') x = fluid.layers.data(name="a", shape=[3, 4], dtype='float32')
y = fluid.layers.fc(input=x, size=10, bias_attr=True) y = fluid.layers.fc(input=x, size=10, bias_attr=True)
try: with self.assertRaises(AssertionError):
y_detach = y.detach() y_detach = y.detach()
except Exception as e:
# Here is to check
assert type(e) == AssertionError
assert str(e) == (
"'detach' should be called by imperative Varible "
"in imperative mode, please run it in dygraph mode. You can "
"turn off paddle.enable_static() if you are in static mode, "
"or turn off ProgramTranslator if you are using "
"@paddle.jit.to_static. If you have to run ProgramTranslator, "
"please use other API to replace 'detach'")
class TestInplace(unittest.TestCase): class TestInplace(unittest.TestCase):
......
...@@ -39,6 +39,21 @@ class SimpleNet(nn.Layer): ...@@ -39,6 +39,21 @@ class SimpleNet(nn.Layer):
return ret1, out return ret1, out
class SimpleNetForStatic(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleNetForStatic, self).__init__()
self.linear1 = nn.Linear(in_size, in_size)
self.linear2 = nn.Linear(in_size, out_size)
def forward(self, x):
ret1 = self.linear1(x)
ret1.register_hook(lambda grad: grad * 2)
ret2 = self.linear2(ret1)
out = paddle.mean(ret2, axis=-1)
return out
class TestTensorRegisterHook(unittest.TestCase): class TestTensorRegisterHook(unittest.TestCase):
def setUp(self): def setUp(self):
self.seed = 2021 self.seed = 2021
...@@ -451,6 +466,34 @@ class TestTensorRegisterHook(unittest.TestCase): ...@@ -451,6 +466,34 @@ class TestTensorRegisterHook(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
x.register_hook(lambda grad: grad * 2) x.register_hook(lambda grad: grad * 2)
def test_register_hook_in_static_mode(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.scope_guard(paddle.static.Scope()):
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
name='x', shape=[None, self.in_size], dtype='float32')
net = SimpleNetForStatic(self.in_size, self.out_size)
with self.assertRaises(AssertionError):
out = net(x)
paddle.disable_static()
def test_register_hook_in_dy2static_mode(self):
net = SimpleNetForStatic(self.in_size, self.out_size)
jit_net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec([None, self.in_size])])
data = np.random.uniform(
size=[self.batch_size, self.in_size]).astype('float32')
data_t = paddle.to_tensor(data)
with self.assertRaises(AssertionError):
out = jit_net(data_t)
HOOK_INIT_VALUE = 10 HOOK_INIT_VALUE = 10
HOOK_IS_CALLED = False HOOK_IS_CALLED = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册