“f7e4b863340b19ff112efc52f1a83e1466b2f302”上不存在“develop/doc_cn/design/kernel_hint_design.html”
未验证 提交 5ffc22da 编写于 作者: H heliqi 提交者: GitHub

[Zero-Dim]Paddle.t support 0d tensor (#49880)

* support paddle.t 0d tensor

* fix paddle.t test case

* merge from develop
上级 7242f40b
...@@ -1345,6 +1345,17 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1345,6 +1345,17 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1) self.assertEqual(x.grad.numpy(), 1)
def test_t(self):
x = paddle.full([], 2.0)
x.stop_gradient = False
x.retain_grads()
out = paddle.t(x)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
class TestSundryAPIStatic(unittest.TestCase): class TestSundryAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -2080,6 +2091,21 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2080,6 +2091,21 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[3].shape, ()) self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1) self.assertEqual(res[3], 1)
@prog_scope()
def test_t(self):
x = paddle.full([], 2.0)
x.stop_gradient = False
out = paddle.t(x)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, feed={}, fetch_list=[out, out.grad_name, x.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
...@@ -1296,7 +1296,7 @@ def t(input, name=None): ...@@ -1296,7 +1296,7 @@ def t(input, name=None):
"tensor.transpose() instead." % len(input.shape) "tensor.transpose() instead." % len(input.shape)
) )
if in_dygraph_mode(): if in_dygraph_mode():
if len(input.shape) == 1: if len(input.shape) <= 1:
return input return input
# 2-D tensor # 2-D tensor
perm = [1, 0] perm = [1, 0]
...@@ -1313,7 +1313,7 @@ def t(input, name=None): ...@@ -1313,7 +1313,7 @@ def t(input, name=None):
helper = LayerHelper('t', **locals()) helper = LayerHelper('t', **locals())
out = helper.create_variable_for_type_inference(input.dtype) out = helper.create_variable_for_type_inference(input.dtype)
input_shape = helper.create_variable_for_type_inference(input.dtype) input_shape = helper.create_variable_for_type_inference(input.dtype)
if len(input.shape) == 1: if len(input.shape) <= 1:
out = input out = input
else: else:
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册