未验证 提交 ce045890 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Zero-Dim] Support input 0D Tensor for masked_select (#49803)

* [Zero-Dim] Support input 0D Tensor for masked_select
上级 2242136a
......@@ -1298,6 +1298,20 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x1.grad.numpy(), 0.5)
self.assertEqual(x2.grad.numpy(), 0)
def test_maseked_select(self):
x = paddle.rand([])
x.stop_gradient = False
mask = paddle.full([], True, dtype='bool')
y = paddle.masked_select(x, mask)
y.retain_grads()
y.backward()
self.assertEqual(y.shape, [1])
self.assertEqual(y.numpy(), x.numpy())
self.assertEqual(y.grad.shape, [1])
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1)
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -1968,6 +1982,22 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[0].shape, ())
@prog_scope()
def test_maseked_select(self):
x = paddle.rand([])
x.stop_gradient = False
mask = paddle.full([], True, dtype='bool')
y = paddle.masked_select(x, mask)
paddle.static.append_backward(y.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[x, y, y.grad_name, x.grad_name])
self.assertEqual(res[1].shape, (1,))
self.assertEqual(res[1], res[0])
self.assertEqual(res[2].shape, (1,))
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -787,6 +787,20 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, [])
self.assertFalse(out)
def test_maseked_select(self):
x = paddle.rand([])
x.stop_gradient = False
mask = paddle.full([], True, dtype='bool')
y = paddle.masked_select(x, mask)
y.retain_grads()
y.backward()
self.assertEqual(y.shape, [1])
self.assertEqual(y.numpy(), x.numpy())
self.assertEqual(y.grad.shape, [1])
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册