未验证 提交 d10eb700 编写于 作者: L liym27 提交者: GitHub

[cherry-pick 2.0rc1][Dy2Stat] Fix bug: Do not use gast.Subscript to replace...

[cherry-pick 2.0rc1][Dy2Stat] Fix bug: Do not use gast.Subscript to replace gast.Name in when transforming for_enumerate_loop (#29310) (#29361)
上级 0e7539e7
...@@ -69,6 +69,7 @@ dygraph_class_to_static_api = { ...@@ -69,6 +69,7 @@ dygraph_class_to_static_api = {
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
# FullArgSpec is valid from Python3. Defined a Namedtuple to # FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2. # to make it available in Python2.
...@@ -772,6 +773,20 @@ class NameNodeReplaceTransformer(gast.NodeTransformer): ...@@ -772,6 +773,20 @@ class NameNodeReplaceTransformer(gast.NodeTransformer):
def __init__(self, root_node, target_name, replace_node): def __init__(self, root_node, target_name, replace_node):
assert isinstance(target_name, str) assert isinstance(target_name, str)
# NOTE(liym27):
# Use gast.Name to replace gast.Name, otherwise, errors may occur.
#
# For examples:
# If using a gast.Subscript to replace gast.Name, and the original gast.Name
# is in the arguments of FunctionDef, an exception will be raised.
#
# ```
# def func(x[i])) # x[i] can not be a argument
# # ...
# ```
assert isinstance(replace_node, gast.Name)
self.target_name = target_name self.target_name = target_name
self.replace_node = replace_node self.replace_node = replace_node
...@@ -908,10 +923,14 @@ class ForNodeVisitor(object): ...@@ -908,10 +923,14 @@ class ForNodeVisitor(object):
cond_stmt = self._build_cond_stmt(step_node, compare_node) cond_stmt = self._build_cond_stmt(step_node, compare_node)
body_stmts = self.body body_stmts = self.body
var_slice_node = self._build_var_slice_node()
# NOTE(liym27): Here add a gast.Assign, and the target of it is gast.Name.
# In NameNodeReplaceTransformer, using gast.Name to replace gast.Name is safe.
target_node, assign_node = self._build_assign_var_slice_node()
body_stmts[0:0] = [assign_node]
for body_node in body_stmts: for body_node in body_stmts:
NameNodeReplaceTransformer(body_node, self.iter_var_name, NameNodeReplaceTransformer(body_node, self.iter_var_name,
var_slice_node) target_node)
body_stmts.append(self._build_index_increase_node(step_node)) body_stmts.append(self._build_index_increase_node(step_node))
return init_stmts, cond_stmt, body_stmts return init_stmts, cond_stmt, body_stmts
...@@ -927,10 +946,13 @@ class ForNodeVisitor(object): ...@@ -927,10 +946,13 @@ class ForNodeVisitor(object):
cond_stmt = self._build_cond_stmt(step_node, compare_node) cond_stmt = self._build_cond_stmt(step_node, compare_node)
body_stmts = self.body body_stmts = self.body
var_slice_node = self._build_var_slice_node()
target_node, assign_node = self._build_assign_var_slice_node()
body_stmts[0:0] = [assign_node]
for body_node in body_stmts: for body_node in body_stmts:
NameNodeReplaceTransformer(body_node, self.iter_var_name, NameNodeReplaceTransformer(body_node, self.iter_var_name,
var_slice_node) target_node)
body_stmts.append(self._build_index_increase_node(step_node)) body_stmts.append(self._build_index_increase_node(step_node))
body_stmts.append(self._build_enum_increase_node()) body_stmts.append(self._build_enum_increase_node())
...@@ -1030,15 +1052,19 @@ class ForNodeVisitor(object): ...@@ -1030,15 +1052,19 @@ class ForNodeVisitor(object):
op=gast.Add(), op=gast.Add(),
value=step_node) value=step_node)
def _build_var_slice_node(self): def _build_assign_var_slice_node(self):
return gast.Subscript( var_slice_node = gast.Subscript(
value=self.iter_node, value=self.iter_node,
slice=gast.Index(value=gast.Name( slice=gast.Index(value=gast.Name(
id=self.iter_idx_name, id=self.iter_idx_name,
ctx=gast.Load(), ctx=gast.Load(),
annotation=None, annotation=None,
type_comment=None)), type_comment=None)),
ctx=gast.Load()) ctx=gast.Load(), )
new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX)
target_node, assign_node = create_assign_node(new_iter_var_name,
var_slice_node)
return target_node, assign_node
def _build_enum_increase_node(self): def _build_enum_increase_node(self):
return gast.AugAssign( return gast.AugAssign(
......
...@@ -17,15 +17,15 @@ from __future__ import print_function ...@@ -17,15 +17,15 @@ from __future__ import print_function
import numpy as np import numpy as np
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
# 0. for in range var.numpy()[0] # 0. for in range var.numpy()[0]
@declarative @paddle.jit.to_static
def for_in_range(x): def for_in_range(x):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
...@@ -35,7 +35,7 @@ def for_in_range(x): ...@@ -35,7 +35,7 @@ def for_in_range(x):
# 1. for iter list # 1. for iter list
@declarative @paddle.jit.to_static
def for_iter_list(x_array): def for_iter_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
for x in x_array: for x in x_array:
...@@ -44,7 +44,7 @@ def for_iter_list(x_array): ...@@ -44,7 +44,7 @@ def for_iter_list(x_array):
# 2. for enumerate list # 2. for enumerate list
@declarative @paddle.jit.to_static
def for_enumerate_list(x_array): def for_enumerate_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
for i, x in enumerate(x_array): for i, x in enumerate(x_array):
...@@ -53,7 +53,7 @@ def for_enumerate_list(x_array): ...@@ -53,7 +53,7 @@ def for_enumerate_list(x_array):
# 3. for iter var.numpy() # 3. for iter var.numpy()
@declarative @paddle.jit.to_static
def for_iter_var_numpy(x_array): def for_iter_var_numpy(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -63,7 +63,7 @@ def for_iter_var_numpy(x_array): ...@@ -63,7 +63,7 @@ def for_iter_var_numpy(x_array):
# 4. for enumerate var.numpy() # 4. for enumerate var.numpy()
@declarative @paddle.jit.to_static
def for_enumerate_var_numpy(x_array): def for_enumerate_var_numpy(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
...@@ -75,7 +75,7 @@ def for_enumerate_var_numpy(x_array): ...@@ -75,7 +75,7 @@ def for_enumerate_var_numpy(x_array):
# 5. for enumerate var.numpy() with start # 5. for enumerate var.numpy() with start
@declarative @paddle.jit.to_static
def for_enumerate_var_numpy_with_start(x_array): def for_enumerate_var_numpy_with_start(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
...@@ -87,7 +87,7 @@ def for_enumerate_var_numpy_with_start(x_array): ...@@ -87,7 +87,7 @@ def for_enumerate_var_numpy_with_start(x_array):
# 6. for in range with break # 6. for in range with break
@declarative @paddle.jit.to_static
def for_in_range_with_break(x): def for_in_range_with_break(x):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
...@@ -99,7 +99,7 @@ def for_in_range_with_break(x): ...@@ -99,7 +99,7 @@ def for_in_range_with_break(x):
# 7. for enumerate var.numpy() with break # 7. for enumerate var.numpy() with break
@declarative @paddle.jit.to_static
def for_enumerate_var_numpy_with_break(x_array): def for_enumerate_var_numpy_with_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
...@@ -113,7 +113,7 @@ def for_enumerate_var_numpy_with_break(x_array): ...@@ -113,7 +113,7 @@ def for_enumerate_var_numpy_with_break(x_array):
# 8. for enumerate var.numpy() with continue # 8. for enumerate var.numpy() with continue
@declarative @paddle.jit.to_static
def for_enumerate_var_numpy_with_continue(x_array): def for_enumerate_var_numpy_with_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
...@@ -127,7 +127,7 @@ def for_enumerate_var_numpy_with_continue(x_array): ...@@ -127,7 +127,7 @@ def for_enumerate_var_numpy_with_continue(x_array):
# 9. for enumerate var.numpy() with start & break # 9. for enumerate var.numpy() with start & break
@declarative @paddle.jit.to_static
def for_enumerate_var_numpy_with_start_break(x_array): def for_enumerate_var_numpy_with_start_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
...@@ -141,7 +141,7 @@ def for_enumerate_var_numpy_with_start_break(x_array): ...@@ -141,7 +141,7 @@ def for_enumerate_var_numpy_with_start_break(x_array):
# 10. for enumerate var.numpy() with start & continue # 10. for enumerate var.numpy() with start & continue
@declarative @paddle.jit.to_static
def for_enumerate_var_numpy_with_start_continue(x_array): def for_enumerate_var_numpy_with_start_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
...@@ -155,7 +155,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array): ...@@ -155,7 +155,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array):
# 11. for iter var # 11. for iter var
@declarative @paddle.jit.to_static
def for_iter_var(x_array): def for_iter_var(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
...@@ -165,7 +165,7 @@ def for_iter_var(x_array): ...@@ -165,7 +165,7 @@ def for_iter_var(x_array):
# 12. for enumerate var # 12. for enumerate var
@declarative @paddle.jit.to_static
def for_enumerate_var(x_array): def for_enumerate_var(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0) y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
...@@ -177,7 +177,7 @@ def for_enumerate_var(x_array): ...@@ -177,7 +177,7 @@ def for_enumerate_var(x_array):
# 13. for iter list[var] # 13. for iter list[var]
@declarative @paddle.jit.to_static
def for_iter_var_list(x): def for_iter_var_list(x):
# 1. prepare data, ref test_list.py # 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
...@@ -193,7 +193,7 @@ def for_iter_var_list(x): ...@@ -193,7 +193,7 @@ def for_iter_var_list(x):
# 14. for enumerate list[var] # 14. for enumerate list[var]
@declarative @paddle.jit.to_static
def for_enumerate_var_list(x): def for_enumerate_var_list(x):
# 1. prepare data, ref test_list.py # 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
...@@ -210,6 +210,17 @@ def for_enumerate_var_list(x): ...@@ -210,6 +210,17 @@ def for_enumerate_var_list(x):
return y, z return y, z
# 15. for enumerate list[var] with a nested for range
@paddle.jit.to_static
def for_enumerate_var_with_nested_range(x_array):
x = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, num in enumerate(x_array):
for idx in range(num):
x = x + num
return x
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
...@@ -337,6 +348,11 @@ class TestForEnumerateVar(TestForIterVarNumpy): ...@@ -337,6 +348,11 @@ class TestForEnumerateVar(TestForIterVarNumpy):
self.dygraph_func = for_enumerate_var self.dygraph_func = for_enumerate_var
class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_with_nested_range
class TestForIterVarList(TestForInRange): class TestForIterVarList(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_var_list self.dygraph_func = for_iter_var_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册