diff --git a/python/paddle/fluid/tests/unittests/test_imperative_gnn.py b/python/paddle/fluid/tests/unittests/test_imperative_gnn.py index c813aeede6fe4555ececea9f7e00a479226f2d27..a5a90461551ff868b6d15a5fa3b9de2850cb460a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_gnn.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_gnn.py @@ -23,6 +23,7 @@ import paddle.fluid.core as core from paddle.fluid.optimizer import AdamOptimizer from test_imperative_base import new_program_scope from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.framework import _test_eager_guard def gen_data(): @@ -60,7 +61,7 @@ class GCN(fluid.Layer): class TestDygraphGNN(unittest.TestCase): - def test_gnn_float32(self): + def func_gnn_float32(self): paddle.seed(90) paddle.framework.random._manual_program_seed(90) startup = fluid.Program() @@ -168,6 +169,11 @@ class TestDygraphGNN(unittest.TestCase): self.assertTrue(np.allclose(static_weight, model2_gc_weight_value)) sys.stderr.write('%s %s\n' % (static_loss, loss_value)) + def test_gnn_float32(self): + with _test_eager_guard(): + self.func_gnn_float32() + self.func_gnn_float32() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py index e5453eed136c20f4effe5a8b81292ca9a37f4929..f659d8343543310493b979334f32a2c9f5e19c8d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn_sorted_gradient.py @@ -26,10 +26,11 @@ from test_imperative_base import new_program_scope from test_imperative_ptb_rnn import PtbModel import numpy as np import six +from paddle.fluid.framework import _test_eager_guard class TestDygraphPtbRnnSortGradient(unittest.TestCase): - def test_ptb_rnn_sort_gradient(self): + def func_ptb_rnn_sort_gradient(self): for is_sparse in [True, False]: self.ptb_rnn_sort_gradient_cpu_float32(is_sparse) @@ -171,6 +172,11 @@ class TestDygraphPtbRnnSortGradient(unittest.TestCase): for key, value in six.iteritems(static_param_updated): self.assertTrue(np.array_equal(value, dy_param_updated[key])) + def test_ptb_rnn_sort_gradient(self): + with _test_eager_guard(): + self.func_ptb_rnn_sort_gradient() + self.func_ptb_rnn_sort_gradient() + if __name__ == '__main__': unittest.main()