From 509d3ec5b7bbe4697875edc17def71645cf373e7 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 18 Jun 2020 19:48:26 +0800 Subject: [PATCH] [Dy2static] Add for iterate or enumerate variable list unittest (#25100) * add for iter var list, test=develop * add enumerate unittest, test=develop --- .../dygraph_to_static/test_for_enumerate.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 3b15e477e7d..86cfcb9b3d8 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -176,6 +176,40 @@ def for_enumerate_var(x_array): return y, z +# 13. for iter list[var] +@declarative +def for_iter_var_list(x): + # 1. prepare data, ref test_list.py + x = fluid.dygraph.to_variable(x) + iter_num = fluid.layers.fill_constant(shape=[1], value=5, dtype="int32") + a = [] + for i in range(iter_num): + a.append(x + i) + # 2. iter list[var] + y = fluid.layers.fill_constant([1], 'int32', 0) + for x in a: + y = y + x + return y + + +# 14. for enumerate list[var] +@declarative +def for_enumerate_var_list(x): + # 1. prepare data, ref test_list.py + x = fluid.dygraph.to_variable(x) + iter_num = fluid.layers.fill_constant(shape=[1], value=5, dtype="int32") + a = [] + for i in range(iter_num): + a.append(x + i) + # 2. iter list[var] + y = fluid.layers.fill_constant([1], 'int32', 0) + z = fluid.layers.fill_constant([1], 'int32', 0) + for i, x in enumerate(a): + y = y + i + z = z + x + return y, z + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -303,5 +337,15 @@ class TestForEnumerateVar(TestForIterVarNumpy): self.dygraph_func = for_enumerate_var +class TestForIterVarList(TestForInRange): + def set_test_func(self): + self.dygraph_func = for_iter_var_list + + +class TestForEnumerateVarList(TestForInRange): + def set_test_func(self): + self.dygraph_func = for_enumerate_var_list + + if __name__ == '__main__': unittest.main() -- GitLab