From 44db219a8162a7fdcd725f09d1eebbd24bd57f10 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Fri, 19 Nov 2021 16:57:33 +0800 Subject: [PATCH] [Dy2stat]Support `for i in [1,2,3]` statements in dy2stat (#37259) (#37356) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 该PR使得动转静模块能够正确转换如下的for i in [1, 2, 3]语句。 --- .../fluid/dygraph/dygraph_to_static/utils.py | 3 +- .../dygraph_to_static/test_for_enumerate.py | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 650857eefb3..320f2ef5b33 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -1044,7 +1044,8 @@ class ForNodeVisitor(object): gast.Name) and self.node.iter.func.id == "range" def is_for_iter(self): - if isinstance(self.node.iter, (gast.Name, gast.Attribute)): + if isinstance(self.node.iter, + (gast.Name, gast.Attribute, gast.List, gast.Tuple)): return True elif isinstance(self.node.iter, gast.Call) and isinstance( self.node.iter.func, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 305bdc1468b..2aab27c0311 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -304,6 +304,24 @@ class ForwardContainsForLayer(paddle.nn.Layer): return z +# 21. for original list +@paddle.jit.to_static +def for_original_list(): + z = fluid.layers.fill_constant([1], 'int32', 0) + for x in [1, 2, 3]: + z = z + x + return z + + +# 22. for original tuple +@paddle.jit.to_static +def for_original_tuple(): + z = fluid.layers.fill_constant([1], 'int32', 0) + for x in (1, 2, 3): + z = z + x + return z + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -344,6 +362,13 @@ class TestTransform(TestTransformBase): self.assertTrue(np.allclose(x.numpy(), y.numpy())) +class TestTransformForOriginalList(TestTransform): + def _run(self, to_static): + program_translator.enable(to_static) + with fluid.dygraph.guard(): + return self.dygraph_func() + + class TestTransformError(TestTransformBase): def transformed_error(self, etype): with self.assertRaises(etype): @@ -471,5 +496,21 @@ class TestForwardContainsForLayer(TestForIterVarNumpy): self.dygraph_func = ForwardContainsForLayer() +class TestForOriginalList(TestTransformForOriginalList): + def set_test_func(self): + self.dygraph_func = for_original_list + + def test_transformed_result_compare(self): + self.transformed_result_compare() + + +class TestForOriginalTuple(TestTransformForOriginalList): + def set_test_func(self): + self.dygraph_func = for_original_tuple + + def test_transformed_result_compare(self): + self.transformed_result_compare() + + if __name__ == '__main__': unittest.main() -- GitLab