未验证 提交 7e4b432b 编写于 作者: Y Yuang Liu 提交者: GitHub

update xpu zero dim tensor ut (#50289)

* xpu scatter ut no backward

* update gather xpu ut
上级 84fe2de6
...@@ -583,42 +583,35 @@ class TestSundryAPI(unittest.TestCase): ...@@ -583,42 +583,35 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [3]) self.assertEqual(out.grad.shape, [3])
def _test_gather_xD_axis_1(self): def test_gather_xD_axis_1(self):
x = paddle.to_tensor( x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
) )
index = paddle.full([], 1, 'int64') index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1) out = paddle.gather(x, index, axis=1)
out.backward()
self.assertEqual(out.shape, [2]) self.assertEqual(out.shape, [2])
np.testing.assert_array_equal(out.numpy(), [2.0, 5.0]) np.testing.assert_array_equal(out.numpy(), [2.0, 5.0])
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [2])
def _test_scatter_1D(self): def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64') index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0) updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates) out = paddle.scatter(x, index, updates)
out.backward()
self.assertEqual(out.shape, [5]) self.assertEqual(out.shape, [5])
self.assertEqual(out.numpy()[2], 4) self.assertEqual(out.numpy()[2], 4)
self.assertEqual(out.grad.shape, [5])
def _test_scatter_XD(self): def test_scatter_XD(self):
x = paddle.to_tensor( x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
) )
index = paddle.full([], 1, 'int64') index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0]) updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates) out = paddle.scatter(x, index, updates)
out.backward()
self.assertEqual(out.shape, [2, 3]) self.assertEqual(out.shape, [2, 3])
np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0]) np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0])
self.assertEqual(out.grad.shape, [2, 3])
def test_diagflat(self): def test_diagflat(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册