From 7f1a1570c68985b8649edbe484b812ab82df26bb Mon Sep 17 00:00:00 2001 From: RedContritio Date: Wed, 1 Feb 2023 10:37:18 +0800 Subject: [PATCH] Fix Python IndexError of case1: paddle.linalg.lstsq (#49985) --- .../tests/unittests/test_linalg_lstsq_op.py | 33 +++++++++++++++++++ python/paddle/tensor/linalg.py | 19 +++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py index 82576ab1bd..94dc901a56 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -278,5 +278,38 @@ class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase): self._input_shape_2 = (50, 300) +class TestLinalgLstsqAPIError(unittest.TestCase): + def setUp(self): + pass + + def test_api_errors(self): + def test_x_bad_shape(): + x = paddle.to_tensor(np.random.random(size=(5)), dtype=np.float32) + y = paddle.to_tensor( + np.random.random(size=(5, 15)), dtype=np.float32 + ) + out = paddle.linalg.lstsq(x, y, driver='gelsy') + + def test_y_bad_shape(): + x = paddle.to_tensor( + np.random.random(size=(5, 10)), dtype=np.float32 + ) + y = paddle.to_tensor(np.random.random(size=(5)), dtype=np.float32) + out = paddle.linalg.lstsq(x, y, driver='gelsy') + + def test_shape_dismatch(): + x = paddle.to_tensor( + np.random.random(size=(5, 10)), dtype=np.float32 + ) + y = paddle.to_tensor( + np.random.random(size=(4, 15)), dtype=np.float32 + ) + out = paddle.linalg.lstsq(x, y, driver='gelsy') + + self.assertRaises(ValueError, test_x_bad_shape) + self.assertRaises(ValueError, test_y_bad_shape) + self.assertRaises(ValueError, test_shape_dismatch) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4cce1b0196..46f11130c0 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3171,13 +3171,26 @@ def lstsq(x, y, rcond=None, driver=None, name=None): else: raise RuntimeError("Only support lstsq api for CPU or CUDA device.") - if x.dtype == y.dtype and x.dtype in (paddle.float32, paddle.float64): - pass - else: + if not (x.dtype == y.dtype and x.dtype in (paddle.float32, paddle.float64)): raise ValueError( "Only support x and y have the same dtype such as 'float32' and 'float64'." ) + if x.ndim < 2: + raise ValueError( + f"The shape of x should be (*, M, N), but received ndim is [{x.ndim} < 2]" + ) + + if y.ndim < 2: + raise ValueError( + f"The shape of y should be (*, M, K), but received ndim is [{y.ndim} < 2]" + ) + + if x.shape[-2] != y.shape[-2]: + raise ValueError( + f"x with shape (*, M = {x.shape[-2]}, N) and y with shape (*, M = {y.shape[-2]}, K) should have same M." + ) + if rcond is None: if x.dtype == paddle.float32: rcond = 1e-7 * max(x.shape[-2], x.shape[-1]) -- GitLab