diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 4368ef69f4a2b253de3b22c9530d0c1b49119263..ae74fbd1c1e0956332d83baec998d54a2ac465da 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -52,6 +52,11 @@ from paddle.fluid.tests.unittests.white_list import ( no_grad_set_white_list, ) from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs +# For switch new eager mode globally +g_is_in_eager = _in_eager_without_dygraph_check() +g_enable_legacy_dygraph = _enable_legacy_dygraph if g_is_in_eager else lambda: None +g_disable_legacy_dygraph = _disable_legacy_dygraph if g_is_in_eager else lambda: None + def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs): """ @@ -1583,9 +1588,14 @@ class OpTest(unittest.TestCase): static_checker.check() outs, fetch_list = static_checker.outputs, static_checker.fetch_list if check_dygraph: + # always enable legacy dygraph + g_enable_legacy_dygraph() + dygraph_checker = DygraphChecker(self, self.outputs) dygraph_checker.check() dygraph_outs = dygraph_checker.outputs + # yield the original state + g_disable_legacy_dygraph() if check_eager: eager_checker = EagerChecker(self, self.outputs) eager_checker.check()