From c137578341a7ad771580d744fc2ac186da0d2d19 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Thu, 18 Feb 2021 15:11:04 +0800 Subject: [PATCH] Add Support for Tuple in for Loop (#30998) Dy2stat didn't support tuple as iteration variable in the past. This PR added there main cases: 1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as: for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): var1 = FOR_ITER_TUPLE_PREFIX_x[0] var2 = FOR_ITER_TUPLE_PREFIX_x[1] 2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as: for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy): t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x) 3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will be re-written as: for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): var1 = FOR_ITER_TUPLE_PREFIX_x[0] var2 = FOR_ITER_TUPLE_PREFIX_x[1][0] var3 = FOR_ITER_TUPLE_PREFIX_x[1][1] --- .../dygraph_to_static/loop_transformer.py | 4 +- .../fluid/dygraph/dygraph_to_static/utils.py | 151 ++++++++++++++++++ .../dygraph_to_static/test_for_enumerate.py | 66 ++++++++ 3 files changed, 220 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 924143049e..140c57f710 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -25,6 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name +from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node @@ -427,9 +428,10 @@ class LoopTransformer(gast.NodeTransformer): ), "Input non-AstNodeWrapper node for the initialization of LoopTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node - self.name_visitor = NameVisitor(self.root) def transform(self): + ForLoopTuplePreTransformer(self.wrapper_root).transform() + self.name_visitor = NameVisitor(self.root) self.visit(self.root) def visit(self, node): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 9e61b8aa1e..e9f8afc06c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -75,6 +75,8 @@ dygraph_class_to_static_api = { } FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' +FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple' +FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var' @@ -810,6 +812,155 @@ class NameNodeReplaceTransformer(gast.NodeTransformer): return node +class ForLoopTuplePreTransformer(gast.NodeTransformer): + """ + ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable): + 1). for x in range(var[*]|var.numpy()[*]) + 2). for x in var|var.numpy() + 3). for i, x in enumerate(var|var.numpy()) + + We chose these 3 types because they are easier (x can be variable name iterating in var). + However, users can write tuples in Python for loop, such as + 1). for var1, var2 in var|var.numpy() + 2). for t in enumerate(var|var.numpy()) + 2). for i, (var1, var2, va3) in enumerate(var|var.numpy()) + + To handle these case, this method will do the rewrite tuple pre-process: + 1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as: + for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): + var1 = FOR_ITER_TUPLE_PREFIX_x[0] + var2 = FOR_ITER_TUPLE_PREFIX_x[1] + 2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as: + for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy): + t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x) + 3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will + be re-written as: + for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy(): + var1 = FOR_ITER_TUPLE_PREFIX_x[0] + var2 = FOR_ITER_TUPLE_PREFIX_x[1][0] + var3 = FOR_ITER_TUPLE_PREFIX_x[1][1] + """ + + def __init__(self, wrapper_root): + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + def transform(self): + self.visit(self.root) + + def visit_For(self, node): + if self.is_for_enumerate_iter(node): + if isinstance(node.target, (gast.Name, gast.Attribute)): + # Out tuple case + out_tuple_name = ast_to_source_code(node.target).strip() + tuple_iter_name = unique_name.generate( + FOR_ITER_TUPLE_INDEX_PREFIX) + tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) + node.target = gast.Tuple( + elts=[ + gast.Name( + id=tuple_iter_name, + ctx=gast.Store(), + annotation=None, + type_comment=None), gast.Name( + id=tuple_var_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + ctx=gast.Store()) + node.body.insert( + 0, + gast.Assign( + targets=[ + gast.Name( + id=out_tuple_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=gast.Tuple( + elts=[ + gast.Name( + id=tuple_iter_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), gast.Name( + id=tuple_var_name, + ctx=gast.Load(), + annotation=None, + type_comment=None) + ], + ctx=gast.Load()))) + elif isinstance(node.target, ( + gast.List, + gast.Tuple)) and len(node.target.elts) >= 2 and isinstance( + node.target.elts[1], (gast.List, gast.Tuple)): + # Inner tuple case + inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) + origin_inner_tuple_node = node.target.elts[1] + node.target.elts[1] = gast.Name( + id=inner_tuple_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node, + inner_tuple_name) + elif self.is_for_iter(node) and isinstance(node.target, + (gast.List, gast.Tuple)): + # Non-enumrate case: + tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX) + origin_tuple_node = node.target + node.target = gast.Name( + id=tuple_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name) + return node + + def tuple_to_stmts(self, node, tuple_name, idx=[]): + if not isinstance(node, (gast.Tuple, gast.List)): + value_node = gast.Name( + id=tuple_name, + ctx=gast.Load(), + annotation=None, + type_comment=None) + for i in idx: + value_node = gast.Subscript( + value=value_node, + slice=gast.Index(value=gast.Constant( + value=i, kind=None)), + ctx=gast.Load()) + return [gast.Assign(targets=[node], value=value_node)] + # isinstance(node, (gast.Tuple, gast.List)) + ret = [] + for i, element in enumerate(node.elts): + ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i]) + return ret + + def is_for_iter(self, for_node): + assert isinstance(for_node, + gast.For), "Input node is not gast.For node." + if isinstance(for_node.iter, (gast.Name, gast.Attribute)): + return True + elif isinstance(for_node.iter, gast.Call) and isinstance( + for_node.iter.func, + gast.Attribute) and for_node.iter.func.attr == 'numpy': + return True + elif isinstance(for_node.iter, gast.Subscript): + return True + else: + return False + + def is_for_enumerate_iter(self, for_node): + assert isinstance(for_node, + gast.For), "Input node is not gast.For node." + return isinstance(for_node.iter, gast.Call) and isinstance( + for_node.iter.func, + gast.Name) and for_node.iter.func.id == "enumerate" + + class ForNodeVisitor(object): """ This class parses python for statement, get transformed 3 statement components of for node diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 18995238a3..c28997c5c1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -233,6 +233,57 @@ def for_iter_var_idx(x_array): return z +@paddle.jit.to_static +def for_tuple_as_iter_var(x_array): + x = paddle.to_tensor(x_array) + z = paddle.to_tensor(np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])) + + a_result = paddle.zeros([3]) + b_result = paddle.zeros([3]) + c_result = paddle.zeros([3]) + + for a, b, c in z: + a_result += a + b_result += b + c_result += c + + return a_result, b_result, c_result + + +@paddle.jit.to_static +def for_tuple_as_enumerate_iter(x_array): + x = paddle.to_tensor(x_array) + x_list = [x, x, x] + + a_result = paddle.zeros([5]) + + for t in enumerate(x_list): + a_result += t[1] + + return a_result + + +@paddle.jit.to_static +def for_tuple_as_enumerate_value(x_array): + x = paddle.to_tensor(x_array) + x_list = [x, x, x] + + a_result = paddle.zeros([1]) + b_result = paddle.zeros([1]) + c_result = paddle.zeros([1]) + d_result = paddle.zeros([1]) + e_result = paddle.zeros([1]) + + for i, (a, b, c, d, e) in enumerate(x_list): + a_result += a + b_result += b + c_result += c + d_result += d + e_result += e + + return a_result + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -380,5 +431,20 @@ class TestForEnumerateVarList(TestForInRange): self.dygraph_func = for_enumerate_var_list +class TestForTupleAsIterVar(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = for_tuple_as_iter_var + + +class TestForTupleAsEnumerateIter(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = for_tuple_as_enumerate_iter + + +class TestForTupleAsEnumerateValue(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = for_tuple_as_enumerate_value + + if __name__ == '__main__': unittest.main() -- GitLab