diff --git a/paddle/phi/kernels/cpu/svd_kernel.cc b/paddle/phi/kernels/cpu/svd_kernel.cc index 814a9c451e7b8444fb4762e812f5fa3741223ec1..c7a7471d159931333119292a71f468c44efc019f 100644 --- a/paddle/phi/kernels/cpu/svd_kernel.cc +++ b/paddle/phi/kernels/cpu/svd_kernel.cc @@ -105,6 +105,14 @@ void SvdKernel(const Context& dev_ctx, // int k = std::min(rows, cols); // int col_u = full ? rows : k; // int col_v = full ? cols : k; + PADDLE_ENFORCE_LT( + 0, + rows, + errors::InvalidArgument("The row of Input(X) should be greater than 0.")); + PADDLE_ENFORCE_LT( + 0, + cols, + errors::InvalidArgument("The col of Input(X) should be greater than 0.")); int batches = numel / (rows * cols); auto* U_out = dev_ctx.template Alloc>(U); auto* VH_out = dev_ctx.template Alloc>(VH); diff --git a/paddle/phi/kernels/gpu/svd_kernel.cu b/paddle/phi/kernels/gpu/svd_kernel.cu index 4d4c19cde2b9c0425c1d1a1bc75397273302d8a0..3ef0584b0419f15c25e14af06795034c3de25546 100644 --- a/paddle/phi/kernels/gpu/svd_kernel.cu +++ b/paddle/phi/kernels/gpu/svd_kernel.cu @@ -217,6 +217,15 @@ void SvdKernel(const Context& dev_ctx, int m = dims[rank - 2]; int n = dims[rank - 1]; + PADDLE_ENFORCE_LT( + 0, + m, + errors::InvalidArgument("The row of Input(X) should be greater than 0.")); + PADDLE_ENFORCE_LT( + 0, + n, + errors::InvalidArgument("The col of Input(X) should be greater than 0.")); + auto* u_data = dev_ctx.template Alloc>(U); auto* vh_data = dev_ctx.template Alloc>(VH); auto* s_data = dev_ctx.template Alloc>(S); diff --git a/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py b/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py index 353b4d8da55e9c06122ab219ed2710bca0f87771..074b0fb517aa96d4ad8a8d472c51dea822b74f65 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py @@ -280,5 +280,27 @@ class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase): self.hermitian = True +class TestDivByZero(unittest.TestCase): + def pinv_zero_input_static(self): + + paddle.enable_static() + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32') + paddle.linalg.pinv(x) + + def pinv_zero_input_dynamic(self): + + paddle.disable_static() + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32') + paddle.linalg.pinv(x) + + def test_div_by_zero(self): + + with self.assertRaises(ValueError): + self.pinv_zero_input_dynamic() + self.pinv_zero_input_static() + + if __name__ == '__main__': unittest.main()