未验证 提交 a1c1b59d 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2Static] Support for iter & enumerate VarBase (#24856)

* support for iter & enumerate varbase, test=develop

* revert IsControlFlowVisitor change, test=develop
上级 5ea82e8a
......@@ -424,9 +424,9 @@ class LoopTransformer(gast.NodeTransformer):
# 1. check whether need to transform
# NOTE: Current need transform cases:
# 1). for x in range(VarBase.numpy()[0])
# 2). for x in VarBase.numpy()
# 3). for i, x in enumerate(VarBase.numpy())
# 1). for x in range(VarBase[0]|VarBase.numpy()[0])
# 2). for x in VarBase|VarBase.numpy()
# 3). for i, x in enumerate(VarBase|VarBase.numpy())
if not self.name_visitor.is_control_flow_loop(node):
return [node]
......
......@@ -239,7 +239,15 @@ def update_args_of_func(node, dygraph_node, method_name):
def create_api_shape_node(tensor_shape_node):
assert isinstance(tensor_shape_node, (gast.Attribute, gast.Subscript))
assert isinstance(tensor_shape_node,
(gast.Name, gast.Attribute, gast.Subscript))
if isinstance(tensor_shape_node, gast.Name):
api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value,
args=[tensor_shape_node],
keywords=[])
return api_shape_node
if isinstance(tensor_shape_node, gast.Attribute):
api_shape_node = gast.Call(
......@@ -453,8 +461,10 @@ class IsControlFlowVisitor(gast.NodeVisitor):
gast.While must meet at least one of the requirements 1 to 5:
4. has `break` statement.
5. has `continue` statement.
gast.For must meet at least one of the requirements 4 to 6:
gast.For must meet at least one of the requirements 4 to 8:
6. calls `range` function in `for` statement and the argument of range is Tensor.
7. calls `enumerate` function in `for` statement and the argument of enumerate is Tensor.
8. the iterable varaible in `for` statement is Tensor.
TODO: Support non-range case
The following examples should not be considered as control_flow_if:
......@@ -507,22 +517,25 @@ class IsControlFlowVisitor(gast.NodeVisitor):
def _visit_For(self, node):
assert isinstance(node, gast.For)
if not isinstance(node.iter, gast.Call):
return
# for in range(v.numpy()) or for in enumerate(v.numpy())
if isinstance(node.iter.func, gast.Name):
if node.iter.func.id == "range" or node.iter.func.id == "enumerate":
for arg in node.iter.args:
self.visit(arg)
else:
return
# for in v.numpy()
elif isinstance(node.iter.func, gast.Attribute):
if node.iter.func.attr == 'numpy':
self._visit_Call(node.iter)
if isinstance(node.iter, gast.Call):
# for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy())
if isinstance(node.iter.func, gast.Name):
if node.iter.func.id == "range" or node.iter.func.id == "enumerate":
for arg in node.iter.args:
self.visit(arg)
else:
return
# for in var.numpy()
elif isinstance(node.iter.func, gast.Attribute):
if node.iter.func.attr == 'numpy':
self._visit_Call(node.iter)
else:
return
else:
return
elif isinstance(node.iter, gast.Name):
# for in var
self.visit(node.iter)
else:
return
......@@ -655,10 +668,10 @@ class ForNodeVisitor(object):
In this process, the semantics of for does not change.
Now only can parse 3 type statements:
1). for x in range(***)
2). for x in var.numpy()
3). for i, x enumerate(var.numpy())
Now only can parse 3 type statements (Here var is VarBase(Tensor)):
1). for x in range(var[*]|var.numpy()[*])
2). for x in var|var.numpy()
3). for i, x enumerate(var|var.numpy())
"""
def __init__(self, for_node):
......@@ -678,28 +691,29 @@ class ForNodeVisitor(object):
# 3. key shared node or names
# - x:
# - for x in range(***)
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
# - for x in var|var.numpy()
# - for i, x enumerate(var|var.numpy())
self.iter_var_name = self._get_iter_var_name()
# - created index var to slice Variable: __for_loop_var_index_0
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
# - for x in var|var.numpy()
# - for i, x enumerate(var|var.numpy())
self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX)
# - created shape var to build loop condition: __for_loop_var_shape_0
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
# - for x in var|var.numpy()
# - for i, x enumerate(var|var.numpy())
# - for x in var
self.iter_var_shape_name = unique_name.generate(
FOR_ITER_VAR_SHAPE_PREFIX)
# - var.numpy()
# - for x in var.numpy()
# - for i, x enumerate(var.numpy())
# - var.numpy()/var
# - for x in var|var.numpy()
# - for i, x enumerate(var|var.numpy())
self.iter_node = self._get_iter_node()
# - enumeate i:
# - for i, x enumerate(var.numpy())
# - for i, x enumerate(var|var.numpy())
self.enum_idx_name = self._get_enum_idx_name()
# - range/enumerate args length
......@@ -717,17 +731,24 @@ class ForNodeVisitor(object):
raise None
def is_for_range_iter(self):
return isinstance(self.node.iter.func,
gast.Name) and self.node.iter.func.id == "range"
return isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
gast.Name) and self.node.iter.func.id == "range"
def is_for_iter(self):
return isinstance(
self.node.iter.func,
gast.Attribute) and self.node.iter.func.attr == 'numpy'
if isinstance(self.node.iter, gast.Name):
return True
elif isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
gast.Attribute) and self.node.iter.func.attr == 'numpy':
return True
else:
return False
def is_for_enumerate_iter(self):
return isinstance(self.node.iter.func,
gast.Name) and self.node.iter.func.id == "enumerate"
return isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
gast.Name) and self.node.iter.func.id == "enumerate"
def _args_check(self):
if self.is_for_range_iter():
......@@ -811,6 +832,10 @@ class ForNodeVisitor(object):
def _build_var_shape_assign_node(self):
# get variable shape as iter length
if isinstance(self.iter_node, gast.Call):
iter_var = self.iter_node.func
else:
iter_var = self.iter_node
return gast.Assign(
targets=[
gast.Name(
......@@ -819,7 +844,7 @@ class ForNodeVisitor(object):
annotation=None,
type_comment=None)
],
value=create_api_shape_node(self.iter_node.func))
value=create_api_shape_node(iter_var))
def _build_enum_init_node(self):
enum_init_node = get_constant_variable_node(
......
......@@ -24,9 +24,9 @@ from paddle.fluid.dygraph.jit import declarative
program_translator = ProgramTranslator()
# 0. for in range with var case
# 0. for in range var.numpy()[0]
@declarative
def dygraph_for_in_range(x):
def for_in_range(x):
z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]):
......@@ -36,7 +36,7 @@ def dygraph_for_in_range(x):
# 1. for iter list
@declarative
def dygraph_for_iter_list(x_array):
def for_iter_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in x_array:
z = z + x
......@@ -45,7 +45,7 @@ def dygraph_for_iter_list(x_array):
# 2. for enumerate list
@declarative
def dygraph_for_enumerate_list(x_array):
def for_enumerate_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
for i, x in enumerate(x_array):
z = z + x + i
......@@ -54,7 +54,7 @@ def dygraph_for_enumerate_list(x_array):
# 3. for iter var.numpy()
@declarative
def dygraph_for_iter_var_numpy(x_array):
def for_iter_var_numpy(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for x in x_array.numpy():
......@@ -64,7 +64,7 @@ def dygraph_for_iter_var_numpy(x_array):
# 4. for enumerate var.numpy()
@declarative
def dygraph_for_enumerate_var_numpy(x_array):
def for_enumerate_var_numpy(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -76,7 +76,7 @@ def dygraph_for_enumerate_var_numpy(x_array):
# 5. for enumerate var.numpy() with start
@declarative
def dygraph_for_enumerate_var_numpy_with_start(x_array):
def for_enumerate_var_numpy_with_start(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -88,7 +88,7 @@ def dygraph_for_enumerate_var_numpy_with_start(x_array):
# 6. for in range with break
@declarative
def dygraph_for_in_range_with_break(x):
def for_in_range_with_break(x):
z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]):
......@@ -100,7 +100,7 @@ def dygraph_for_in_range_with_break(x):
# 7. for enumerate var.numpy() with break
@declarative
def dygraph_for_enumerate_var_numpy_with_break(x_array):
def for_enumerate_var_numpy_with_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -114,7 +114,7 @@ def dygraph_for_enumerate_var_numpy_with_break(x_array):
# 8. for enumerate var.numpy() with continue
@declarative
def dygraph_for_enumerate_var_numpy_with_continue(x_array):
def for_enumerate_var_numpy_with_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -128,7 +128,7 @@ def dygraph_for_enumerate_var_numpy_with_continue(x_array):
# 9. for enumerate var.numpy() with start & break
@declarative
def dygraph_for_enumerate_var_numpy_with_start_break(x_array):
def for_enumerate_var_numpy_with_start_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -142,7 +142,7 @@ def dygraph_for_enumerate_var_numpy_with_start_break(x_array):
# 10. for enumerate var.numpy() with start & continue
@declarative
def dygraph_for_enumerate_var_numpy_with_start_continue(x_array):
def for_enumerate_var_numpy_with_start_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -154,6 +154,28 @@ def dygraph_for_enumerate_var_numpy_with_start_continue(x_array):
return y, z
# 11. for iter var
@declarative
def for_iter_var(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for x in x_array:
z = z + x
return z
# 12. for enumerate var
@declarative
def for_enumerate_var(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array):
y = y + i
z = z + x
return y, z
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......@@ -206,7 +228,7 @@ class TestForInRange(TestTransform):
self.input = np.array([5])
def set_test_func(self):
self.dygraph_func = dygraph_for_in_range
self.dygraph_func = for_in_range
def test_transformed_result_compare(self):
self.transformed_result_compare()
......@@ -214,7 +236,7 @@ class TestForInRange(TestTransform):
class TestForIterList(TestTransform):
def set_test_func(self):
self.dygraph_func = dygraph_for_iter_list
self.dygraph_func = for_iter_list
def test_transformed_result_compare(self):
self.transformed_result_compare()
......@@ -222,12 +244,12 @@ class TestForIterList(TestTransform):
class TestForEnumerateSimple(TestForIterList):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_list
self.dygraph_func = for_enumerate_list
class TestForInRangeWithBreak(TestForInRange):
def set_test_func(self):
self.dygraph_func = dygraph_for_in_range_with_break
self.dygraph_func = for_in_range_with_break
class TestForIterVarNumpy(TestTransform):
......@@ -235,7 +257,7 @@ class TestForIterVarNumpy(TestTransform):
self.input = np.array([1, 2, 3, 4, 5])
def set_test_func(self):
self.dygraph_func = dygraph_for_iter_var_numpy
self.dygraph_func = for_iter_var_numpy
def test_transformed_result_compare(self):
self.transformed_result_compare()
......@@ -243,32 +265,42 @@ class TestForIterVarNumpy(TestTransform):
class TestForEnumerateVarNumpy(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy
self.dygraph_func = for_enumerate_var_numpy
class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start
self.dygraph_func = for_enumerate_var_numpy_with_start
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_break
self.dygraph_func = for_enumerate_var_numpy_with_break
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_continue
self.dygraph_func = for_enumerate_var_numpy_with_continue
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_break
self.dygraph_func = for_enumerate_var_numpy_with_start_break
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_continue
self.dygraph_func = for_enumerate_var_numpy_with_start_continue
class TestForIterVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_iter_var
class TestForEnumerateVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册