未验证 提交 76103c88 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Added Eager Dygraph support for user_defined_grads (#39309)

上级 75923a32
...@@ -30,6 +30,7 @@ from copy import copy ...@@ -30,6 +30,7 @@ from copy import copy
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import _in_eager_mode
from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
...@@ -1831,6 +1832,16 @@ class OpTest(unittest.TestCase): ...@@ -1831,6 +1832,16 @@ class OpTest(unittest.TestCase):
for no_grad_val in no_grad_set: for no_grad_val in no_grad_set:
del (inputs[no_grad_val]) del (inputs[no_grad_val])
if _in_eager_mode():
core.eager.run_backward(
fluid.layers.utils.flatten(outputs), grad_outputs,
False)
grad_inputs = []
for inputs_list in inputs.values():
for inp in inputs_list:
grad_inputs.append(inp.grad.numpy())
return grad_inputs
else:
grad_inputs = paddle.grad( grad_inputs = paddle.grad(
outputs=fluid.layers.utils.flatten(outputs), outputs=fluid.layers.utils.flatten(outputs),
inputs=fluid.layers.utils.flatten(inputs), inputs=fluid.layers.utils.flatten(inputs),
......
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
class TestDiagV2Op(OpTest): class TestDiagV2Op(OpTest):
...@@ -239,6 +240,9 @@ class TestDiagV2API(unittest.TestCase): ...@@ -239,6 +240,9 @@ class TestDiagV2API(unittest.TestCase):
def test_cpu(self): def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace()) paddle.disable_static(place=paddle.fluid.CPUPlace())
self.run_imperative() self.run_imperative()
with _test_eager_guard():
self.run_imperative()
paddle.enable_static() paddle.enable_static()
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
...@@ -250,6 +254,8 @@ class TestDiagV2API(unittest.TestCase): ...@@ -250,6 +254,8 @@ class TestDiagV2API(unittest.TestCase):
paddle.disable_static(place=paddle.fluid.CUDAPlace(0)) paddle.disable_static(place=paddle.fluid.CUDAPlace(0))
self.run_imperative() self.run_imperative()
with _test_eager_guard():
self.run_imperative()
paddle.enable_static() paddle.enable_static()
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
......
...@@ -22,6 +22,7 @@ import paddle.nn.functional as F ...@@ -22,6 +22,7 @@ import paddle.nn.functional as F
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.tensor as tensor import paddle.tensor as tensor
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static() paddle.enable_static()
...@@ -33,10 +34,10 @@ class TestDiagonalOp(OpTest): ...@@ -33,10 +34,10 @@ class TestDiagonalOp(OpTest):
self.outputs = {'Out': self.target} self.outputs = {'Out': self.target}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input'], 'Out') self.check_grad(['Input'], 'Out', check_eager=True)
def init_config(self): def init_config(self):
self.case = np.random.randn(10, 5, 2).astype('float64') self.case = np.random.randn(10, 5, 2).astype('float64')
...@@ -79,7 +80,8 @@ class TestDiagonalOpCase2(TestDiagonalOp): ...@@ -79,7 +80,8 @@ class TestDiagonalOpCase2(TestDiagonalOp):
['Input'], ['Input'],
'Out', 'Out',
user_defined_grads=[self.grad_x], user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out]) user_defined_grad_outputs=[self.grad_out],
check_eager=True)
class TestDiagonalOpCase3(TestDiagonalOp): class TestDiagonalOpCase3(TestDiagonalOp):
...@@ -122,6 +124,10 @@ class TestDiagonalAPI(unittest.TestCase): ...@@ -122,6 +124,10 @@ class TestDiagonalAPI(unittest.TestCase):
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True)
paddle.enable_static() paddle.enable_static()
def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_dygraph()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,6 +20,7 @@ import paddle ...@@ -20,6 +20,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.static as static import paddle.static as static
from op_test import OpTest from op_test import OpTest
from paddle.fluid.framework import _test_eager_guard
class TestDigammaOp(OpTest): class TestDigammaOp(OpTest):
...@@ -94,6 +95,10 @@ class TestDigammaAPI(unittest.TestCase): ...@@ -94,6 +95,10 @@ class TestDigammaAPI(unittest.TestCase):
res = paddle.digamma(input_t).numpy() res = paddle.digamma(input_t).numpy()
self.assertEqual(np.allclose(res, sc_res, rtol=1e-05), True) self.assertEqual(np.allclose(res, sc_res, rtol=1e-05), True)
def test_in_eager_dynamic_mode(self):
with _test_eager_guard():
self.test_in_dynamic_mode()
def test_name_argument(self): def test_name_argument(self):
with static.program_guard(static.Program()): with static.program_guard(static.Program()):
x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0]) x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0])
...@@ -114,6 +119,13 @@ class TestDigammaAPI(unittest.TestCase): ...@@ -114,6 +119,13 @@ class TestDigammaAPI(unittest.TestCase):
input_t = paddle.to_tensor(input) input_t = paddle.to_tensor(input)
res = paddle.digamma(input_t) res = paddle.digamma(input_t)
with self.assertRaises(RuntimeError):
with fluid.dygraph.guard():
with _test_eager_guard():
input = np.random.random(self._shape).astype("int32")
input_t = paddle.to_tensor(input)
res = paddle.digamma(input_t)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static() paddle.enable_static()
...@@ -78,6 +79,10 @@ class TestTruncAPI(unittest.TestCase): ...@@ -78,6 +79,10 @@ class TestTruncAPI(unittest.TestCase):
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True)
paddle.enable_static() paddle.enable_static()
def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_dygraph()
def test_errors(self): def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [20, 20], 'bool') x = paddle.fluid.data('X', [20, 20], 'bool')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册