未验证 提交 649868ff 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Fix the bug that loop_body_func may return single element (#31806)

Our old `loop_body` function may return single element when `loop_vars` just contains only 1 element, which can cause bug. The key point of this PR is forcing `loop_body` functions always return tuple.
上级 e5f7a834
...@@ -594,7 +594,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -594,7 +594,7 @@ class LoopTransformer(gast.NodeTransformer):
# append return values for loop body # append return values for loop body
body_stmts.append( body_stmts.append(
gast.Return(value=generate_name_node( gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load()))) loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True)))
body_func_node = gast.FunctionDef( body_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_BODY_PREFIX), name=unique_name.generate(FOR_BODY_PREFIX),
args=gast.arguments( args=gast.arguments(
......
...@@ -381,9 +381,15 @@ def get_attribute_full_name(node): ...@@ -381,9 +381,15 @@ def get_attribute_full_name(node):
return astor.to_source(gast.gast_to_ast(node)).strip() return astor.to_source(gast.gast_to_ast(node)).strip()
def generate_name_node(name_ids, ctx=gast.Load()): def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False):
""" """
Generate list or gast.Tuple of ast.Name for Return statement. If name_ids is list or tuple or set with multiple strings, this function
generates gast.Tuple of gast.Name.
If the name_ids is single string or contains only 1 string, this function
returns gast.Name if gen_tuple_if_single==False else returns gast.Tuple
with only one gast.Name
This function is used at several gast.Return statements.
""" """
if isinstance(name_ids, six.string_types): if isinstance(name_ids, six.string_types):
name_ids = [name_ids] name_ids = [name_ids]
...@@ -395,7 +401,7 @@ def generate_name_node(name_ids, ctx=gast.Load()): ...@@ -395,7 +401,7 @@ def generate_name_node(name_ids, ctx=gast.Load()):
id=name_id, ctx=ctx, annotation=None, type_comment=None) id=name_id, ctx=ctx, annotation=None, type_comment=None)
for name_id in name_ids for name_id in name_ids
] ]
if len(gast_names) == 1: if len(gast_names) == 1 and not gen_tuple_if_single:
name_node = gast_names[0] name_node = gast_names[0]
else: else:
name_node = gast.Tuple(elts=gast_names, ctx=ctx) name_node = gast.Tuple(elts=gast_names, ctx=ctx)
......
...@@ -233,6 +233,7 @@ def for_iter_var_idx(x_array): ...@@ -233,6 +233,7 @@ def for_iter_var_idx(x_array):
return z return z
# 17. for a,b,c in z: (a, b, c) is a tuple
@paddle.jit.to_static @paddle.jit.to_static
def for_tuple_as_iter_var(x_array): def for_tuple_as_iter_var(x_array):
x = paddle.to_tensor(x_array) x = paddle.to_tensor(x_array)
...@@ -250,6 +251,7 @@ def for_tuple_as_iter_var(x_array): ...@@ -250,6 +251,7 @@ def for_tuple_as_iter_var(x_array):
return a_result, b_result, c_result return a_result, b_result, c_result
# 18. for t in enumerate(collection): t is tuple of (idx, element)
@paddle.jit.to_static @paddle.jit.to_static
def for_tuple_as_enumerate_iter(x_array): def for_tuple_as_enumerate_iter(x_array):
x = paddle.to_tensor(x_array) x = paddle.to_tensor(x_array)
...@@ -263,6 +265,7 @@ def for_tuple_as_enumerate_iter(x_array): ...@@ -263,6 +265,7 @@ def for_tuple_as_enumerate_iter(x_array):
return a_result return a_result
# 19. for i, (a, b, c, d, e) in enumerate(collection): (a, b, c, d, e) is a tuple
@paddle.jit.to_static @paddle.jit.to_static
def for_tuple_as_enumerate_value(x_array): def for_tuple_as_enumerate_value(x_array):
x = paddle.to_tensor(x_array) x = paddle.to_tensor(x_array)
...@@ -284,6 +287,23 @@ def for_tuple_as_enumerate_value(x_array): ...@@ -284,6 +287,23 @@ def for_tuple_as_enumerate_value(x_array):
return a_result return a_result
# 20. test for function in a class
class ForwardContainsForLayer(paddle.nn.Layer):
def __init__(self):
super(ForwardContainsForLayer, self).__init__()
self.high = 5
self.low = 3
@paddle.jit.to_static
def forward(self, x):
# just for test case, x is useless in this method
y = paddle.zeros([10, 2, 3])
z = []
for i in range(self.high - self.low):
z.append(y[i].clone())
return z
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
...@@ -313,11 +333,11 @@ class TestTransformBase(unittest.TestCase): ...@@ -313,11 +333,11 @@ class TestTransformBase(unittest.TestCase):
class TestTransform(TestTransformBase): class TestTransform(TestTransformBase):
def transformed_result_compare(self): def transformed_result_compare(self):
dy_outs = self.get_dygraph_output() dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, tuple): if not isinstance(dy_outs, (tuple, list)):
dy_outs = (dy_outs, ) dy_outs = (dy_outs, )
st_outs = self.get_static_output() st_outs = self.get_static_output()
if not isinstance(st_outs, tuple): if not isinstance(st_outs, (tuple, list)):
st_outs = (st_outs, ) st_outs = (st_outs, )
for x, y in zip(dy_outs, st_outs): for x, y in zip(dy_outs, st_outs):
...@@ -446,5 +466,10 @@ class TestForTupleAsEnumerateValue(TestForIterVarNumpy): ...@@ -446,5 +466,10 @@ class TestForTupleAsEnumerateValue(TestForIterVarNumpy):
self.dygraph_func = for_tuple_as_enumerate_value self.dygraph_func = for_tuple_as_enumerate_value
class TestForwardContainsForLayer(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = ForwardContainsForLayer()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册