未验证 提交 d772a9aa 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]Support `for i in [1,2,3]` statements in dy2stat (#37259)

* support `for i in [1,2,3]` statements in dy2stat

* add test case

* fix ci

* remove wrong code
上级 c98d175d
...@@ -1045,7 +1045,8 @@ class ForNodeVisitor(object): ...@@ -1045,7 +1045,8 @@ class ForNodeVisitor(object):
gast.Name) and self.node.iter.func.id == "range" gast.Name) and self.node.iter.func.id == "range"
def is_for_iter(self): def is_for_iter(self):
if isinstance(self.node.iter, (gast.Name, gast.Attribute)): if isinstance(self.node.iter,
(gast.Name, gast.Attribute, gast.List, gast.Tuple)):
return True return True
elif isinstance(self.node.iter, gast.Call) and isinstance( elif isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func, self.node.iter.func,
......
...@@ -304,6 +304,24 @@ class ForwardContainsForLayer(paddle.nn.Layer): ...@@ -304,6 +304,24 @@ class ForwardContainsForLayer(paddle.nn.Layer):
return z return z
# 21. for original list
@paddle.jit.to_static
def for_original_list():
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in [1, 2, 3]:
z = z + x
return z
# 22. for original tuple
@paddle.jit.to_static
def for_original_tuple():
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in (1, 2, 3):
z = z + x
return z
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
...@@ -344,6 +362,13 @@ class TestTransform(TestTransformBase): ...@@ -344,6 +362,13 @@ class TestTransform(TestTransformBase):
self.assertTrue(np.allclose(x.numpy(), y.numpy())) self.assertTrue(np.allclose(x.numpy(), y.numpy()))
class TestTransformForOriginalList(TestTransform):
def _run(self, to_static):
program_translator.enable(to_static)
with fluid.dygraph.guard():
return self.dygraph_func()
class TestTransformError(TestTransformBase): class TestTransformError(TestTransformBase):
def transformed_error(self, etype): def transformed_error(self, etype):
with self.assertRaises(etype): with self.assertRaises(etype):
...@@ -471,5 +496,21 @@ class TestForwardContainsForLayer(TestForIterVarNumpy): ...@@ -471,5 +496,21 @@ class TestForwardContainsForLayer(TestForIterVarNumpy):
self.dygraph_func = ForwardContainsForLayer() self.dygraph_func = ForwardContainsForLayer()
class TestForOriginalList(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_list
def test_transformed_result_compare(self):
self.transformed_result_compare()
class TestForOriginalTuple(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_tuple
def test_transformed_result_compare(self):
self.transformed_result_compare()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册