diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index b7ef000938a1518592d06abaf3ae6da437acd9a8..bd89a79c805c98d6092f31573b86d8c7cea20a26 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -594,7 +594,7 @@ class LoopTransformer(gast.NodeTransformer): # append return values for loop body body_stmts.append( gast.Return(value=generate_name_node( - loop_var_names, ctx=gast.Load()))) + loop_var_names, ctx=gast.Load(), gen_tuple_if_single=True))) body_func_node = gast.FunctionDef( name=unique_name.generate(FOR_BODY_PREFIX), args=gast.arguments( diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 1071fc1350bfeb8f8b768a81204f134b345101b5..624ca085ac6c2dfa5c2ca7e2894b0f2afc6f6ea2 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -381,9 +381,15 @@ def get_attribute_full_name(node): return astor.to_source(gast.gast_to_ast(node)).strip() -def generate_name_node(name_ids, ctx=gast.Load()): +def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False): """ - Generate list or gast.Tuple of ast.Name for Return statement. + If name_ids is list or tuple or set with multiple strings, this function + generates gast.Tuple of gast.Name. + If the name_ids is single string or contains only 1 string, this function + returns gast.Name if gen_tuple_if_single==False else returns gast.Tuple + with only one gast.Name + + This function is used at several gast.Return statements. """ if isinstance(name_ids, six.string_types): name_ids = [name_ids] @@ -395,7 +401,7 @@ def generate_name_node(name_ids, ctx=gast.Load()): id=name_id, ctx=ctx, annotation=None, type_comment=None) for name_id in name_ids ] - if len(gast_names) == 1: + if len(gast_names) == 1 and not gen_tuple_if_single: name_node = gast_names[0] else: name_node = gast.Tuple(elts=gast_names, ctx=ctx) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index c28997c5c1c673c4d7774610dd24036074a1b54c..517cff39a276f490c2d4b6a0d951d40dc00f57b3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -233,6 +233,7 @@ def for_iter_var_idx(x_array): return z +# 17. for a,b,c in z: (a, b, c) is a tuple @paddle.jit.to_static def for_tuple_as_iter_var(x_array): x = paddle.to_tensor(x_array) @@ -250,6 +251,7 @@ def for_tuple_as_iter_var(x_array): return a_result, b_result, c_result +# 18. for t in enumerate(collection): t is tuple of (idx, element) @paddle.jit.to_static def for_tuple_as_enumerate_iter(x_array): x = paddle.to_tensor(x_array) @@ -263,6 +265,7 @@ def for_tuple_as_enumerate_iter(x_array): return a_result +# 19. for i, (a, b, c, d, e) in enumerate(collection): (a, b, c, d, e) is a tuple @paddle.jit.to_static def for_tuple_as_enumerate_value(x_array): x = paddle.to_tensor(x_array) @@ -284,6 +287,23 @@ def for_tuple_as_enumerate_value(x_array): return a_result +# 20. test for function in a class +class ForwardContainsForLayer(paddle.nn.Layer): + def __init__(self): + super(ForwardContainsForLayer, self).__init__() + self.high = 5 + self.low = 3 + + @paddle.jit.to_static + def forward(self, x): + # just for test case, x is useless in this method + y = paddle.zeros([10, 2, 3]) + z = [] + for i in range(self.high - self.low): + z.append(y[i].clone()) + return z + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -313,11 +333,11 @@ class TestTransformBase(unittest.TestCase): class TestTransform(TestTransformBase): def transformed_result_compare(self): dy_outs = self.get_dygraph_output() - if not isinstance(dy_outs, tuple): + if not isinstance(dy_outs, (tuple, list)): dy_outs = (dy_outs, ) st_outs = self.get_static_output() - if not isinstance(st_outs, tuple): + if not isinstance(st_outs, (tuple, list)): st_outs = (st_outs, ) for x, y in zip(dy_outs, st_outs): @@ -446,5 +466,10 @@ class TestForTupleAsEnumerateValue(TestForIterVarNumpy): self.dygraph_func = for_tuple_as_enumerate_value +class TestForwardContainsForLayer(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = ForwardContainsForLayer() + + if __name__ == '__main__': unittest.main()