未验证 提交 a7155c5c 编写于 作者: G GGBond8488 提交者: GitHub

【0D output】add 0D output support for linalg.slogdet (#52891)

* add 0D output support for inalg.slogdet,test=allcase

* fix zerom dime test error test=allcase

* fix test error test=allcase

* add static backward test, test=allcase
上级 f1b6a76b
......@@ -94,7 +94,7 @@ void SlogDeterminantKernel(const Context& dev_ctx,
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
if (input_dim.size() == static_cast<size_t>(2)) {
// when input is a two-dimension matrix, The det value is a number.
output_dim_vec = {1};
output_dim_vec = {};
}
output_dim_vec.insert(output_dim_vec.begin(),
2); // make the output dims as same as numpy
......
......@@ -2066,6 +2066,27 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0))
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
x.stop_gradient = False
out = paddle.linalg.slogdet(x)
out.retain_grads()
out.backward()
self.assertTrue(out.shape, [2])
self.assertTrue(x.grad.shape, [3, 3])
# 3-D input
x1 = paddle.randn([3, 3, 3])
x1.stop_gradient = False
out1 = paddle.linalg.slogdet(x1)
out1.retain_grads()
out1.backward()
self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -3609,6 +3630,30 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(out1.shape, (2, 3))
self.assertEqual(out2.shape, (2, 3))
@prog_scope()
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
x.stop_gradient = False
out = paddle.linalg.slogdet(x)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (3, 3))
# 3-D input
x1 = paddle.randn([3, 3, 3])
x1.stop_gradient = False
out1 = paddle.linalg.slogdet(x1)
paddle.static.append_backward(out1.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (3, 3, 3))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册