未验证 提交 a7155c5c 编写于 作者: G GGBond8488 提交者: GitHub

【0D output】add 0D output support for linalg.slogdet (#52891)

* add 0D output support for inalg.slogdet,test=allcase

* fix zerom dime test error test=allcase

* fix test error test=allcase

* add static backward test, test=allcase
上级 f1b6a76b
...@@ -94,7 +94,7 @@ void SlogDeterminantKernel(const Context& dev_ctx, ...@@ -94,7 +94,7 @@ void SlogDeterminantKernel(const Context& dev_ctx,
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2); std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
if (input_dim.size() == static_cast<size_t>(2)) { if (input_dim.size() == static_cast<size_t>(2)) {
// when input is a two-dimension matrix, The det value is a number. // when input is a two-dimension matrix, The det value is a number.
output_dim_vec = {1}; output_dim_vec = {};
} }
output_dim_vec.insert(output_dim_vec.begin(), output_dim_vec.insert(output_dim_vec.begin(),
2); // make the output dims as same as numpy 2); // make the output dims as same as numpy
......
...@@ -2066,6 +2066,27 @@ class TestSundryAPI(unittest.TestCase): ...@@ -2066,6 +2066,27 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0)) np.testing.assert_allclose(x.grad, np.array(1.0))
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
x.stop_gradient = False
out = paddle.linalg.slogdet(x)
out.retain_grads()
out.backward()
self.assertTrue(out.shape, [2])
self.assertTrue(x.grad.shape, [3, 3])
# 3-D input
x1 = paddle.randn([3, 3, 3])
x1.stop_gradient = False
out1 = paddle.linalg.slogdet(x1)
out1.retain_grads()
out1.backward()
self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])
class TestSundryAPIStatic(unittest.TestCase): class TestSundryAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -3609,6 +3630,30 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -3609,6 +3630,30 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(out1.shape, (2, 3)) self.assertEqual(out1.shape, (2, 3))
self.assertEqual(out2.shape, (2, 3)) self.assertEqual(out2.shape, (2, 3))
@prog_scope()
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
x.stop_gradient = False
out = paddle.linalg.slogdet(x)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (3, 3))
# 3-D input
x1 = paddle.randn([3, 3, 3])
x1.stop_gradient = False
out1 = paddle.linalg.slogdet(x1)
paddle.static.append_backward(out1.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (3, 3, 3))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase): class TestNoBackwardAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册