diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py index 010c8aeccacd6550a5b8963c63f3d28af898550e..531c89fb19ec6a5817a940e968345bcb33c71e00 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -21,7 +21,7 @@ from paddle.fluid import Embedding, LayerNorm, Linear, Layer from paddle.fluid.dygraph import to_variable, guard from paddle.fluid.dygraph import TracedLayer from test_imperative_base import new_program_scope -from paddle.fluid.framework import _test_eager_guard +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode, _in_legacy_dygraph from paddle.fluid import core import numpy as np import six @@ -1041,8 +1041,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): with guard(): fluid.set_flags({'FLAGS_sort_sum_gradient': True}) - dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \ - dy_param_init, dy_param_updated = run_dygraph() + if _in_legacy_dygraph(): + dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \ + dy_param_init, dy_param_updated = run_dygraph() with new_program_scope(): paddle.seed(seed) @@ -1116,21 +1117,22 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): for k in range(4, len(out)): static_param_updated[static_param_name_list[k - 4]] = out[k] - - self.assertTrue( - np.array_equal(static_avg_cost_value, dy_avg_cost_value)) - self.assertTrue( - np.array_equal(static_sum_cost_value, dy_sum_cost_value)) - self.assertTrue(np.array_equal(static_predict_value, dy_predict_value)) - self.assertTrue( - np.array_equal(static_token_num_value, dy_token_num_value)) - - for key, value in six.iteritems(static_param_init): - self.assertTrue(np.array_equal(value, dy_param_init[key])) - for key, value in six.iteritems(static_param_updated): - self.assertTrue(np.array_equal(value, dy_param_updated[key])) - - # check eager result + if _in_legacy_dygraph(): + self.assertTrue( + np.array_equal(static_avg_cost_value, dy_avg_cost_value)) + self.assertTrue( + np.array_equal(static_sum_cost_value, dy_sum_cost_value)) + self.assertTrue( + np.array_equal(static_predict_value, dy_predict_value)) + self.assertTrue( + np.array_equal(static_token_num_value, dy_token_num_value)) + + for key, value in six.iteritems(static_param_init): + self.assertTrue(np.array_equal(value, dy_param_init[key])) + for key, value in six.iteritems(static_param_updated): + self.assertTrue(np.array_equal(value, dy_param_updated[key])) + + # compare eager result with imperative result with guard(): fluid.set_flags({'FLAGS_sort_sum_gradient': False}) dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \