未验证 提交 f9043c78 编写于 作者: H heliqi 提交者: GitHub

[Zero-Dim]Unsqueeze support 0d tensor (#49862)

* add unsqueeze test case

* add unsqueeze xpu test case

* fix unsqueeze test case

* fix unsqueeze test case

* fix unsqueeze test case

* fix unsqueeze test case

* add retain_grads
上级 60d1199a
...@@ -1345,6 +1345,23 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1345,6 +1345,23 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1) self.assertEqual(x.grad.numpy(), 1)
def test_unsqueeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
x1.retain_grads()
out1 = paddle.unsqueeze(x1, axis=0)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [1])
self.assertEqual(x1.grad.shape, [])
x2 = paddle.full([], 0, dtype='int32')
out2 = paddle.unsqueeze(x1, axis=x2)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [1])
self.assertEqual(x1.grad.shape, [])
def test_t(self): def test_t(self):
x = paddle.full([], 2.0) x = paddle.full([], 2.0)
x.stop_gradient = False x.stop_gradient = False
...@@ -2091,6 +2108,34 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2091,6 +2108,34 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1) self.assertEqual(res[3], 1)
@prog_scope()
def test_unsqueeze(self):
x1 = paddle.full([], 2)
out1 = paddle.unsqueeze(x1, axis=0)
x1.stop_gradient = False
paddle.static.append_backward(out1.sum())
x2 = paddle.full([], 3)
x3 = paddle.full([], 0, dtype='int32')
x2.stop_gradient = False
out2 = paddle.unsqueeze(x2, axis=x3)
paddle.static.append_backward(out2.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
],
)
self.assertEqual(res[0].shape, (1,))
self.assertEqual(res[1].shape, (1,))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
@prog_scope() @prog_scope()
def test_t(self): def test_t(self):
x = paddle.full([], 2.0) x = paddle.full([], 2.0)
......
...@@ -824,6 +824,19 @@ class TestSundryAPI(unittest.TestCase): ...@@ -824,6 +824,19 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1) self.assertEqual(x.grad.numpy(), 1)
def test_unsqueeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
out1 = paddle.unsqueeze(x1, axis=0)
out1.backward()
self.assertEqual(out1.shape, [1])
self.assertEqual(x1.grad.shape, [])
x2 = paddle.full([], 0, dtype='int32')
out2 = paddle.unsqueeze(x1, axis=x2)
out2.backward()
self.assertEqual(out2.shape, [1])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册