未验证 提交 72cbb6da 编写于 作者: C Charles-hit 提交者: GitHub

add static guard (#50971)

上级 f265a313
......@@ -7524,6 +7524,18 @@ def _dygraph_guard(tracer):
global_var._dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager
def _static_guard():
tmp_tracer = global_var._dygraph_tracer_
global_var._dygraph_tracer_ = None
try:
yield
finally:
if tmp_tracer is not None:
core._switch_tracer(tmp_tracer)
global_var._dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager
def _dygraph_place_guard(place):
global _global_expected_place_
......
......@@ -920,7 +920,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None,
for_inplace_test=None,
):
with paddle.static.program_guard(paddle.static.Program()):
with paddle.fluid.framework._static_guard():
program = Program()
block = program.global_block()
op = self._append_ops(block)
......@@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase):
Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
"""
with paddle.static.program_guard(paddle.static.Program()):
with paddle.fluid.framework._static_guard():
(
fwd_outs,
fwd_fetch_list,
......@@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
parallel=False,
):
with paddle.static.program_guard(paddle.static.Program()):
with paddle.fluid.framework._static_guard():
prog = Program()
scope = core.Scope()
block = prog.global_block()
......
......@@ -1018,7 +1018,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None,
for_inplace_test=None,
):
with paddle.static.program_guard(paddle.static.Program()):
with paddle.fluid.framework._static_guard():
program = Program()
block = program.global_block()
op = self._append_ops(block)
......@@ -1347,7 +1347,7 @@ class OpTest(unittest.TestCase):
Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
"""
with paddle.static.program_guard(paddle.static.Program()):
with paddle.fluid.framework._static_guard():
(
fwd_outs,
fwd_fetch_list,
......@@ -2666,7 +2666,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
parallel=False,
):
with paddle.static.program_guard(paddle.static.Program()):
with paddle.fluid.framework._static_guard():
prog = Program()
scope = core.Scope()
block = prog.global_block()
......
......@@ -28,7 +28,7 @@ from paddle.jit.dy2static.utils import parse_arg_and_kwargs
def flatten(nest_list):
out = []
for i in nest_list:
if isinstance(i, list or tuple):
if isinstance(i, (list, tuple)):
tmp_list = flatten(i)
for j in tmp_list:
out.append(j)
......@@ -40,7 +40,7 @@ def flatten(nest_list):
def _as_list(x):
if x is None:
return []
return list(x) if isinstance(x, list or tuple) else [x]
return list(x) if isinstance(x, (list, tuple)) else [x]
def convert_uint16_to_float(in_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册