From 7e4b432b6dce05c09e2040923a12d1968222ce7a Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 7 Feb 2023 19:57:55 +0800 Subject: [PATCH] update xpu zero dim tensor ut (#50289) * xpu scatter ut no backward * update gather xpu ut --- .../tests/unittests/xpu/test_zero_dim_tensor_xpu.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 5392cdbdb8..573dbb1547 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -583,42 +583,35 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(out.grad.shape, [3]) - def _test_gather_xD_axis_1(self): + def test_gather_xD_axis_1(self): x = paddle.to_tensor( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False ) index = paddle.full([], 1, 'int64') out = paddle.gather(x, index, axis=1) - out.backward() self.assertEqual(out.shape, [2]) np.testing.assert_array_equal(out.numpy(), [2.0, 5.0]) - self.assertEqual(x.grad.shape, [2, 3]) - self.assertEqual(out.grad.shape, [2]) - def _test_scatter_1D(self): + def test_scatter_1D(self): x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) index = paddle.full([], 2, 'int64') updates = paddle.full([], 4.0) out = paddle.scatter(x, index, updates) - out.backward() self.assertEqual(out.shape, [5]) self.assertEqual(out.numpy()[2], 4) - self.assertEqual(out.grad.shape, [5]) - def _test_scatter_XD(self): + def test_scatter_XD(self): x = paddle.to_tensor( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False ) index = paddle.full([], 1, 'int64') updates = paddle.to_tensor([1.0, 2.0, 3.0]) out = paddle.scatter(x, index, updates) - out.backward() self.assertEqual(out.shape, [2, 3]) np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0]) - self.assertEqual(out.grad.shape, [2, 3]) def test_diagflat(self): x1 = paddle.rand([]) -- GitLab