未验证 提交 facda828 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager grad] Refactor partial grad logic (#40693)

* Refactor partial_grad/backward logic

* Add DuplicateCheck and polish code

* Refactor partial_grad/backward more clearly

* Refactor GeneralGrad by SingleInstance
上级 cc853e95
此差异已折叠。
...@@ -116,6 +116,54 @@ class TestEagerGrad(TestCase): ...@@ -116,6 +116,54 @@ class TestEagerGrad(TestCase):
self.func_simple_example_eager_grad_not_allow_unused() self.func_simple_example_eager_grad_not_allow_unused()
self.func_simple_example_eager_grad_not_allow_unused() self.func_simple_example_eager_grad_not_allow_unused()
def func_simple_example_eager_grad_duplicate_input(self):
np.random.seed(2021)
paddle.set_device('cpu')
np_x = np.random.random((3, 3))
np_y = np.random.random((3, 1))
np_z = np.random.random((3, 1))
x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False)
y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False)
z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False)
out_z = paddle.nn.functional.sigmoid(z)
out = paddle.matmul(x, y)
try:
# duplicate input will arise RuntimeError errors
dx = fluid.dygraph.grad(out, [x, x])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_input(self):
with _test_eager_guard():
self.func_simple_example_eager_grad_duplicate_input()
self.func_simple_example_eager_grad_duplicate_input()
def func_simple_example_eager_grad_duplicate_output(self):
np.random.seed(2021)
paddle.set_device('cpu')
np_x = np.random.random((3, 3))
np_y = np.random.random((3, 1))
np_z = np.random.random((3, 1))
x = paddle.to_tensor(np_x, dtype="float64", stop_gradient=False)
y = paddle.to_tensor(np_y, dtype="float64", stop_gradient=False)
z = paddle.to_tensor(np_z, dtype="float64", stop_gradient=False)
out_z = paddle.nn.functional.sigmoid(z)
out = paddle.matmul(x, y)
try:
# duplicate output will arise RuntimeError errors
dx = fluid.dygraph.grad([out, out], [x])
except RuntimeError as e:
error_msg = cpt.get_exception_message(e)
assert error_msg.find("duplicate") > 0
def test_simple_example_eager_grad_duplicate_output(self):
with _test_eager_guard():
self.func_simple_example_eager_grad_duplicate_output()
self.func_simple_example_eager_grad_duplicate_output()
class TestDygraphDoubleGrad(TestCase): class TestDygraphDoubleGrad(TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册