未验证 提交 3b8bcd5a 编写于 作者: W Weilong Wu 提交者: GitHub

Update unit tests by using _test_eager_guard (#40760)

上级 aad0ae2a
......@@ -17,10 +17,11 @@ from __future__ import print_function
import unittest
import paddle.fluid as fluid
import numpy as np
from paddle.fluid.framework import _test_eager_guard
class TestImperativePartitialBackward(unittest.TestCase):
def test_partitial_backward(self):
def func_partitial_backward(self):
with fluid.dygraph.guard():
x = np.random.randn(2, 4, 5).astype("float32")
x = fluid.dygraph.to_variable(x)
......@@ -49,6 +50,11 @@ class TestImperativePartitialBackward(unittest.TestCase):
linear1.clear_gradients()
linear2.clear_gradients()
def test_partitial_backward(self):
with _test_eager_guard():
self.func_partitial_backward()
self.func_partitial_backward()
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,8 @@ import numpy as np
import paddle
import paddle.nn as nn
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode
import paddle.fluid as fluid
import paddle.fluid.core as core
class SimpleNet(nn.Layer):
......@@ -445,8 +447,7 @@ class TestTensorRegisterHook(unittest.TestCase):
self.func_multiple_hooks_for_interior_var()
self.func_multiple_hooks_for_interior_var()
# TODO(wuweilong): enable this case when DoubleGrad in eager mode is ready
def test_hook_in_double_grad(self):
def func_hook_in_double_grad(self):
def double_print_hook(grad):
grad = grad * 2
print(grad)
......@@ -461,10 +462,11 @@ class TestTensorRegisterHook(unittest.TestCase):
x.register_hook(double_print_hook)
y = x * x
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': False})
# Since y = x * x, dx = 2 * x
dx = paddle.grad(
outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0]
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': True})
z = y + dx
self.assertTrue(x.grad is None)
......@@ -475,9 +477,18 @@ class TestTensorRegisterHook(unittest.TestCase):
# x.gradient() = 2 * x + 2 = 4.0
# after changed by hook: 8.0
# TODO(wuweilong): enable this case when DoubleGrad in eager mode is ready
if core._in_eager_mode():
pass
else:
z.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np.array([8.])))
def test_hook_in_double_grad(self):
with _test_eager_guard():
self.func_hook_in_double_grad()
self.func_hook_in_double_grad()
def func_remove_one_hook_multiple_times(self):
for device in self.devices:
paddle.set_device(device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册