未验证 提交 76103c88 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Added Eager Dygraph support for user_defined_grads (#39309)

上级 75923a32
......@@ -30,6 +30,7 @@ from copy import copy
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.framework import _in_eager_mode
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.op import Operator
......@@ -1831,6 +1832,16 @@ class OpTest(unittest.TestCase):
for no_grad_val in no_grad_set:
del (inputs[no_grad_val])
if _in_eager_mode():
core.eager.run_backward(
fluid.layers.utils.flatten(outputs), grad_outputs,
False)
grad_inputs = []
for inputs_list in inputs.values():
for inp in inputs_list:
grad_inputs.append(inp.grad.numpy())
return grad_inputs
else:
grad_inputs = paddle.grad(
outputs=fluid.layers.utils.flatten(outputs),
inputs=fluid.layers.utils.flatten(inputs),
......
......@@ -21,6 +21,7 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
class TestDiagV2Op(OpTest):
......@@ -239,6 +240,9 @@ class TestDiagV2API(unittest.TestCase):
def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace())
self.run_imperative()
with _test_eager_guard():
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
......@@ -250,6 +254,8 @@ class TestDiagV2API(unittest.TestCase):
paddle.disable_static(place=paddle.fluid.CUDAPlace(0))
self.run_imperative()
with _test_eager_guard():
self.run_imperative()
paddle.enable_static()
with fluid.program_guard(fluid.Program()):
......
......@@ -22,6 +22,7 @@ import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.tensor as tensor
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static()
......@@ -33,10 +34,10 @@ class TestDiagonalOp(OpTest):
self.outputs = {'Out': self.target}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['Input'], 'Out')
self.check_grad(['Input'], 'Out', check_eager=True)
def init_config(self):
self.case = np.random.randn(10, 5, 2).astype('float64')
......@@ -79,7 +80,8 @@ class TestDiagonalOpCase2(TestDiagonalOp):
['Input'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
user_defined_grad_outputs=[self.grad_out],
check_eager=True)
class TestDiagonalOpCase3(TestDiagonalOp):
......@@ -122,6 +124,10 @@ class TestDiagonalAPI(unittest.TestCase):
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True)
paddle.enable_static()
def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_dygraph()
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,7 @@ import paddle
import paddle.fluid as fluid
import paddle.static as static
from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard
class TestDigammaOp(OpTest):
......@@ -94,6 +95,10 @@ class TestDigammaAPI(unittest.TestCase):
res = paddle.digamma(input_t).numpy()
self.assertEqual(np.allclose(res, sc_res, rtol=1e-05), True)
def test_in_eager_dynamic_mode(self):
with _test_eager_guard():
self.test_in_dynamic_mode()
def test_name_argument(self):
with static.program_guard(static.Program()):
x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0])
......@@ -114,6 +119,13 @@ class TestDigammaAPI(unittest.TestCase):
input_t = paddle.to_tensor(input)
res = paddle.digamma(input_t)
with self.assertRaises(RuntimeError):
with fluid.dygraph.guard():
with _test_eager_guard():
input = np.random.random(self._shape).astype("int32")
input_t = paddle.to_tensor(input)
res = paddle.digamma(input_t)
if __name__ == "__main__":
unittest.main()
......@@ -21,6 +21,7 @@ import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static()
......@@ -78,6 +79,10 @@ class TestTruncAPI(unittest.TestCase):
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True)
paddle.enable_static()
def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_dygraph()
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [20, 20], 'bool')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册