未验证 提交 509d3ec5 编写于 作者: C Chen Weihang 提交者: GitHub

[Dy2static] Add for iterate or enumerate variable list unittest (#25100)

* add for iter var list, test=develop

* add enumerate unittest, test=develop
上级 eb1c0901
......@@ -176,6 +176,40 @@ def for_enumerate_var(x_array):
return y, z
# 13. for iter list[var]
@declarative
def for_iter_var_list(x):
# 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(shape=[1], value=5, dtype="int32")
a = []
for i in range(iter_num):
a.append(x + i)
# 2. iter list[var]
y = fluid.layers.fill_constant([1], 'int32', 0)
for x in a:
y = y + x
return y
# 14. for enumerate list[var]
@declarative
def for_enumerate_var_list(x):
# 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(shape=[1], value=5, dtype="int32")
a = []
for i in range(iter_num):
a.append(x + i)
# 2. iter list[var]
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
for i, x in enumerate(a):
y = y + i
z = z + x
return y, z
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......@@ -303,5 +337,15 @@ class TestForEnumerateVar(TestForIterVarNumpy):
self.dygraph_func = for_enumerate_var
class TestForIterVarList(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_iter_var_list
class TestForEnumerateVarList(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_list
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册