From 5ffc22da8ebbfa32473482f930a04a7b5a887e49 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Wed, 18 Jan 2023 12:49:51 +0800 Subject: [PATCH] [Zero-Dim]Paddle.t support 0d tensor (#49880) * support paddle.t 0d tensor * fix paddle.t test case * merge from develop --- .../tests/unittests/test_zero_dim_tensor.py | 26 +++++++++++++++++++ python/paddle/tensor/linalg.py | 4 +-- 2 files changed, 28 insertions(+), 2 deletions(-) mode change 100755 => 100644 python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py old mode 100755 new mode 100644 index a97d284120..e18c0bec99 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1345,6 +1345,17 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.numpy(), 1) + def test_t(self): + x = paddle.full([], 2.0) + x.stop_gradient = False + x.retain_grads() + out = paddle.t(x) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, []) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -2080,6 +2091,21 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[3].shape, ()) self.assertEqual(res[3], 1) + @prog_scope() + def test_t(self): + x = paddle.full([], 2.0) + x.stop_gradient = False + out = paddle.t(x) + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, feed={}, fetch_list=[out, out.grad_name, x.grad_name] + ) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + self.assertEqual(res[2].shape, ()) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index a5492d5081..4cce1b0196 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1296,7 +1296,7 @@ def t(input, name=None): "tensor.transpose() instead." % len(input.shape) ) if in_dygraph_mode(): - if len(input.shape) == 1: + if len(input.shape) <= 1: return input # 2-D tensor perm = [1, 0] @@ -1313,7 +1313,7 @@ def t(input, name=None): helper = LayerHelper('t', **locals()) out = helper.create_variable_for_type_inference(input.dtype) input_shape = helper.create_variable_for_type_inference(input.dtype) - if len(input.shape) == 1: + if len(input.shape) <= 1: out = input else: helper.append_op( -- GitLab