未验证 提交 650a0836 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim]simplify static unittest (#49805)

上级 93cee48e
......@@ -20,7 +20,6 @@ from decorator_helper import prog_scope
import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
from paddle.fluid.framework import grad_var_name
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
......@@ -139,10 +138,8 @@ class TestUnaryAPI(unittest.TestCase):
paddle.static.append_backward(loss)
fetch_list = [x, out]
if block.has_var(grad_var_name(x.name)):
out_grad = block.var(grad_var_name(out.name))
x_grad = block.var(grad_var_name(x.name))
fetch_list.extend([x_grad, out_grad])
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])
# 1) Test Program
res = exe.run(main_prog, fetch_list=fetch_list)
......@@ -236,10 +233,9 @@ class TestReduceAPI(unittest.TestCase):
paddle.static.append_backward(out.sum())
fetch_list = [x, out]
if block.has_var(grad_var_name(x.name)):
out_grad = block.var(grad_var_name(out.name))
x_grad = block.var(grad_var_name(x.name))
fetch_list.append([x_grad, out_grad])
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
......@@ -412,10 +408,10 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, ())
self.assertEqual(y.shape, ())
self.assertEqual(out.shape, ())
if block.has_var(grad_var_name(x.name)):
out_grad = block.var(grad_var_name(out.name))
x_grad = block.var(grad_var_name(x.name))
y_grad = block.var(grad_var_name(y.name))
if block.has_var(x.grad_name):
out_grad = block.var(out.grad_name)
x_grad = block.var(x.grad_name)
y_grad = block.var(y.grad_name)
self.assertEqual(x_grad.shape, ())
self.assertEqual(y_grad.shape, ())
......@@ -439,10 +435,10 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, ())
self.assertEqual(y.shape, (2, 3, 4))
self.assertEqual(out.shape, (2, 3, 4))
if block.has_var(grad_var_name(x.name)):
out_grad = block.var(grad_var_name(out.name))
x_grad = block.var(grad_var_name(x.name))
y_grad = block.var(grad_var_name(y.name))
if block.has_var(x.grad_name):
out_grad = block.var(out.grad_name)
x_grad = block.var(x.grad_name)
y_grad = block.var(y.grad_name)
self.assertEqual(x_grad.shape, ())
self.assertEqual(y_grad.shape, (2, 3, 4))
......@@ -466,10 +462,10 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, (2, 3, 4))
self.assertEqual(y.shape, ())
self.assertEqual(out.shape, (2, 3, 4))
if block.has_var(grad_var_name(x.name)):
out_grad = block.var(grad_var_name(out.name))
x_grad = block.var(grad_var_name(x.name))
y_grad = block.var(grad_var_name(y.name))
if block.has_var(x.grad_name):
out_grad = block.var(out.grad_name)
x_grad = block.var(x.grad_name)
y_grad = block.var(y.grad_name)
self.assertEqual(x_grad.shape, (2, 3, 4))
self.assertEqual(y_grad.shape, ())
......@@ -490,9 +486,9 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())
if block.has_var(grad_var_name(x.name)):
out_grad = block.var(grad_var_name(out.name))
x_grad = block.var(grad_var_name(x.name))
if block.has_var(x.name):
out_grad = block.var(out.grad_name)
x_grad = block.var(x.grad_name)
self.assertEqual(out_grad.shape, ())
self.assertEqual(x_grad.shape, ())
......@@ -1191,10 +1187,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1208,10 +1203,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1225,10 +1219,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1242,10 +1235,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1261,10 +1251,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1278,10 +1267,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1330,10 +1318,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 1)
self.assertEqual(res[1].shape, (10,))
......@@ -1348,10 +1333,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
self.assertEqual(res[0].shape, (3,))
np.testing.assert_array_equal(res[0], [1.0, 1.0, 1.0])
self.assertEqual(res[1].shape, (2, 3))
......@@ -1366,10 +1348,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
self.assertEqual(res[0].shape, (2,))
np.testing.assert_array_equal(res[0], [1.0, 1.0])
self.assertEqual(res[1].shape, (2, 3))
......@@ -1385,10 +1364,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
self.assertEqual(res[0].shape, (10,))
self.assertEqual(res[0][2], 4.0)
self.assertEqual(res[1].shape, (10,))
......@@ -1404,10 +1380,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
self.assertEqual(res[0].shape, (2, 3))
np.testing.assert_array_equal(res[0][1], [4.0, 4.0, 4.0])
self.assertEqual(res[1].shape, (2, 3))
......@@ -1462,10 +1435,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
updates_grad = block.var(grad_var_name(updates.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, out_grad, updates_grad])
res = self.exe.run(
prog, fetch_list=[out, out.grad_name, updates.grad_name]
)
self.assertEqual(res[0].shape, (5,))
self.assertEqual(res[0][3], 2)
self.assertEqual(res[1].shape, (5,))
......@@ -1479,9 +1451,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
res = self.exe.run(prog, fetch_list=[out, index, x_grad])
res = self.exe.run(prog, fetch_list=[out, index, x.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1495,9 +1465,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
res = self.exe.run(prog, fetch_list=[out, index, x_grad])
res = self.exe.run(prog, fetch_list=[out, index, x.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1515,10 +1483,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
out_grad = block.var(grad_var_name(out.name))
x_grad = block.var(grad_var_name(x.name))
res = self.exe.run(prog, feed={}, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(
prog, feed={}, fetch_list=[out, x.grad_name, out.grad_name]
)
self.assertEqual(res[0].shape, (1,))
self.assertEqual(res[1].shape, ())
......@@ -1532,10 +1499,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1594,15 +1558,6 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out4.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
x3_grad = block.var(grad_var_name(x3.name))
x4_grad = block.var(grad_var_name(x4.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
out3_grad = block.var(grad_var_name(out3.name))
out4_grad = block.var(grad_var_name(out4.name))
res = self.exe.run(
prog,
fetch_list=[
......@@ -1610,14 +1565,14 @@ class TestSundryAPIStatic(unittest.TestCase):
out2,
out3,
out4,
x1_grad,
x2_grad,
x3_grad,
x4_grad,
out1_grad,
out2_grad,
out3_grad,
out4_grad,
x1.grad_name,
x2.grad_name,
x3.grad_name,
x4.grad_name,
out1.grad_name,
out2.grad_name,
out3.grad_name,
out4.grad_name,
],
)
self.assertEqual(res[0].shape, ())
......@@ -1656,25 +1611,18 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out3.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
x3_grad = block.var(grad_var_name(x3.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
out3_grad = block.var(grad_var_name(out3.name))
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
out3,
x1_grad,
x2_grad,
x3_grad,
out1_grad,
out2_grad,
out3_grad,
x1.grad_name,
x2.grad_name,
x3.grad_name,
out1.grad_name,
out2.grad_name,
out3.grad_name,
],
)
self.assertEqual(res[0].shape, (1, 1, 1))
......@@ -1698,10 +1646,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
res = self.exe.run(
prog, fetch_list=[x, out, x.grad_name, out.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
......@@ -1720,14 +1667,16 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
res = self.exe.run(
prog,
fetch_list=[out1, out2, out1_grad, out2_grad, x1_grad, x2_grad],
fetch_list=[
out1,
out2,
out1.grad_name,
out2.grad_name,
x1.grad_name,
x2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
......@@ -1775,12 +1724,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(grad_var_name(x.name))
y_grad = block.var(grad_var_name(y.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, out_grad, y_grad, x_grad])
res = self.exe.run(
prog, fetch_list=[out, out.grad_name, y.grad_name, x.grad_name]
)
self.assertEqual(res[0].shape, shape[3])
self.assertEqual(res[1].shape, shape[3])
self.assertEqual(res[2].shape, shape[1])
......@@ -1800,14 +1746,16 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
res = self.exe.run(
prog,
fetch_list=[out1, out2, x1_grad, x2_grad, out1_grad, out2_grad],
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
out1.grad_name,
out2.grad_name,
],
)
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (3,))
......
......@@ -464,7 +464,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [3])
def test_gather_xD_axis_1(self):
def _test_gather_xD_axis_1(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册