未验证 提交 287cbde8 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Fix ForLoop Transformation with single return (#40683)

* [Dy2Stat] Fix ForLoop Transformation with single return

* [Dy2Stat] Fix ForLoop Transformation with single return
上级 c03186f9
...@@ -693,7 +693,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -693,7 +693,7 @@ class LoopTransformer(gast.NodeTransformer):
new_body = node.body new_body = node.body
new_body.append( new_body.append(
gast.Return(value=generate_name_node( gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load()))) loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef( body_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_BODY_PREFIX), name=unique_name.generate(WHILE_BODY_PREFIX),
args=gast.arguments( args=gast.arguments(
......
...@@ -121,6 +121,13 @@ def for_loop_dyfunc_not_support(max_len): ...@@ -121,6 +121,13 @@ def for_loop_dyfunc_not_support(max_len):
return ret return ret
def for_break_single_return(max_len):
for i in range(3):
if i == 2:
break
return i
def while_loop_bool_op(x): def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x) i = fluid.dygraph.to_variable(x)
...@@ -297,7 +304,10 @@ class TestTransformWhileLoop(unittest.TestCase): ...@@ -297,7 +304,10 @@ class TestTransformWhileLoop(unittest.TestCase):
ret = declarative(self.dyfunc)(tensor_x) ret = declarative(self.dyfunc)(tensor_x)
else: else:
ret = self.dyfunc(tensor_x) ret = self.dyfunc(tensor_x)
if hasattr(ret, "numpy"):
return ret.numpy() return ret.numpy()
else:
return ret
def test_ast_to_func(self): def test_ast_to_func(self):
static_numpy = self._run_static() static_numpy = self._run_static()
...@@ -320,6 +330,11 @@ class TestTransformWhileLoopWithNone(TestTransformWhileLoop): ...@@ -320,6 +330,11 @@ class TestTransformWhileLoopWithNone(TestTransformWhileLoop):
self.dyfunc = while_loop_dyfunc_with_none self.dyfunc = while_loop_dyfunc_with_none
class TestForBreakSingleReturn(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = for_break_single_return
class TestWhileLoopBoolOp(TestTransformWhileLoop): class TestWhileLoopBoolOp(TestTransformWhileLoop):
def _init_dyfunc(self): def _init_dyfunc(self):
self.dyfunc = while_loop_bool_op self.dyfunc = while_loop_bool_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册