未验证 提交 441606fd 编写于 作者: L levi131 提交者: GitHub

add fifth order test case (#44303)

上级 07f33da9
...@@ -161,7 +161,7 @@ class TestGrad(unittest.TestCase): ...@@ -161,7 +161,7 @@ class TestGrad(unittest.TestCase):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(startup) exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list) outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result) np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5)
paddle.incubate.autograd.disable_prim() paddle.incubate.autograd.disable_prim()
def test_fourth_order(self): def test_fourth_order(self):
...@@ -196,7 +196,43 @@ class TestGrad(unittest.TestCase): ...@@ -196,7 +196,43 @@ class TestGrad(unittest.TestCase):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(startup) exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list) outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.allclose(outs, result) np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5)
paddle.incubate.autograd.disable_prim()
def test_fifth_order(self):
paddle.incubate.autograd.enable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name='x', shape=[1], dtype='float32')
x2 = paddle.multiply(x, x)
x3 = paddle.multiply(x2, x)
x4 = paddle.multiply(x3, x)
x5 = paddle.multiply(x4, x)
x6 = paddle.multiply(x5, x)
out = x6 + x5
grad1, = paddle.incubate.autograd.grad([out], [x])
grad2, = paddle.incubate.autograd.grad([grad1], [x])
grad3, = paddle.incubate.autograd.grad([grad2], [x])
grad4, = paddle.incubate.autograd.grad([grad3], [x])
grad5, = paddle.incubate.autograd.grad([grad4], [x])
paddle.incubate.autograd.prim2orig()
feed = {
x.name: np.array([2.]).astype('float32'),
}
fetch_list = [grad5.name]
result = [np.array([1560.0])]
place = paddle.CPUPlace()
if paddle.device.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
outs = exe.run(main, feed=feed, fetch_list=fetch_list)
np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5)
paddle.incubate.autograd.disable_prim() paddle.incubate.autograd.disable_prim()
def test_disable_prim(self): def test_disable_prim(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册