未验证 提交 72cbb6da 编写于 作者: C Charles-hit 提交者: GitHub

add static guard (#50971)

上级 f265a313
...@@ -7524,6 +7524,18 @@ def _dygraph_guard(tracer): ...@@ -7524,6 +7524,18 @@ def _dygraph_guard(tracer):
global_var._dygraph_tracer_ = tmp_tracer global_var._dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager
def _static_guard():
tmp_tracer = global_var._dygraph_tracer_
global_var._dygraph_tracer_ = None
try:
yield
finally:
if tmp_tracer is not None:
core._switch_tracer(tmp_tracer)
global_var._dygraph_tracer_ = tmp_tracer
@signature_safe_contextmanager @signature_safe_contextmanager
def _dygraph_place_guard(place): def _dygraph_place_guard(place):
global _global_expected_place_ global _global_expected_place_
......
...@@ -920,7 +920,7 @@ class OpTest(unittest.TestCase): ...@@ -920,7 +920,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None, enable_inplace=None,
for_inplace_test=None, for_inplace_test=None,
): ):
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
program = Program() program = Program()
block = program.global_block() block = program.global_block()
op = self._append_ops(block) op = self._append_ops(block)
...@@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase): ...@@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase):
Returns: Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
""" """
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
( (
fwd_outs, fwd_outs,
fwd_fetch_list, fwd_fetch_list,
...@@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase): ...@@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
parallel=False, parallel=False,
): ):
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
prog = Program() prog = Program()
scope = core.Scope() scope = core.Scope()
block = prog.global_block() block = prog.global_block()
......
...@@ -1018,7 +1018,7 @@ class OpTest(unittest.TestCase): ...@@ -1018,7 +1018,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None, enable_inplace=None,
for_inplace_test=None, for_inplace_test=None,
): ):
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
program = Program() program = Program()
block = program.global_block() block = program.global_block()
op = self._append_ops(block) op = self._append_ops(block)
...@@ -1347,7 +1347,7 @@ class OpTest(unittest.TestCase): ...@@ -1347,7 +1347,7 @@ class OpTest(unittest.TestCase):
Returns: Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
""" """
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
( (
fwd_outs, fwd_outs,
fwd_fetch_list, fwd_fetch_list,
...@@ -2666,7 +2666,7 @@ class OpTest(unittest.TestCase): ...@@ -2666,7 +2666,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
parallel=False, parallel=False,
): ):
with paddle.static.program_guard(paddle.static.Program()): with paddle.fluid.framework._static_guard():
prog = Program() prog = Program()
scope = core.Scope() scope = core.Scope()
block = prog.global_block() block = prog.global_block()
......
...@@ -28,7 +28,7 @@ from paddle.jit.dy2static.utils import parse_arg_and_kwargs ...@@ -28,7 +28,7 @@ from paddle.jit.dy2static.utils import parse_arg_and_kwargs
def flatten(nest_list): def flatten(nest_list):
out = [] out = []
for i in nest_list: for i in nest_list:
if isinstance(i, list or tuple): if isinstance(i, (list, tuple)):
tmp_list = flatten(i) tmp_list = flatten(i)
for j in tmp_list: for j in tmp_list:
out.append(j) out.append(j)
...@@ -40,7 +40,7 @@ def flatten(nest_list): ...@@ -40,7 +40,7 @@ def flatten(nest_list):
def _as_list(x): def _as_list(x):
if x is None: if x is None:
return [] return []
return list(x) if isinstance(x, list or tuple) else [x] return list(x) if isinstance(x, (list, tuple)) else [x]
def convert_uint16_to_float(in_list): def convert_uint16_to_float(in_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册