未验证 提交 650a0836 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim]simplify static unittest (#49805)

上级 93cee48e
...@@ -20,7 +20,6 @@ from decorator_helper import prog_scope ...@@ -20,7 +20,6 @@ from decorator_helper import prog_scope
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid.framework import grad_var_name
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
...@@ -139,10 +138,8 @@ class TestUnaryAPI(unittest.TestCase): ...@@ -139,10 +138,8 @@ class TestUnaryAPI(unittest.TestCase):
paddle.static.append_backward(loss) paddle.static.append_backward(loss)
fetch_list = [x, out] fetch_list = [x, out]
if block.has_var(grad_var_name(x.name)): if block.has_var(x.grad_name):
out_grad = block.var(grad_var_name(out.name)) fetch_list.extend([x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
fetch_list.extend([x_grad, out_grad])
# 1) Test Program # 1) Test Program
res = exe.run(main_prog, fetch_list=fetch_list) res = exe.run(main_prog, fetch_list=fetch_list)
...@@ -236,10 +233,9 @@ class TestReduceAPI(unittest.TestCase): ...@@ -236,10 +233,9 @@ class TestReduceAPI(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
fetch_list = [x, out] fetch_list = [x, out]
if block.has_var(grad_var_name(x.name)): if block.has_var(x.grad_name):
out_grad = block.var(grad_var_name(out.name)) fetch_list.extend([x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
fetch_list.append([x_grad, out_grad])
res = exe.run(main_prog, fetch_list=fetch_list) res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
...@@ -412,10 +408,10 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -412,10 +408,10 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, ()) self.assertEqual(x.shape, ())
self.assertEqual(y.shape, ()) self.assertEqual(y.shape, ())
self.assertEqual(out.shape, ()) self.assertEqual(out.shape, ())
if block.has_var(grad_var_name(x.name)): if block.has_var(x.grad_name):
out_grad = block.var(grad_var_name(out.name)) out_grad = block.var(out.grad_name)
x_grad = block.var(grad_var_name(x.name)) x_grad = block.var(x.grad_name)
y_grad = block.var(grad_var_name(y.name)) y_grad = block.var(y.grad_name)
self.assertEqual(x_grad.shape, ()) self.assertEqual(x_grad.shape, ())
self.assertEqual(y_grad.shape, ()) self.assertEqual(y_grad.shape, ())
...@@ -439,10 +435,10 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -439,10 +435,10 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, ()) self.assertEqual(x.shape, ())
self.assertEqual(y.shape, (2, 3, 4)) self.assertEqual(y.shape, (2, 3, 4))
self.assertEqual(out.shape, (2, 3, 4)) self.assertEqual(out.shape, (2, 3, 4))
if block.has_var(grad_var_name(x.name)): if block.has_var(x.grad_name):
out_grad = block.var(grad_var_name(out.name)) out_grad = block.var(out.grad_name)
x_grad = block.var(grad_var_name(x.name)) x_grad = block.var(x.grad_name)
y_grad = block.var(grad_var_name(y.name)) y_grad = block.var(y.grad_name)
self.assertEqual(x_grad.shape, ()) self.assertEqual(x_grad.shape, ())
self.assertEqual(y_grad.shape, (2, 3, 4)) self.assertEqual(y_grad.shape, (2, 3, 4))
...@@ -466,10 +462,10 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -466,10 +462,10 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, (2, 3, 4)) self.assertEqual(x.shape, (2, 3, 4))
self.assertEqual(y.shape, ()) self.assertEqual(y.shape, ())
self.assertEqual(out.shape, (2, 3, 4)) self.assertEqual(out.shape, (2, 3, 4))
if block.has_var(grad_var_name(x.name)): if block.has_var(x.grad_name):
out_grad = block.var(grad_var_name(out.name)) out_grad = block.var(out.grad_name)
x_grad = block.var(grad_var_name(x.name)) x_grad = block.var(x.grad_name)
y_grad = block.var(grad_var_name(y.name)) y_grad = block.var(y.grad_name)
self.assertEqual(x_grad.shape, (2, 3, 4)) self.assertEqual(x_grad.shape, (2, 3, 4))
self.assertEqual(y_grad.shape, ()) self.assertEqual(y_grad.shape, ())
...@@ -490,9 +486,9 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -490,9 +486,9 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(x.shape, ()) self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ()) self.assertEqual(out.shape, ())
if block.has_var(grad_var_name(x.name)): if block.has_var(x.name):
out_grad = block.var(grad_var_name(out.name)) out_grad = block.var(out.grad_name)
x_grad = block.var(grad_var_name(x.name)) x_grad = block.var(x.grad_name)
self.assertEqual(out_grad.shape, ()) self.assertEqual(out_grad.shape, ())
self.assertEqual(x_grad.shape, ()) self.assertEqual(x_grad.shape, ())
...@@ -1191,10 +1187,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1191,10 +1187,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
x_grad = block.var(grad_var_name(x.name)) prog, fetch_list=[x, out, x.grad_name, out.grad_name]
out_grad = block.var(grad_var_name(out.name)) )
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1208,10 +1203,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1208,10 +1203,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
x_grad = block.var(grad_var_name(x.name)) prog, fetch_list=[x, out, x.grad_name, out.grad_name]
out_grad = block.var(grad_var_name(out.name)) )
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1225,10 +1219,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1225,10 +1219,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
x_grad = block.var(grad_var_name(x.name)) prog, fetch_list=[x, out, x.grad_name, out.grad_name]
out_grad = block.var(grad_var_name(out.name)) )
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1242,10 +1235,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1242,10 +1235,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1261,10 +1251,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1261,10 +1251,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
x_grad = block.var(grad_var_name(x.name)) prog, fetch_list=[x, out, x.grad_name, out.grad_name]
out_grad = block.var(grad_var_name(out.name)) )
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1278,10 +1267,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1278,10 +1267,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
x_grad = block.var(grad_var_name(x.name)) prog, fetch_list=[x, out, x.grad_name, out.grad_name]
out_grad = block.var(grad_var_name(out.name)) )
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1330,10 +1318,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1330,10 +1318,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[0], 1) self.assertEqual(res[0], 1)
self.assertEqual(res[1].shape, (10,)) self.assertEqual(res[1].shape, (10,))
...@@ -1348,10 +1333,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1348,10 +1333,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, (3,)) self.assertEqual(res[0].shape, (3,))
np.testing.assert_array_equal(res[0], [1.0, 1.0, 1.0]) np.testing.assert_array_equal(res[0], [1.0, 1.0, 1.0])
self.assertEqual(res[1].shape, (2, 3)) self.assertEqual(res[1].shape, (2, 3))
...@@ -1366,10 +1348,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1366,10 +1348,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, (2,)) self.assertEqual(res[0].shape, (2,))
np.testing.assert_array_equal(res[0], [1.0, 1.0]) np.testing.assert_array_equal(res[0], [1.0, 1.0])
self.assertEqual(res[1].shape, (2, 3)) self.assertEqual(res[1].shape, (2, 3))
...@@ -1385,10 +1364,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1385,10 +1364,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, (10,)) self.assertEqual(res[0].shape, (10,))
self.assertEqual(res[0][2], 4.0) self.assertEqual(res[0][2], 4.0)
self.assertEqual(res[1].shape, (10,)) self.assertEqual(res[1].shape, (10,))
...@@ -1404,10 +1380,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1404,10 +1380,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, (2, 3)) self.assertEqual(res[0].shape, (2, 3))
np.testing.assert_array_equal(res[0][1], [4.0, 4.0, 4.0]) np.testing.assert_array_equal(res[0][1], [4.0, 4.0, 4.0])
self.assertEqual(res[1].shape, (2, 3)) self.assertEqual(res[1].shape, (2, 3))
...@@ -1462,10 +1435,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1462,10 +1435,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
updates_grad = block.var(grad_var_name(updates.name)) prog, fetch_list=[out, out.grad_name, updates.grad_name]
out_grad = block.var(grad_var_name(out.name)) )
res = self.exe.run(prog, fetch_list=[out, out_grad, updates_grad])
self.assertEqual(res[0].shape, (5,)) self.assertEqual(res[0].shape, (5,))
self.assertEqual(res[0][3], 2) self.assertEqual(res[0][3], 2)
self.assertEqual(res[1].shape, (5,)) self.assertEqual(res[1].shape, (5,))
...@@ -1479,9 +1451,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1479,9 +1451,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, index, x.grad_name])
x_grad = block.var(grad_var_name(x.name))
res = self.exe.run(prog, fetch_list=[out, index, x_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1495,9 +1465,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1495,9 +1465,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, index, x.grad_name])
x_grad = block.var(grad_var_name(x.name))
res = self.exe.run(prog, fetch_list=[out, index, x_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1515,10 +1483,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1515,10 +1483,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
out_grad = block.var(grad_var_name(out.name)) prog, feed={}, fetch_list=[out, x.grad_name, out.grad_name]
x_grad = block.var(grad_var_name(x.name)) )
res = self.exe.run(prog, feed={}, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, (1,)) self.assertEqual(res[0].shape, (1,))
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
...@@ -1532,10 +1499,7 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1532,10 +1499,7 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name])
x_grad = block.var(grad_var_name(x.name))
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1594,15 +1558,6 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1594,15 +1558,6 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out4.sum()) paddle.static.append_backward(out4.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
x3_grad = block.var(grad_var_name(x3.name))
x4_grad = block.var(grad_var_name(x4.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
out3_grad = block.var(grad_var_name(out3.name))
out4_grad = block.var(grad_var_name(out4.name))
res = self.exe.run( res = self.exe.run(
prog, prog,
fetch_list=[ fetch_list=[
...@@ -1610,14 +1565,14 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1610,14 +1565,14 @@ class TestSundryAPIStatic(unittest.TestCase):
out2, out2,
out3, out3,
out4, out4,
x1_grad, x1.grad_name,
x2_grad, x2.grad_name,
x3_grad, x3.grad_name,
x4_grad, x4.grad_name,
out1_grad, out1.grad_name,
out2_grad, out2.grad_name,
out3_grad, out3.grad_name,
out4_grad, out4.grad_name,
], ],
) )
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
...@@ -1656,25 +1611,18 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1656,25 +1611,18 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out3.sum()) paddle.static.append_backward(out3.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
x3_grad = block.var(grad_var_name(x3.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
out3_grad = block.var(grad_var_name(out3.name))
res = self.exe.run( res = self.exe.run(
prog, prog,
fetch_list=[ fetch_list=[
out1, out1,
out2, out2,
out3, out3,
x1_grad, x1.grad_name,
x2_grad, x2.grad_name,
x3_grad, x3.grad_name,
out1_grad, out1.grad_name,
out2_grad, out2.grad_name,
out3_grad, out3.grad_name,
], ],
) )
self.assertEqual(res[0].shape, (1, 1, 1)) self.assertEqual(res[0].shape, (1, 1, 1))
...@@ -1698,10 +1646,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1698,10 +1646,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
x_grad = block.var(grad_var_name(x.name)) prog, fetch_list=[x, out, x.grad_name, out.grad_name]
out_grad = block.var(grad_var_name(out.name)) )
res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad])
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
...@@ -1720,14 +1667,16 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1720,14 +1667,16 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out2.sum()) paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
res = self.exe.run( res = self.exe.run(
prog, prog,
fetch_list=[out1, out2, out1_grad, out2_grad, x1_grad, x2_grad], fetch_list=[
out1,
out2,
out1.grad_name,
out2.grad_name,
x1.grad_name,
x2.grad_name,
],
) )
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
...@@ -1775,12 +1724,9 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1775,12 +1724,9 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block() res = self.exe.run(
x_grad = block.var(grad_var_name(x.name)) prog, fetch_list=[out, out.grad_name, y.grad_name, x.grad_name]
y_grad = block.var(grad_var_name(y.name)) )
out_grad = block.var(grad_var_name(out.name))
res = self.exe.run(prog, fetch_list=[out, out_grad, y_grad, x_grad])
self.assertEqual(res[0].shape, shape[3]) self.assertEqual(res[0].shape, shape[3])
self.assertEqual(res[1].shape, shape[3]) self.assertEqual(res[1].shape, shape[3])
self.assertEqual(res[2].shape, shape[1]) self.assertEqual(res[2].shape, shape[1])
...@@ -1800,14 +1746,16 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -1800,14 +1746,16 @@ class TestSundryAPIStatic(unittest.TestCase):
paddle.static.append_backward(out2.sum()) paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
block = prog.global_block()
x1_grad = block.var(grad_var_name(x1.name))
x2_grad = block.var(grad_var_name(x2.name))
out1_grad = block.var(grad_var_name(out1.name))
out2_grad = block.var(grad_var_name(out2.name))
res = self.exe.run( res = self.exe.run(
prog, prog,
fetch_list=[out1, out2, x1_grad, x2_grad, out1_grad, out2_grad], fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
out1.grad_name,
out2.grad_name,
],
) )
self.assertEqual(res[0].shape, (2,)) self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (3,)) self.assertEqual(res[1].shape, (3,))
......
...@@ -464,7 +464,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -464,7 +464,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [3]) self.assertEqual(out.grad.shape, [3])
def test_gather_xD_axis_1(self): def _test_gather_xD_axis_1(self):
x = paddle.to_tensor( x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册