From dbcf879758db039d68b5c6018b9229f4548e8702 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 2 Mar 2022 10:17:29 +0800 Subject: [PATCH] [Eager] Support gnn ptb_rnn in eager mode (#39993) --- .../paddle/fluid/tests/unittests/test_imperative_gnn.py | 8 +++++++- .../unittests/test_imperative_ptb_rnn_sorted_gradient.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gnn.py b/python/paddle/fluid/tests/unittests/test_imperative_gnn.py index c813aeede6f..a5a90461551 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gnn.py @@ -23,6 +23,7 @@ import paddle.fluid.core as core from paddle.fluid.optimizer import AdamOptimizer from test_imperative_base import new_program_scope from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.framework import _test_eager_guard def gen_data(): @@ -60,7 +61,7 @@ class GCN(fluid.Layer): class TestDygraphGNN(unittest.TestCase): - def test_gnn_float32(self): + def func_gnn_float32(self): paddle.seed(90) paddle.framework.random._manual_program_seed(90) startup = fluid.Program() @@ -168,6 +169,11 @@ class TestDygraphGNN(unittest.TestCase): self.assertTrue(np.allclose(static_weight, model2_gc_weight_value)) sys.stderr.write('%s %s\n' % (static_loss, loss_value)) + def test_gnn_float32(self): + with _test_eager_guard(): + self.func_gnn_float32() + self.func_gnn_float32() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py index e5453eed136..f659d834354 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py @@ -26,10 +26,11 @@ from test_imperative_base import new_program_scope from test_imperative_ptb_rnn import PtbModel import numpy as np import six +from paddle.fluid.framework import _test_eager_guard class TestDygraphPtbRnnSortGradient(unittest.TestCase): - def test_ptb_rnn_sort_gradient(self): + def func_ptb_rnn_sort_gradient(self): for is_sparse in [True, False]: self.ptb_rnn_sort_gradient_cpu_float32(is_sparse) @@ -171,6 +172,11 @@ class TestDygraphPtbRnnSortGradient(unittest.TestCase): for key, value in six.iteritems(static_param_updated): self.assertTrue(np.array_equal(value, dy_param_updated[key])) + def test_ptb_rnn_sort_gradient(self): + with _test_eager_guard(): + self.func_ptb_rnn_sort_gradient() + self.func_ptb_rnn_sort_gradient() + if __name__ == '__main__': unittest.main() -- GitLab