diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 55ad7305e9d279c95503766251817af37c36d8f3..6ec5d5ffed79c14ebd0c629af1195cea098422b1 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -7524,6 +7524,18 @@ def _dygraph_guard(tracer): global_var._dygraph_tracer_ = tmp_tracer +@signature_safe_contextmanager +def _static_guard(): + tmp_tracer = global_var._dygraph_tracer_ + global_var._dygraph_tracer_ = None + try: + yield + finally: + if tmp_tracer is not None: + core._switch_tracer(tmp_tracer) + global_var._dygraph_tracer_ = tmp_tracer + + @signature_safe_contextmanager def _dygraph_place_guard(place): global _global_expected_place_ diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index 46c612cd07584db217f2bde51b3c184f99e2fdbe..81e709556974d8a7f1dbc26cf68e219c1caf0394 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -920,7 +920,7 @@ class OpTest(unittest.TestCase): enable_inplace=None, for_inplace_test=None, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): program = Program() block = program.global_block() op = self._append_ops(block) @@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase): Returns: res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. """ - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): ( fwd_outs, fwd_fetch_list, @@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs=None, parallel=False, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): prog = Program() scope = core.Scope() block = prog.global_block() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index c31d88ddeda2f6efc40126989e550b1487cd2827..60f175203c94d301e76098ec3f5aae6329eaf39a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1018,7 +1018,7 @@ class OpTest(unittest.TestCase): enable_inplace=None, for_inplace_test=None, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): program = Program() block = program.global_block() op = self._append_ops(block) @@ -1347,7 +1347,7 @@ class OpTest(unittest.TestCase): Returns: res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. """ - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): ( fwd_outs, fwd_fetch_list, @@ -2666,7 +2666,7 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs=None, parallel=False, ): - with paddle.static.program_guard(paddle.static.Program()): + with paddle.fluid.framework._static_guard(): prog = Program() scope = core.Scope() block = prog.global_block() diff --git a/python/paddle/fluid/tests/unittests/prim_op_test.py b/python/paddle/fluid/tests/unittests/prim_op_test.py index 73986391dc6b5c35cd626279f53d6b476e0b0a53..c3b1d44bb2c0a0ff0aeccbd3e166584c3c25ad2c 100644 --- a/python/paddle/fluid/tests/unittests/prim_op_test.py +++ b/python/paddle/fluid/tests/unittests/prim_op_test.py @@ -28,7 +28,7 @@ from paddle.jit.dy2static.utils import parse_arg_and_kwargs def flatten(nest_list): out = [] for i in nest_list: - if isinstance(i, list or tuple): + if isinstance(i, (list, tuple)): tmp_list = flatten(i) for j in tmp_list: out.append(j) @@ -40,7 +40,7 @@ def flatten(nest_list): def _as_list(x): if x is None: return [] - return list(x) if isinstance(x, list or tuple) else [x] + return list(x) if isinstance(x, (list, tuple)) else [x] def convert_uint16_to_float(in_list):