未验证 提交 61a8f287 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Fix bug: Do not use gast.Subscript to replace gast.Name in when...

[Dy2Stat] Fix bug: Do not use gast.Subscript to replace gast.Name in when transforming for_enumerate_loop (#29310)
上级 74bf3bed
......@@ -69,6 +69,7 @@ dygraph_class_to_static_api = {
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var'
# FullArgSpec is valid from Python3. Defined a Namedtuple to
# to make it available in Python2.
......@@ -772,6 +773,20 @@ class NameNodeReplaceTransformer(gast.NodeTransformer):
def __init__(self, root_node, target_name, replace_node):
assert isinstance(target_name, str)
# NOTE(liym27):
# Use gast.Name to replace gast.Name, otherwise, errors may occur.
#
# For examples:
# If using a gast.Subscript to replace gast.Name, and the original gast.Name
# is in the arguments of FunctionDef, an exception will be raised.
#
# ```
# def func(x[i])) # x[i] can not be a argument
# # ...
# ```
assert isinstance(replace_node, gast.Name)
self.target_name = target_name
self.replace_node = replace_node
......@@ -908,10 +923,14 @@ class ForNodeVisitor(object):
cond_stmt = self._build_cond_stmt(step_node, compare_node)
body_stmts = self.body
var_slice_node = self._build_var_slice_node()
# NOTE(liym27): Here add a gast.Assign, and the target of it is gast.Name.
# In NameNodeReplaceTransformer, using gast.Name to replace gast.Name is safe.
target_node, assign_node = self._build_assign_var_slice_node()
body_stmts[0:0] = [assign_node]
for body_node in body_stmts:
NameNodeReplaceTransformer(body_node, self.iter_var_name,
var_slice_node)
target_node)
body_stmts.append(self._build_index_increase_node(step_node))
return init_stmts, cond_stmt, body_stmts
......@@ -927,10 +946,13 @@ class ForNodeVisitor(object):
cond_stmt = self._build_cond_stmt(step_node, compare_node)
body_stmts = self.body
var_slice_node = self._build_var_slice_node()
target_node, assign_node = self._build_assign_var_slice_node()
body_stmts[0:0] = [assign_node]
for body_node in body_stmts:
NameNodeReplaceTransformer(body_node, self.iter_var_name,
var_slice_node)
target_node)
body_stmts.append(self._build_index_increase_node(step_node))
body_stmts.append(self._build_enum_increase_node())
......@@ -1030,15 +1052,19 @@ class ForNodeVisitor(object):
op=gast.Add(),
value=step_node)
def _build_var_slice_node(self):
return gast.Subscript(
def _build_assign_var_slice_node(self):
var_slice_node = gast.Subscript(
value=self.iter_node,
slice=gast.Index(value=gast.Name(
id=self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)),
ctx=gast.Load())
ctx=gast.Load(), )
new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX)
target_node, assign_node = create_assign_node(new_iter_var_name,
var_slice_node)
return target_node, assign_node
def _build_enum_increase_node(self):
return gast.AugAssign(
......
......@@ -17,15 +17,15 @@ from __future__ import print_function
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative
program_translator = ProgramTranslator()
# 0. for in range var.numpy()[0]
@declarative
@paddle.jit.to_static
def for_in_range(x):
z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
......@@ -35,7 +35,7 @@ def for_in_range(x):
# 1. for iter list
@declarative
@paddle.jit.to_static
def for_iter_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in x_array:
......@@ -44,7 +44,7 @@ def for_iter_list(x_array):
# 2. for enumerate list
@declarative
@paddle.jit.to_static
def for_enumerate_list(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
for i, x in enumerate(x_array):
......@@ -53,7 +53,7 @@ def for_enumerate_list(x_array):
# 3. for iter var.numpy()
@declarative
@paddle.jit.to_static
def for_iter_var_numpy(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -63,7 +63,7 @@ def for_iter_var_numpy(x_array):
# 4. for enumerate var.numpy()
@declarative
@paddle.jit.to_static
def for_enumerate_var_numpy(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
......@@ -75,7 +75,7 @@ def for_enumerate_var_numpy(x_array):
# 5. for enumerate var.numpy() with start
@declarative
@paddle.jit.to_static
def for_enumerate_var_numpy_with_start(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
......@@ -87,7 +87,7 @@ def for_enumerate_var_numpy_with_start(x_array):
# 6. for in range with break
@declarative
@paddle.jit.to_static
def for_in_range_with_break(x):
z = fluid.layers.fill_constant([1], 'int32', 0)
x = fluid.dygraph.to_variable(x)
......@@ -99,7 +99,7 @@ def for_in_range_with_break(x):
# 7. for enumerate var.numpy() with break
@declarative
@paddle.jit.to_static
def for_enumerate_var_numpy_with_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
......@@ -113,7 +113,7 @@ def for_enumerate_var_numpy_with_break(x_array):
# 8. for enumerate var.numpy() with continue
@declarative
@paddle.jit.to_static
def for_enumerate_var_numpy_with_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
......@@ -127,7 +127,7 @@ def for_enumerate_var_numpy_with_continue(x_array):
# 9. for enumerate var.numpy() with start & break
@declarative
@paddle.jit.to_static
def for_enumerate_var_numpy_with_start_break(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
......@@ -141,7 +141,7 @@ def for_enumerate_var_numpy_with_start_break(x_array):
# 10. for enumerate var.numpy() with start & continue
@declarative
@paddle.jit.to_static
def for_enumerate_var_numpy_with_start_continue(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
......@@ -155,7 +155,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array):
# 11. for iter var
@declarative
@paddle.jit.to_static
def for_iter_var(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
......@@ -165,7 +165,7 @@ def for_iter_var(x_array):
# 12. for enumerate var
@declarative
@paddle.jit.to_static
def for_enumerate_var(x_array):
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
......@@ -177,7 +177,7 @@ def for_enumerate_var(x_array):
# 13. for iter list[var]
@declarative
@paddle.jit.to_static
def for_iter_var_list(x):
# 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x)
......@@ -193,7 +193,7 @@ def for_iter_var_list(x):
# 14. for enumerate list[var]
@declarative
@paddle.jit.to_static
def for_enumerate_var_list(x):
# 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x)
......@@ -210,6 +210,17 @@ def for_enumerate_var_list(x):
return y, z
# 15. for enumerate list[var] with a nested for range
@paddle.jit.to_static
def for_enumerate_var_with_nested_range(x_array):
x = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for i, num in enumerate(x_array):
for idx in range(num):
x = x + num
return x
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......@@ -337,6 +348,11 @@ class TestForEnumerateVar(TestForIterVarNumpy):
self.dygraph_func = for_enumerate_var
class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_with_nested_range
class TestForIterVarList(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_iter_var_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册