From 922e076e5d1cf925c7cd7f47478af4b788010094 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 30 Mar 2022 12:07:44 +0800 Subject: [PATCH] [Eager] Fix legacy always make sense (#41048) --- python/paddle/fluid/tests/unittests/op_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 4368ef69f4..ae74fbd1c1 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -52,6 +52,11 @@ from paddle.fluid.tests.unittests.white_list import ( no_grad_set_white_list, ) from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs +# For switch new eager mode globally +g_is_in_eager = _in_eager_without_dygraph_check() +g_enable_legacy_dygraph = _enable_legacy_dygraph if g_is_in_eager else lambda: None +g_disable_legacy_dygraph = _disable_legacy_dygraph if g_is_in_eager else lambda: None + def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs): """ @@ -1583,9 +1588,14 @@ class OpTest(unittest.TestCase): static_checker.check() outs, fetch_list = static_checker.outputs, static_checker.fetch_list if check_dygraph: + # always enable legacy dygraph + g_enable_legacy_dygraph() + dygraph_checker = DygraphChecker(self, self.outputs) dygraph_checker.check() dygraph_outs = dygraph_checker.outputs + # yield the original state + g_disable_legacy_dygraph() if check_eager: eager_checker = EagerChecker(self, self.outputs) eager_checker.check() -- GitLab