未验证 提交 a0b60716 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Support grammar: for ele in var[idx] (#29541)

Support to transformfor ele in var stms in which var is a slice of Tensor.
上级 b59b6d7a
...@@ -882,6 +882,8 @@ class ForNodeVisitor(object): ...@@ -882,6 +882,8 @@ class ForNodeVisitor(object):
self.node.iter.func, self.node.iter.func,
gast.Attribute) and self.node.iter.func.attr == 'numpy': gast.Attribute) and self.node.iter.func.attr == 'numpy':
return True return True
elif isinstance(self.node.iter, gast.Subscript):
return True
else: else:
return False return False
......
...@@ -159,6 +159,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array): ...@@ -159,6 +159,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array):
def for_iter_var(x_array): def for_iter_var(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0) z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array) x_array = fluid.dygraph.to_variable(x_array)
for x in x_array: for x in x_array:
z = z + x z = z + x
return z return z
...@@ -221,6 +222,17 @@ def for_enumerate_var_with_nested_range(x_array): ...@@ -221,6 +222,17 @@ def for_enumerate_var_with_nested_range(x_array):
return x return x
# 16. for iter var[idx]
@paddle.jit.to_static
def for_iter_var_idx(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for x in x_array[0:]:
z = z + x
return z
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
...@@ -343,6 +355,11 @@ class TestForIterVar(TestForIterVarNumpy): ...@@ -343,6 +355,11 @@ class TestForIterVar(TestForIterVarNumpy):
self.dygraph_func = for_iter_var self.dygraph_func = for_iter_var
class TestForIterVarIdx(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_iter_var_idx
class TestForEnumerateVar(TestForIterVarNumpy): class TestForEnumerateVar(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var self.dygraph_func = for_enumerate_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册