未验证 提交 a1c1b59d 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2Static] Support for iter & enumerate VarBase (#24856)

* support for iter & enumerate varbase, test=develop

* revert IsControlFlowVisitor change, test=develop
上级 5ea82e8a
...@@ -424,9 +424,9 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -424,9 +424,9 @@ class LoopTransformer(gast.NodeTransformer):
# 1. check whether need to transform # 1. check whether need to transform
# NOTE: Current need transform cases: # NOTE: Current need transform cases:
# 1). for x in range(VarBase.numpy()[0]) # 1). for x in range(VarBase[0]|VarBase.numpy()[0])
# 2). for x in VarBase.numpy() # 2). for x in VarBase|VarBase.numpy()
# 3). for i, x in enumerate(VarBase.numpy()) # 3). for i, x in enumerate(VarBase|VarBase.numpy())
if not self.name_visitor.is_control_flow_loop(node): if not self.name_visitor.is_control_flow_loop(node):
return [node] return [node]
......
...@@ -239,7 +239,15 @@ def update_args_of_func(node, dygraph_node, method_name): ...@@ -239,7 +239,15 @@ def update_args_of_func(node, dygraph_node, method_name):
def create_api_shape_node(tensor_shape_node): def create_api_shape_node(tensor_shape_node):
assert isinstance(tensor_shape_node, (gast.Attribute, gast.Subscript)) assert isinstance(tensor_shape_node,
(gast.Name, gast.Attribute, gast.Subscript))
if isinstance(tensor_shape_node, gast.Name):
api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value,
args=[tensor_shape_node],
keywords=[])
return api_shape_node
if isinstance(tensor_shape_node, gast.Attribute): if isinstance(tensor_shape_node, gast.Attribute):
api_shape_node = gast.Call( api_shape_node = gast.Call(
...@@ -453,8 +461,10 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -453,8 +461,10 @@ class IsControlFlowVisitor(gast.NodeVisitor):
gast.While must meet at least one of the requirements 1 to 5: gast.While must meet at least one of the requirements 1 to 5:
4. has `break` statement. 4. has `break` statement.
5. has `continue` statement. 5. has `continue` statement.
gast.For must meet at least one of the requirements 4 to 6: gast.For must meet at least one of the requirements 4 to 8:
6. calls `range` function in `for` statement and the argument of range is Tensor. 6. calls `range` function in `for` statement and the argument of range is Tensor.
7. calls `enumerate` function in `for` statement and the argument of enumerate is Tensor.
8. the iterable varaible in `for` statement is Tensor.
TODO: Support non-range case TODO: Support non-range case
The following examples should not be considered as control_flow_if: The following examples should not be considered as control_flow_if:
...@@ -507,17 +517,15 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -507,17 +517,15 @@ class IsControlFlowVisitor(gast.NodeVisitor):
def _visit_For(self, node): def _visit_For(self, node):
assert isinstance(node, gast.For) assert isinstance(node, gast.For)
if not isinstance(node.iter, gast.Call): if isinstance(node.iter, gast.Call):
return # for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy())
# for in range(v.numpy()) or for in enumerate(v.numpy())
if isinstance(node.iter.func, gast.Name): if isinstance(node.iter.func, gast.Name):
if node.iter.func.id == "range" or node.iter.func.id == "enumerate": if node.iter.func.id == "range" or node.iter.func.id == "enumerate":
for arg in node.iter.args: for arg in node.iter.args:
self.visit(arg) self.visit(arg)
else: else:
return return
# for in v.numpy() # for in var.numpy()
elif isinstance(node.iter.func, gast.Attribute): elif isinstance(node.iter.func, gast.Attribute):
if node.iter.func.attr == 'numpy': if node.iter.func.attr == 'numpy':
self._visit_Call(node.iter) self._visit_Call(node.iter)
...@@ -525,6 +533,11 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -525,6 +533,11 @@ class IsControlFlowVisitor(gast.NodeVisitor):
return return
else: else:
return return
elif isinstance(node.iter, gast.Name):
# for in var
self.visit(node.iter)
else:
return
for child_node in gast.walk(node): for child_node in gast.walk(node):
if isinstance(child_node, (gast.Continue, gast.Break)): if isinstance(child_node, (gast.Continue, gast.Break)):
...@@ -655,10 +668,10 @@ class ForNodeVisitor(object): ...@@ -655,10 +668,10 @@ class ForNodeVisitor(object):
In this process, the semantics of for does not change. In this process, the semantics of for does not change.
Now only can parse 3 type statements: Now only can parse 3 type statements (Here var is VarBase(Tensor)):
1). for x in range(***) 1). for x in range(var[*]|var.numpy()[*])
2). for x in var.numpy() 2). for x in var|var.numpy()
3). for i, x enumerate(var.numpy()) 3). for i, x enumerate(var|var.numpy())
""" """
def __init__(self, for_node): def __init__(self, for_node):
...@@ -678,28 +691,29 @@ class ForNodeVisitor(object): ...@@ -678,28 +691,29 @@ class ForNodeVisitor(object):
# 3. key shared node or names # 3. key shared node or names
# - x: # - x:
# - for x in range(***) # - for x in range(***)
# - for x in var.numpy() # - for x in var|var.numpy()
# - for i, x enumerate(var.numpy()) # - for i, x enumerate(var|var.numpy())
self.iter_var_name = self._get_iter_var_name() self.iter_var_name = self._get_iter_var_name()
# - created index var to slice Variable: __for_loop_var_index_0 # - created index var to slice Variable: __for_loop_var_index_0
# - for x in var.numpy() # - for x in var|var.numpy()
# - for i, x enumerate(var.numpy()) # - for i, x enumerate(var|var.numpy())
self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX) self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX)
# - created shape var to build loop condition: __for_loop_var_shape_0 # - created shape var to build loop condition: __for_loop_var_shape_0
# - for x in var.numpy() # - for x in var|var.numpy()
# - for i, x enumerate(var.numpy()) # - for i, x enumerate(var|var.numpy())
# - for x in var
self.iter_var_shape_name = unique_name.generate( self.iter_var_shape_name = unique_name.generate(
FOR_ITER_VAR_SHAPE_PREFIX) FOR_ITER_VAR_SHAPE_PREFIX)
# - var.numpy() # - var.numpy()/var
# - for x in var.numpy() # - for x in var|var.numpy()
# - for i, x enumerate(var.numpy()) # - for i, x enumerate(var|var.numpy())
self.iter_node = self._get_iter_node() self.iter_node = self._get_iter_node()
# - enumeate i: # - enumeate i:
# - for i, x enumerate(var.numpy()) # - for i, x enumerate(var|var.numpy())
self.enum_idx_name = self._get_enum_idx_name() self.enum_idx_name = self._get_enum_idx_name()
# - range/enumerate args length # - range/enumerate args length
...@@ -717,16 +731,23 @@ class ForNodeVisitor(object): ...@@ -717,16 +731,23 @@ class ForNodeVisitor(object):
raise None raise None
def is_for_range_iter(self): def is_for_range_iter(self):
return isinstance(self.node.iter.func, return isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
gast.Name) and self.node.iter.func.id == "range" gast.Name) and self.node.iter.func.id == "range"
def is_for_iter(self): def is_for_iter(self):
return isinstance( if isinstance(self.node.iter, gast.Name):
return True
elif isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func, self.node.iter.func,
gast.Attribute) and self.node.iter.func.attr == 'numpy' gast.Attribute) and self.node.iter.func.attr == 'numpy':
return True
else:
return False
def is_for_enumerate_iter(self): def is_for_enumerate_iter(self):
return isinstance(self.node.iter.func, return isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
gast.Name) and self.node.iter.func.id == "enumerate" gast.Name) and self.node.iter.func.id == "enumerate"
def _args_check(self): def _args_check(self):
...@@ -811,6 +832,10 @@ class ForNodeVisitor(object): ...@@ -811,6 +832,10 @@ class ForNodeVisitor(object):
def _build_var_shape_assign_node(self): def _build_var_shape_assign_node(self):
# get variable shape as iter length # get variable shape as iter length
if isinstance(self.iter_node, gast.Call):
iter_var = self.iter_node.func
else:
iter_var = self.iter_node
return gast.Assign( return gast.Assign(
targets=[ targets=[
gast.Name( gast.Name(
...@@ -819,7 +844,7 @@ class ForNodeVisitor(object): ...@@ -819,7 +844,7 @@ class ForNodeVisitor(object):
annotation=None, annotation=None,
type_comment=None) type_comment=None)
], ],
value=create_api_shape_node(self.iter_node.func)) value=create_api_shape_node(iter_var))
def _build_enum_init_node(self): def _build_enum_init_node(self):
enum_init_node = get_constant_variable_node( enum_init_node = get_constant_variable_node(
......
...@@ -24,9 +24,9 @@ from paddle.fluid.dygraph.jit import declarative ...@@ -24,9 +24,9 @@ from paddle.fluid.dygraph.jit import declarative
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
# 0. for in range with var case # 0. for in range var.numpy()[0]
@declarative @declarative
def dygraph_for_in_range(x): def for_in_range(x):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]): for i in range(x.numpy()[0]):
...@@ -36,7 +36,7 @@ def dygraph_for_in_range(x): ...@@ -36,7 +36,7 @@ def dygraph_for_in_range(x):
# 1. for iter list # 1. for iter list
@declarative @declarative
def dygraph_for_iter_list(x_array): def for_iter_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
for x in x_array: for x in x_array:
z = z + x z = z + x
...@@ -45,7 +45,7 @@ def dygraph_for_iter_list(x_array): ...@@ -45,7 +45,7 @@ def dygraph_for_iter_list(x_array):
# 2. for enumerate list # 2. for enumerate list
@declarative @declarative
def dygraph_for_enumerate_list(x_array): def for_enumerate_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
for i, x in enumerate(x_array): for i, x in enumerate(x_array):
z = z + x + i z = z + x + i
...@@ -54,7 +54,7 @@ def dygraph_for_enumerate_list(x_array): ...@@ -54,7 +54,7 @@ def dygraph_for_enumerate_list(x_array):
# 3. for iter var.numpy() # 3. for iter var.numpy()
@declarative @declarative
def dygraph_for_iter_var_numpy(x_array): def for_iter_var_numpy(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
for x in x_array.numpy(): for x in x_array.numpy():
...@@ -64,7 +64,7 @@ def dygraph_for_iter_var_numpy(x_array): ...@@ -64,7 +64,7 @@ def dygraph_for_iter_var_numpy(x_array):
# 4. for enumerate var.numpy() # 4. for enumerate var.numpy()
@declarative @declarative
def dygraph_for_enumerate_var_numpy(x_array): def for_enumerate_var_numpy(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -76,7 +76,7 @@ def dygraph_for_enumerate_var_numpy(x_array): ...@@ -76,7 +76,7 @@ def dygraph_for_enumerate_var_numpy(x_array):
# 5. for enumerate var.numpy() with start # 5. for enumerate var.numpy() with start
@declarative @declarative
def dygraph_for_enumerate_var_numpy_with_start(x_array): def for_enumerate_var_numpy_with_start(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -88,7 +88,7 @@ def dygraph_for_enumerate_var_numpy_with_start(x_array): ...@@ -88,7 +88,7 @@ def dygraph_for_enumerate_var_numpy_with_start(x_array):
# 6. for in range with break # 6. for in range with break
@declarative @declarative
def dygraph_for_in_range_with_break(x): def for_in_range_with_break(x):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
for i in range(x.numpy()[0]): for i in range(x.numpy()[0]):
...@@ -100,7 +100,7 @@ def dygraph_for_in_range_with_break(x): ...@@ -100,7 +100,7 @@ def dygraph_for_in_range_with_break(x):
# 7. for enumerate var.numpy() with break # 7. for enumerate var.numpy() with break
@declarative @declarative
def dygraph_for_enumerate_var_numpy_with_break(x_array): def for_enumerate_var_numpy_with_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -114,7 +114,7 @@ def dygraph_for_enumerate_var_numpy_with_break(x_array): ...@@ -114,7 +114,7 @@ def dygraph_for_enumerate_var_numpy_with_break(x_array):
# 8. for enumerate var.numpy() with continue # 8. for enumerate var.numpy() with continue
@declarative @declarative
def dygraph_for_enumerate_var_numpy_with_continue(x_array): def for_enumerate_var_numpy_with_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -128,7 +128,7 @@ def dygraph_for_enumerate_var_numpy_with_continue(x_array): ...@@ -128,7 +128,7 @@ def dygraph_for_enumerate_var_numpy_with_continue(x_array):
# 9. for enumerate var.numpy() with start & break # 9. for enumerate var.numpy() with start & break
@declarative @declarative
def dygraph_for_enumerate_var_numpy_with_start_break(x_array): def for_enumerate_var_numpy_with_start_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -142,7 +142,7 @@ def dygraph_for_enumerate_var_numpy_with_start_break(x_array): ...@@ -142,7 +142,7 @@ def dygraph_for_enumerate_var_numpy_with_start_break(x_array):
# 10. for enumerate var.numpy() with start & continue # 10. for enumerate var.numpy() with start & continue
@declarative @declarative
def dygraph_for_enumerate_var_numpy_with_start_continue(x_array): def for_enumerate_var_numpy_with_start_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -154,6 +154,28 @@ def dygraph_for_enumerate_var_numpy_with_start_continue(x_array): ...@@ -154,6 +154,28 @@ def dygraph_for_enumerate_var_numpy_with_start_continue(x_array):
return y, z return y, z
# 11. for iter var
@declarative
def for_iter_var(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for x in x_array:
z = z + x
return z
# 12. for enumerate var
@declarative
def for_enumerate_var(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, x in enumerate(x_array):
y = y + i
z = z + x
return y, z
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
...@@ -206,7 +228,7 @@ class TestForInRange(TestTransform): ...@@ -206,7 +228,7 @@ class TestForInRange(TestTransform):
self.input = np.array([5]) self.input = np.array([5])
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_in_range self.dygraph_func = for_in_range
def test_transformed_result_compare(self): def test_transformed_result_compare(self):
self.transformed_result_compare() self.transformed_result_compare()
...@@ -214,7 +236,7 @@ class TestForInRange(TestTransform): ...@@ -214,7 +236,7 @@ class TestForInRange(TestTransform):
class TestForIterList(TestTransform): class TestForIterList(TestTransform):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_iter_list self.dygraph_func = for_iter_list
def test_transformed_result_compare(self): def test_transformed_result_compare(self):
self.transformed_result_compare() self.transformed_result_compare()
...@@ -222,12 +244,12 @@ class TestForIterList(TestTransform): ...@@ -222,12 +244,12 @@ class TestForIterList(TestTransform):
class TestForEnumerateSimple(TestForIterList): class TestForEnumerateSimple(TestForIterList):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_list self.dygraph_func = for_enumerate_list
class TestForInRangeWithBreak(TestForInRange): class TestForInRangeWithBreak(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_in_range_with_break self.dygraph_func = for_in_range_with_break
class TestForIterVarNumpy(TestTransform): class TestForIterVarNumpy(TestTransform):
...@@ -235,7 +257,7 @@ class TestForIterVarNumpy(TestTransform): ...@@ -235,7 +257,7 @@ class TestForIterVarNumpy(TestTransform):
self.input = np.array([1, 2, 3, 4, 5]) self.input = np.array([1, 2, 3, 4, 5])
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_iter_var_numpy self.dygraph_func = for_iter_var_numpy
def test_transformed_result_compare(self): def test_transformed_result_compare(self):
self.transformed_result_compare() self.transformed_result_compare()
...@@ -243,32 +265,42 @@ class TestForIterVarNumpy(TestTransform): ...@@ -243,32 +265,42 @@ class TestForIterVarNumpy(TestTransform):
class TestForEnumerateVarNumpy(TestForIterVarNumpy): class TestForEnumerateVarNumpy(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy self.dygraph_func = for_enumerate_var_numpy
class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start self.dygraph_func = for_enumerate_var_numpy_with_start
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_break self.dygraph_func = for_enumerate_var_numpy_with_break
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_continue self.dygraph_func = for_enumerate_var_numpy_with_continue
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_break self.dygraph_func = for_enumerate_var_numpy_with_start_break
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_continue self.dygraph_func = for_enumerate_var_numpy_with_start_continue
class TestForIterVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_iter_var
class TestForEnumerateVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册