未验证 提交 07d3568d 编写于 作者: C Ccc 提交者: GitHub

[Zero-dim] add unittest for static.nn.prelu (#50635)

上级 ae622479
......@@ -2750,6 +2750,29 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[4].shape, ())
self.assertEqual(res[5].shape, ())
def test_static_nn_prelu(self):
x1 = paddle.full([], 1.0, 'float32')
x1.stop_gradient = False
out1 = paddle.static.nn.prelu(x1, 'all')
paddle.static.append_backward(out1.sum())
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(
prog,
fetch_list=[
out1,
x1.grad_name,
out1.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[0], np.array(1))
np.testing.assert_allclose(res[1], np.array(1))
@prog_scope()
def test_while_loop(self):
def cond(i, x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册