From a0b60716f10c6e84b76db88d5db5fa67f3737281 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 18 Dec 2020 15:52:56 +0800 Subject: [PATCH] [Dy2Stat] Support grammar: for ele in var[idx] (#29541) Support to transformfor ele in var stms in which var is a slice of Tensor. --- .../fluid/dygraph/dygraph_to_static/utils.py | 2 ++ .../dygraph_to_static/test_for_enumerate.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index d299e63fd0..3f42137791 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -882,6 +882,8 @@ class ForNodeVisitor(object): self.node.iter.func, gast.Attribute) and self.node.iter.func.attr == 'numpy': return True + elif isinstance(self.node.iter, gast.Subscript): + return True else: return False diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index a74c56fc31..18995238a3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -159,6 +159,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array): def for_iter_var(x_array): z = fluid.layers.fill_constant([1], 'int32', 0) x_array = fluid.dygraph.to_variable(x_array) + for x in x_array: z = z + x return z @@ -221,6 +222,17 @@ def for_enumerate_var_with_nested_range(x_array): return x +# 16. for iter var[idx] +@paddle.jit.to_static +def for_iter_var_idx(x_array): + z = fluid.layers.fill_constant([1], 'int32', 0) + x_array = fluid.dygraph.to_variable(x_array) + + for x in x_array[0:]: + z = z + x + return z + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -343,6 +355,11 @@ class TestForIterVar(TestForIterVarNumpy): self.dygraph_func = for_iter_var +class TestForIterVarIdx(TestForIterVarNumpy): + def set_test_func(self): + self.dygraph_func = for_iter_var_idx + + class TestForEnumerateVar(TestForIterVarNumpy): def set_test_func(self): self.dygraph_func = for_enumerate_var -- GitLab