From f6e874bc657ab844109abef2d57e56d7d7c61dfb Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Mon, 30 Jan 2023 16:20:38 +0800 Subject: [PATCH] [Divide by 0 Error] add pinv check (#49951) * add pinv check * add unitest * update unitest * roll back * fix not call stupid bug * use context --- paddle/phi/kernels/cpu/svd_kernel.cc | 8 +++++++ paddle/phi/kernels/gpu/svd_kernel.cu | 9 ++++++++ .../tests/unittests/test_linalg_pinv_op.py | 22 +++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/paddle/phi/kernels/cpu/svd_kernel.cc b/paddle/phi/kernels/cpu/svd_kernel.cc index 814a9c451e7..c7a7471d159 100644 --- a/paddle/phi/kernels/cpu/svd_kernel.cc +++ b/paddle/phi/kernels/cpu/svd_kernel.cc @@ -105,6 +105,14 @@ void SvdKernel(const Context& dev_ctx, // int k = std::min(rows, cols); // int col_u = full ? rows : k; // int col_v = full ? cols : k; + PADDLE_ENFORCE_LT( + 0, + rows, + errors::InvalidArgument("The row of Input(X) should be greater than 0.")); + PADDLE_ENFORCE_LT( + 0, + cols, + errors::InvalidArgument("The col of Input(X) should be greater than 0.")); int batches = numel / (rows * cols); auto* U_out = dev_ctx.template Alloc>(U); auto* VH_out = dev_ctx.template Alloc>(VH); diff --git a/paddle/phi/kernels/gpu/svd_kernel.cu b/paddle/phi/kernels/gpu/svd_kernel.cu index 4d4c19cde2b..3ef0584b041 100644 --- a/paddle/phi/kernels/gpu/svd_kernel.cu +++ b/paddle/phi/kernels/gpu/svd_kernel.cu @@ -217,6 +217,15 @@ void SvdKernel(const Context& dev_ctx, int m = dims[rank - 2]; int n = dims[rank - 1]; + PADDLE_ENFORCE_LT( + 0, + m, + errors::InvalidArgument("The row of Input(X) should be greater than 0.")); + PADDLE_ENFORCE_LT( + 0, + n, + errors::InvalidArgument("The col of Input(X) should be greater than 0.")); + auto* u_data = dev_ctx.template Alloc>(U); auto* vh_data = dev_ctx.template Alloc>(VH); auto* s_data = dev_ctx.template Alloc>(S); diff --git a/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py b/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py index 353b4d8da55..074b0fb517a 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_pinv_op.py @@ -280,5 +280,27 @@ class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase): self.hermitian = True +class TestDivByZero(unittest.TestCase): + def pinv_zero_input_static(self): + + paddle.enable_static() + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32') + paddle.linalg.pinv(x) + + def pinv_zero_input_dynamic(self): + + paddle.disable_static() + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32') + paddle.linalg.pinv(x) + + def test_div_by_zero(self): + + with self.assertRaises(ValueError): + self.pinv_zero_input_dynamic() + self.pinv_zero_input_static() + + if __name__ == '__main__': unittest.main() -- GitLab