diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index db3024821f885fc6b2f47f90cce28f5e96f70a54..f3ab02c62f9802065f9d5a79154728bbaf017668 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -69,6 +69,7 @@ dygraph_class_to_static_api = { FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' +FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var' # FullArgSpec is valid from Python3. Defined a Namedtuple to # to make it available in Python2. @@ -772,6 +773,20 @@ class NameNodeReplaceTransformer(gast.NodeTransformer): def __init__(self, root_node, target_name, replace_node): assert isinstance(target_name, str) + + # NOTE(liym27): + # Use gast.Name to replace gast.Name, otherwise, errors may occur. + # + # For examples: + # If using a gast.Subscript to replace gast.Name, and the original gast.Name + # is in the arguments of FunctionDef, an exception will be raised. + # + # ``` + # def func(x[i])) # x[i] can not be a argument + # # ... + # ``` + + assert isinstance(replace_node, gast.Name) self.target_name = target_name self.replace_node = replace_node @@ -908,10 +923,14 @@ class ForNodeVisitor(object): cond_stmt = self._build_cond_stmt(step_node, compare_node) body_stmts = self.body - var_slice_node = self._build_var_slice_node() + + # NOTE(liym27): Here add a gast.Assign, and the target of it is gast.Name. + # In NameNodeReplaceTransformer, using gast.Name to replace gast.Name is safe. + target_node, assign_node = self._build_assign_var_slice_node() + body_stmts[0:0] = [assign_node] for body_node in body_stmts: NameNodeReplaceTransformer(body_node, self.iter_var_name, - var_slice_node) + target_node) body_stmts.append(self._build_index_increase_node(step_node)) return init_stmts, cond_stmt, body_stmts @@ -927,10 +946,13 @@ class ForNodeVisitor(object): cond_stmt = self._build_cond_stmt(step_node, compare_node) body_stmts = self.body - var_slice_node = self._build_var_slice_node() + + target_node, assign_node = self._build_assign_var_slice_node() + body_stmts[0:0] = [assign_node] for body_node in body_stmts: NameNodeReplaceTransformer(body_node, self.iter_var_name, - var_slice_node) + target_node) + body_stmts.append(self._build_index_increase_node(step_node)) body_stmts.append(self._build_enum_increase_node()) @@ -1030,15 +1052,19 @@ class ForNodeVisitor(object): op=gast.Add(), value=step_node) - def _build_var_slice_node(self): - return gast.Subscript( + def _build_assign_var_slice_node(self): + var_slice_node = gast.Subscript( value=self.iter_node, slice=gast.Index(value=gast.Name( id=self.iter_idx_name, ctx=gast.Load(), annotation=None, type_comment=None)), - ctx=gast.Load()) + ctx=gast.Load(), ) + new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX) + target_node, assign_node = create_assign_node(new_iter_var_name, + var_slice_node) + return target_node, assign_node def _build_enum_increase_node(self): return gast.AugAssign( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 86cfcb9b3d817ded6edaf527b036085e92a77ec8..a74c56fc31766ccf39cbcbee6b1138573fe9de6a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -17,15 +17,15 @@ from __future__ import print_function import numpy as np import unittest +import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator -from paddle.fluid.dygraph.jit import declarative program_translator = ProgramTranslator() # 0. for in range var.numpy()[0] -@declarative +@paddle.jit.to_static def for_in_range(x): z = fluid.layers.fill_constant([1], 'int32', 0) x = fluid.dygraph.to_variable(x) @@ -35,7 +35,7 @@ def for_in_range(x): # 1. for iter list -@declarative +@paddle.jit.to_static def for_iter_list(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) for x in x_array: @@ -44,7 +44,7 @@ def for_iter_list(x_array): # 2. for enumerate list -@declarative +@paddle.jit.to_static def for_enumerate_list(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) for i, x in enumerate(x_array): @@ -53,7 +53,7 @@ def for_enumerate_list(x_array): # 3. for iter var.numpy() -@declarative +@paddle.jit.to_static def for_iter_var_numpy(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -63,7 +63,7 @@ def for_iter_var_numpy(x_array): # 4. for enumerate var.numpy() -@declarative +@paddle.jit.to_static def for_enumerate_var_numpy(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) @@ -75,7 +75,7 @@ def for_enumerate_var_numpy(x_array): # 5. for enumerate var.numpy() with start -@declarative +@paddle.jit.to_static def for_enumerate_var_numpy_with_start(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) @@ -87,7 +87,7 @@ def for_enumerate_var_numpy_with_start(x_array): # 6. for in range with break -@declarative +@paddle.jit.to_static def for_in_range_with_break(x): z = fluid.layers.fill_constant([1], 'int32', 0) x = fluid.dygraph.to_variable(x) @@ -99,7 +99,7 @@ def for_in_range_with_break(x): # 7. for enumerate var.numpy() with break -@declarative +@paddle.jit.to_static def for_enumerate_var_numpy_with_break(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) @@ -113,7 +113,7 @@ def for_enumerate_var_numpy_with_break(x_array): # 8. for enumerate var.numpy() with continue -@declarative +@paddle.jit.to_static def for_enumerate_var_numpy_with_continue(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) @@ -127,7 +127,7 @@ def for_enumerate_var_numpy_with_continue(x_array): # 9. for enumerate var.numpy() with start & break -@declarative +@paddle.jit.to_static def for_enumerate_var_numpy_with_start_break(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) @@ -141,7 +141,7 @@ def for_enumerate_var_numpy_with_start_break(x_array): # 10. for enumerate var.numpy() with start & continue -@declarative +@paddle.jit.to_static def for_enumerate_var_numpy_with_start_continue(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) @@ -155,7 +155,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array): # 11. for iter var -@declarative +@paddle.jit.to_static def for_iter_var(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) @@ -165,7 +165,7 @@ def for_iter_var(x_array): # 12. for enumerate var -@declarative +@paddle.jit.to_static def for_enumerate_var(x_array): y = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0) @@ -177,7 +177,7 @@ def for_enumerate_var(x_array): # 13. for iter list[var] -@declarative +@paddle.jit.to_static def for_iter_var_list(x): # 1. prepare data, ref test_list.py x = fluid.dygraph.to_variable(x) @@ -193,7 +193,7 @@ def for_iter_var_list(x): # 14. for enumerate list[var] -@declarative +@paddle.jit.to_static def for_enumerate_var_list(x): # 1. prepare data, ref test_list.py x = fluid.dygraph.to_variable(x) @@ -210,6 +210,17 @@ def for_enumerate_var_list(x): return y, z +# 15. for enumerate list[var] with a nested for range +@paddle.jit.to_static +def for_enumerate_var_with_nested_range(x_array): + x = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + for i, num in enumerate(x_array): + for idx in range(num): + x = x + num + return x + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -337,6 +348,11 @@ class TestForEnumerateVar(TestForIterVarNumpy): self.dygraph_func = for_enumerate_var +class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = for_enumerate_var_with_nested_range + + class TestForIterVarList(TestForInRange): def set_test_func(self): self.dygraph_func = for_iter_var_list