未验证 提交 c1375783 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Support for Tuple in for Loop (#30998)

Dy2stat didn't support tuple as iteration variable in the past. This PR added there main cases:

       1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as:
          for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
            var1 = FOR_ITER_TUPLE_PREFIX_x[0]
            var2 = FOR_ITER_TUPLE_PREFIX_x[1]
        2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as:
          for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy):
            t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x)
        3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will
        be re-written as:
          for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
            var1 = FOR_ITER_TUPLE_PREFIX_x[0]
            var2 = FOR_ITER_TUPLE_PREFIX_x[1][0]
            var3 = FOR_ITER_TUPLE_PREFIX_x[1][1]
上级 2497f439
......@@ -25,6 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
......@@ -427,9 +428,10 @@ class LoopTransformer(gast.NodeTransformer):
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.name_visitor = NameVisitor(self.root)
def transform(self):
ForLoopTuplePreTransformer(self.wrapper_root).transform()
self.name_visitor = NameVisitor(self.root)
self.visit(self.root)
def visit(self, node):
......
......@@ -75,6 +75,8 @@ dygraph_class_to_static_api = {
}
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_TUPLE_PREFIX = '__for_loop_iter_tuple'
FOR_ITER_TUPLE_INDEX_PREFIX = '__for_loop_iter_tuple_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
......@@ -810,6 +812,155 @@ class NameNodeReplaceTransformer(gast.NodeTransformer):
return node
class ForLoopTuplePreTransformer(gast.NodeTransformer):
"""
ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable):
1). for x in range(var[*]|var.numpy()[*])
2). for x in var|var.numpy()
3). for i, x in enumerate(var|var.numpy())
We chose these 3 types because they are easier (x can be variable name iterating in var).
However, users can write tuples in Python for loop, such as
1). for var1, var2 in var|var.numpy()
2). for t in enumerate(var|var.numpy())
2). for i, (var1, var2, va3) in enumerate(var|var.numpy())
To handle these case, this method will do the rewrite tuple pre-process:
1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as:
for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1]
2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as:
for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy):
t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x)
3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will
be re-written as:
for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1][0]
var3 = FOR_ITER_TUPLE_PREFIX_x[1][1]
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
def visit_For(self, node):
if self.is_for_enumerate_iter(node):
if isinstance(node.target, (gast.Name, gast.Attribute)):
# Out tuple case
out_tuple_name = ast_to_source_code(node.target).strip()
tuple_iter_name = unique_name.generate(
FOR_ITER_TUPLE_INDEX_PREFIX)
tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
node.target = gast.Tuple(
elts=[
gast.Name(
id=tuple_iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None), gast.Name(
id=tuple_var_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
ctx=gast.Store())
node.body.insert(
0,
gast.Assign(
targets=[
gast.Name(
id=out_tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Tuple(
elts=[
gast.Name(
id=tuple_iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None), gast.Name(
id=tuple_var_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
ctx=gast.Load())))
elif isinstance(node.target, (
gast.List,
gast.Tuple)) and len(node.target.elts) >= 2 and isinstance(
node.target.elts[1], (gast.List, gast.Tuple)):
# Inner tuple case
inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_inner_tuple_node = node.target.elts[1]
node.target.elts[1] = gast.Name(
id=inner_tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node,
inner_tuple_name)
elif self.is_for_iter(node) and isinstance(node.target,
(gast.List, gast.Tuple)):
# Non-enumrate case:
tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_tuple_node = node.target
node.target = gast.Name(
id=tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name)
return node
def tuple_to_stmts(self, node, tuple_name, idx=[]):
if not isinstance(node, (gast.Tuple, gast.List)):
value_node = gast.Name(
id=tuple_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
for i in idx:
value_node = gast.Subscript(
value=value_node,
slice=gast.Index(value=gast.Constant(
value=i, kind=None)),
ctx=gast.Load())
return [gast.Assign(targets=[node], value=value_node)]
# isinstance(node, (gast.Tuple, gast.List))
ret = []
for i, element in enumerate(node.elts):
ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i])
return ret
def is_for_iter(self, for_node):
assert isinstance(for_node,
gast.For), "Input node is not gast.For node."
if isinstance(for_node.iter, (gast.Name, gast.Attribute)):
return True
elif isinstance(for_node.iter, gast.Call) and isinstance(
for_node.iter.func,
gast.Attribute) and for_node.iter.func.attr == 'numpy':
return True
elif isinstance(for_node.iter, gast.Subscript):
return True
else:
return False
def is_for_enumerate_iter(self, for_node):
assert isinstance(for_node,
gast.For), "Input node is not gast.For node."
return isinstance(for_node.iter, gast.Call) and isinstance(
for_node.iter.func,
gast.Name) and for_node.iter.func.id == "enumerate"
class ForNodeVisitor(object):
"""
This class parses python for statement, get transformed 3 statement components of for node
......
......@@ -233,6 +233,57 @@ def for_iter_var_idx(x_array):
return z
@paddle.jit.to_static
def for_tuple_as_iter_var(x_array):
x = paddle.to_tensor(x_array)
z = paddle.to_tensor(np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]]))
a_result = paddle.zeros([3])
b_result = paddle.zeros([3])
c_result = paddle.zeros([3])
for a, b, c in z:
a_result += a
b_result += b
c_result += c
return a_result, b_result, c_result
@paddle.jit.to_static
def for_tuple_as_enumerate_iter(x_array):
x = paddle.to_tensor(x_array)
x_list = [x, x, x]
a_result = paddle.zeros([5])
for t in enumerate(x_list):
a_result += t[1]
return a_result
@paddle.jit.to_static
def for_tuple_as_enumerate_value(x_array):
x = paddle.to_tensor(x_array)
x_list = [x, x, x]
a_result = paddle.zeros([1])
b_result = paddle.zeros([1])
c_result = paddle.zeros([1])
d_result = paddle.zeros([1])
e_result = paddle.zeros([1])
for i, (a, b, c, d, e) in enumerate(x_list):
a_result += a
b_result += b
c_result += c
d_result += d
e_result += e
return a_result
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......@@ -380,5 +431,20 @@ class TestForEnumerateVarList(TestForInRange):
self.dygraph_func = for_enumerate_var_list
class TestForTupleAsIterVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_tuple_as_iter_var
class TestForTupleAsEnumerateIter(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_iter
class TestForTupleAsEnumerateValue(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_value
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册