From 07d3568dc22b5490d3f42b4854d77550ca96e9a5 Mon Sep 17 00:00:00 2001 From: Ccc <52520497+juncaipeng@users.noreply.github.com> Date: Sat, 18 Feb 2023 20:35:58 +0800 Subject: [PATCH] [Zero-dim] add unittest for static.nn.prelu (#50635) --- .../tests/unittests/test_zero_dim_tensor.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 0b3a6c20ec..99c2bc117d 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2750,6 +2750,29 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[4].shape, ()) self.assertEqual(res[5].shape, ()) + def test_static_nn_prelu(self): + x1 = paddle.full([], 1.0, 'float32') + x1.stop_gradient = False + out1 = paddle.static.nn.prelu(x1, 'all') + paddle.static.append_backward(out1.sum()) + + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run( + prog, + fetch_list=[ + out1, + x1.grad_name, + out1.grad_name, + ], + ) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, ()) + np.testing.assert_allclose(res[0], np.array(1)) + np.testing.assert_allclose(res[1], np.array(1)) + @prog_scope() def test_while_loop(self): def cond(i, x): -- GitLab