未验证 提交 d772a9aa 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]Support `for i in [1,2,3]` statements in dy2stat (#37259)

* support `for i in [1,2,3]` statements in dy2stat

* add test case

* fix ci

* remove wrong code
上级 c98d175d
......@@ -1045,7 +1045,8 @@ class ForNodeVisitor(object):
gast.Name) and self.node.iter.func.id == "range"
def is_for_iter(self):
if isinstance(self.node.iter, (gast.Name, gast.Attribute)):
if isinstance(self.node.iter,
(gast.Name, gast.Attribute, gast.List, gast.Tuple)):
return True
elif isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
......
......@@ -304,6 +304,24 @@ class ForwardContainsForLayer(paddle.nn.Layer):
return z
# 21. for original list
@paddle.jit.to_static
def for_original_list():
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in [1, 2, 3]:
z = z + x
return z
# 22. for original tuple
@paddle.jit.to_static
def for_original_tuple():
z = fluid.layers.fill_constant([1], 'int32', 0)
for x in (1, 2, 3):
z = z + x
return z
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......@@ -344,6 +362,13 @@ class TestTransform(TestTransformBase):
self.assertTrue(np.allclose(x.numpy(), y.numpy()))
class TestTransformForOriginalList(TestTransform):
def _run(self, to_static):
program_translator.enable(to_static)
with fluid.dygraph.guard():
return self.dygraph_func()
class TestTransformError(TestTransformBase):
def transformed_error(self, etype):
with self.assertRaises(etype):
......@@ -471,5 +496,21 @@ class TestForwardContainsForLayer(TestForIterVarNumpy):
self.dygraph_func = ForwardContainsForLayer()
class TestForOriginalList(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_list
def test_transformed_result_compare(self):
self.transformed_result_compare()
class TestForOriginalTuple(TestTransformForOriginalList):
def set_test_func(self):
self.dygraph_func = for_original_tuple
def test_transformed_result_compare(self):
self.transformed_result_compare()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册