未验证 提交 f6e874bc 编写于 作者: R Ryan 提交者: GitHub

[Divide by 0 Error] add pinv check (#49951)

* add pinv check

* add unitest

* update unitest

* roll back

* fix not call stupid bug

* use context
上级 094e3b8c
...@@ -105,6 +105,14 @@ void SvdKernel(const Context& dev_ctx, ...@@ -105,6 +105,14 @@ void SvdKernel(const Context& dev_ctx,
// int k = std::min(rows, cols); // int k = std::min(rows, cols);
// int col_u = full ? rows : k; // int col_u = full ? rows : k;
// int col_v = full ? cols : k; // int col_v = full ? cols : k;
PADDLE_ENFORCE_LT(
0,
rows,
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(
0,
cols,
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
int batches = numel / (rows * cols); int batches = numel / (rows * cols);
auto* U_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(U); auto* U_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
auto* VH_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH); auto* VH_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
......
...@@ -217,6 +217,15 @@ void SvdKernel(const Context& dev_ctx, ...@@ -217,6 +217,15 @@ void SvdKernel(const Context& dev_ctx,
int m = dims[rank - 2]; int m = dims[rank - 2];
int n = dims[rank - 1]; int n = dims[rank - 1];
PADDLE_ENFORCE_LT(
0,
m,
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(
0,
n,
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
auto* u_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(U); auto* u_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
auto* vh_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH); auto* vh_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
auto* s_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(S); auto* s_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
......
...@@ -280,5 +280,27 @@ class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase): ...@@ -280,5 +280,27 @@ class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase):
self.hermitian = True self.hermitian = True
class TestDivByZero(unittest.TestCase):
def pinv_zero_input_static(self):
paddle.enable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.linalg.pinv(x)
def pinv_zero_input_dynamic(self):
paddle.disable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.linalg.pinv(x)
def test_div_by_zero(self):
with self.assertRaises(ValueError):
self.pinv_zero_input_dynamic()
self.pinv_zero_input_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册