diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index b43c20424c3b71a529941e00fdaef18a3ec9d713..24e7bf08e0f5b06d15553e37eb2a5873f641df4e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -424,9 +424,9 @@ class LoopTransformer(gast.NodeTransformer): # 1. check whether need to transform # NOTE: Current need transform cases: - # 1). for x in range(VarBase.numpy()[0]) - # 2). for x in VarBase.numpy() - # 3). for i, x in enumerate(VarBase.numpy()) + # 1). for x in range(VarBase[0]|VarBase.numpy()[0]) + # 2). for x in VarBase|VarBase.numpy() + # 3). for i, x in enumerate(VarBase|VarBase.numpy()) if not self.name_visitor.is_control_flow_loop(node): return [node] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 460435c38d3ac39256af305a001efd201f57eb38..2b3fbc725ef54f2ed5cef08a22ac2dd3800d63db 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -239,7 +239,15 @@ def update_args_of_func(node, dygraph_node, method_name): def create_api_shape_node(tensor_shape_node): - assert isinstance(tensor_shape_node, (gast.Attribute, gast.Subscript)) + assert isinstance(tensor_shape_node, + (gast.Name, gast.Attribute, gast.Subscript)) + + if isinstance(tensor_shape_node, gast.Name): + api_shape_node = gast.Call( + func=gast.parse('fluid.layers.shape').body[0].value, + args=[tensor_shape_node], + keywords=[]) + return api_shape_node if isinstance(tensor_shape_node, gast.Attribute): api_shape_node = gast.Call( @@ -453,8 +461,10 @@ class IsControlFlowVisitor(gast.NodeVisitor): gast.While must meet at least one of the requirements 1 to 5: 4. has `break` statement. 5. has `continue` statement. - gast.For must meet at least one of the requirements 4 to 6: + gast.For must meet at least one of the requirements 4 to 8: 6. calls `range` function in `for` statement and the argument of range is Tensor. + 7. calls `enumerate` function in `for` statement and the argument of enumerate is Tensor. + 8. the iterable varaible in `for` statement is Tensor. TODO: Support non-range case The following examples should not be considered as control_flow_if: @@ -507,22 +517,25 @@ class IsControlFlowVisitor(gast.NodeVisitor): def _visit_For(self, node): assert isinstance(node, gast.For) - if not isinstance(node.iter, gast.Call): - return - - # for in range(v.numpy()) or for in enumerate(v.numpy()) - if isinstance(node.iter.func, gast.Name): - if node.iter.func.id == "range" or node.iter.func.id == "enumerate": - for arg in node.iter.args: - self.visit(arg) - else: - return - # for in v.numpy() - elif isinstance(node.iter.func, gast.Attribute): - if node.iter.func.attr == 'numpy': - self._visit_Call(node.iter) + if isinstance(node.iter, gast.Call): + # for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy()) + if isinstance(node.iter.func, gast.Name): + if node.iter.func.id == "range" or node.iter.func.id == "enumerate": + for arg in node.iter.args: + self.visit(arg) + else: + return + # for in var.numpy() + elif isinstance(node.iter.func, gast.Attribute): + if node.iter.func.attr == 'numpy': + self._visit_Call(node.iter) + else: + return else: return + elif isinstance(node.iter, gast.Name): + # for in var + self.visit(node.iter) else: return @@ -655,10 +668,10 @@ class ForNodeVisitor(object): In this process, the semantics of for does not change. - Now only can parse 3 type statements: - 1). for x in range(***) - 2). for x in var.numpy() - 3). for i, x enumerate(var.numpy()) + Now only can parse 3 type statements (Here var is VarBase(Tensor)): + 1). for x in range(var[*]|var.numpy()[*]) + 2). for x in var|var.numpy() + 3). for i, x enumerate(var|var.numpy()) """ def __init__(self, for_node): @@ -678,28 +691,29 @@ class ForNodeVisitor(object): # 3. key shared node or names # - x: # - for x in range(***) - # - for x in var.numpy() - # - for i, x enumerate(var.numpy()) + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) self.iter_var_name = self._get_iter_var_name() # - created index var to slice Variable: __for_loop_var_index_0 - # - for x in var.numpy() - # - for i, x enumerate(var.numpy()) + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX) # - created shape var to build loop condition: __for_loop_var_shape_0 - # - for x in var.numpy() - # - for i, x enumerate(var.numpy()) + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) + # - for x in var self.iter_var_shape_name = unique_name.generate( FOR_ITER_VAR_SHAPE_PREFIX) - # - var.numpy() - # - for x in var.numpy() - # - for i, x enumerate(var.numpy()) + # - var.numpy()/var + # - for x in var|var.numpy() + # - for i, x enumerate(var|var.numpy()) self.iter_node = self._get_iter_node() # - enumeate i: - # - for i, x enumerate(var.numpy()) + # - for i, x enumerate(var|var.numpy()) self.enum_idx_name = self._get_enum_idx_name() # - range/enumerate args length @@ -717,17 +731,24 @@ class ForNodeVisitor(object): raise None def is_for_range_iter(self): - return isinstance(self.node.iter.func, - gast.Name) and self.node.iter.func.id == "range" + return isinstance(self.node.iter, gast.Call) and isinstance( + self.node.iter.func, + gast.Name) and self.node.iter.func.id == "range" def is_for_iter(self): - return isinstance( - self.node.iter.func, - gast.Attribute) and self.node.iter.func.attr == 'numpy' + if isinstance(self.node.iter, gast.Name): + return True + elif isinstance(self.node.iter, gast.Call) and isinstance( + self.node.iter.func, + gast.Attribute) and self.node.iter.func.attr == 'numpy': + return True + else: + return False def is_for_enumerate_iter(self): - return isinstance(self.node.iter.func, - gast.Name) and self.node.iter.func.id == "enumerate" + return isinstance(self.node.iter, gast.Call) and isinstance( + self.node.iter.func, + gast.Name) and self.node.iter.func.id == "enumerate" def _args_check(self): if self.is_for_range_iter(): @@ -811,6 +832,10 @@ class ForNodeVisitor(object): def _build_var_shape_assign_node(self): # get variable shape as iter length + if isinstance(self.iter_node, gast.Call): + iter_var = self.iter_node.func + else: + iter_var = self.iter_node return gast.Assign( targets=[ gast.Name( @@ -819,7 +844,7 @@ class ForNodeVisitor(object): annotation=None, type_comment=None) ], - value=create_api_shape_node(self.iter_node.func)) + value=create_api_shape_node(iter_var)) def _build_enum_init_node(self): enum_init_node = get_constant_variable_node( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 2924f1713665c8e78a1a97edd56beb6f8769d3c8..3b15e477e7da8a5857719dab457d795272561ed4 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -24,9 +24,9 @@ from paddle.fluid.dygraph.jit import declarative program_translator = ProgramTranslator() -# 0. for in range with var case +# 0. for in range var.numpy()[0] @declarative -def dygraph_for_in_range(x): +def for_in_range(x): z = fluid.layers.fill_constant([1], 'int32', 0) x = fluid.dygraph.to_variable(x) for i in range(x.numpy()[0]): @@ -36,7 +36,7 @@ def dygraph_for_in_range(x): # 1. for iter list @declarative -def dygraph_for_iter_list(x_array): +def for_iter_list(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) for x in x_array: z = z + x @@ -45,7 +45,7 @@ def dygraph_for_iter_list(x_array): # 2. for enumerate list @declarative -def dygraph_for_enumerate_list(x_array): +def for_enumerate_list(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) for i, x in enumerate(x_array): z = z + x + i @@ -54,7 +54,7 @@ def dygraph_for_enumerate_list(x_array): # 3. for iter var.numpy() @declarative -def dygraph_for_iter_var_numpy(x_array): +def for_iter_var_numpy(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) for x in x_array.numpy(): @@ -64,7 +64,7 @@ def dygraph_for_iter_var_numpy(x_array): # 4. for enumerate var.numpy() @declarative -def dygraph_for_enumerate_var_numpy(x_array): +def for_enumerate_var_numpy(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -76,7 +76,7 @@ def dygraph_for_enumerate_var_numpy(x_array): # 5. for enumerate var.numpy() with start @declarative -def dygraph_for_enumerate_var_numpy_with_start(x_array): +def for_enumerate_var_numpy_with_start(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -88,7 +88,7 @@ def dygraph_for_enumerate_var_numpy_with_start(x_array): # 6. for in range with break @declarative -def dygraph_for_in_range_with_break(x): +def for_in_range_with_break(x): z = fluid.layers.fill_constant([1], 'int32', 0) x = fluid.dygraph.to_variable(x) for i in range(x.numpy()[0]): @@ -100,7 +100,7 @@ def dygraph_for_in_range_with_break(x): # 7. for enumerate var.numpy() with break @declarative -def dygraph_for_enumerate_var_numpy_with_break(x_array): +def for_enumerate_var_numpy_with_break(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -114,7 +114,7 @@ def dygraph_for_enumerate_var_numpy_with_break(x_array): # 8. for enumerate var.numpy() with continue @declarative -def dygraph_for_enumerate_var_numpy_with_continue(x_array): +def for_enumerate_var_numpy_with_continue(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -128,7 +128,7 @@ def dygraph_for_enumerate_var_numpy_with_continue(x_array): # 9. for enumerate var.numpy() with start & break @declarative -def dygraph_for_enumerate_var_numpy_with_start_break(x_array): +def for_enumerate_var_numpy_with_start_break(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -142,7 +142,7 @@ def dygraph_for_enumerate_var_numpy_with_start_break(x_array): # 10. for enumerate var.numpy() with start & continue @declarative -def dygraph_for_enumerate_var_numpy_with_start_continue(x_array): +def for_enumerate_var_numpy_with_start_continue(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -154,6 +154,28 @@ def dygraph_for_enumerate_var_numpy_with_start_continue(x_array): return y, z +# 11. for iter var +@declarative +def for_iter_var(x_array): + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for x in x_array: + z = z + x + return z + + +# 12. for enumerate var +@declarative +def for_enumerate_var(x_array): + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, x in enumerate(x_array): + y = y + i + z = z + x + return y, z + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -206,7 +228,7 @@ class TestForInRange(TestTransform): self.input = np.array([5]) def set_test_func(self): - self.dygraph_func = dygraph_for_in_range + self.dygraph_func = for_in_range def test_transformed_result_compare(self): self.transformed_result_compare() @@ -214,7 +236,7 @@ class TestForInRange(TestTransform): class TestForIterList(TestTransform): def set_test_func(self): - self.dygraph_func = dygraph_for_iter_list + self.dygraph_func = for_iter_list def test_transformed_result_compare(self): self.transformed_result_compare() @@ -222,12 +244,12 @@ class TestForIterList(TestTransform): class TestForEnumerateSimple(TestForIterList): def set_test_func(self): - self.dygraph_func = dygraph_for_enumerate_list + self.dygraph_func = for_enumerate_list class TestForInRangeWithBreak(TestForInRange): def set_test_func(self): - self.dygraph_func = dygraph_for_in_range_with_break + self.dygraph_func = for_in_range_with_break class TestForIterVarNumpy(TestTransform): @@ -235,7 +257,7 @@ class TestForIterVarNumpy(TestTransform): self.input = np.array([1, 2, 3, 4, 5]) def set_test_func(self): - self.dygraph_func = dygraph_for_iter_var_numpy + self.dygraph_func = for_iter_var_numpy def test_transformed_result_compare(self): self.transformed_result_compare() @@ -243,32 +265,42 @@ class TestForIterVarNumpy(TestTransform): class TestForEnumerateVarNumpy(TestForIterVarNumpy): def set_test_func(self): - self.dygraph_func = dygraph_for_enumerate_var_numpy + self.dygraph_func = for_enumerate_var_numpy class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy): def set_test_func(self): - self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start + self.dygraph_func = for_enumerate_var_numpy_with_start class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): def set_test_func(self): - self.dygraph_func = dygraph_for_enumerate_var_numpy_with_break + self.dygraph_func = for_enumerate_var_numpy_with_break class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): def set_test_func(self): - self.dygraph_func = dygraph_for_enumerate_var_numpy_with_continue + self.dygraph_func = for_enumerate_var_numpy_with_continue class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): def set_test_func(self): - self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_break + self.dygraph_func = for_enumerate_var_numpy_with_start_break class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): def set_test_func(self): - self.dygraph_func = dygraph_for_enumerate_var_numpy_with_start_continue + self.dygraph_func = for_enumerate_var_numpy_with_start_continue + + +class TestForIterVar(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = for_iter_var + + +class TestForEnumerateVar(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = for_enumerate_var if __name__ == '__main__':