From 72cbb6da9c8ad45e95713fda75915865dad22047 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Tue, 28 Feb 2023 15:08:18 +0800 Subject: [PATCH] add static guard (#50971) --- python/paddle/fluid/framework.py | 12 ++++++++++++ python/paddle/fluid/tests/unittests/eager_op_test.py | 6 +++--- python/paddle/fluid/tests/unittests/op_test.py | 6 +++--- python/paddle/fluid/tests/unittests/prim_op_test.py | 4 ++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 55ad7305e9d..6ec5d5ffed7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -7524,6 +7524,18 @@ def _dygraph_guard(tracer): global_var._dygraph_tracer_ = tmp_tracer +@signature_safe_contextmanager +def _static_guard(): + tmp_tracer = global_var._dygraph_tracer_ + global_var._dygraph_tracer_ = None + try: + yield + finally: + if tmp_tracer is not None: + core._switch_tracer(tmp_tracer) + global_var._dygraph_tracer_ = tmp_tracer + + @signature_safe_contextmanager def _dygraph_place_guard(place): global _global_expected_place_ diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index 46c612cd075..81e70955697 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -920,7 +920,7 @@ class OpTest(unittest.TestCase): enable_inplace=None, for_inplace_test=None, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): program = Program() block = program.global_block() op = self._append_ops(block) @@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase): Returns: res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. """ - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): ( fwd_outs, fwd_fetch_list, @@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs=None, parallel=False, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): prog = Program() scope = core.Scope() block = prog.global_block() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index c31d88ddeda..60f175203c9 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1018,7 +1018,7 @@ class OpTest(unittest.TestCase): enable_inplace=None, for_inplace_test=None, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): program = Program() block = program.global_block() op = self._append_ops(block) @@ -1347,7 +1347,7 @@ class OpTest(unittest.TestCase): Returns: res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. """ - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): ( fwd_outs, fwd_fetch_list, @@ -2666,7 +2666,7 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs=None, parallel=False, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): prog = Program() scope = core.Scope() block = prog.global_block() diff --git a/python/paddle/fluid/tests/unittests/prim_op_test.py b/python/paddle/fluid/tests/unittests/prim_op_test.py index 73986391dc6..c3b1d44bb2c 100644 --- a/python/paddle/fluid/tests/unittests/prim_op_test.py +++ b/python/paddle/fluid/tests/unittests/prim_op_test.py @@ -28,7 +28,7 @@ from paddle.jit.dy2static.utils import parse_arg_and_kwargs def flatten(nest_list): out = [] for i in nest_list: - if isinstance(i, list or tuple): + if isinstance(i, (list, tuple)): tmp_list = flatten(i) for j in tmp_list: out.append(j) @@ -40,7 +40,7 @@ def flatten(nest_list): def _as_list(x): if x is None: return [] - return list(x) if isinstance(x, list or tuple) else [x] + return list(x) if isinstance(x, (list, tuple)) else [x] def convert_uint16_to_float(in_list): -- GitLab