未验证 提交 7f1a1570 编写于 作者: R RedContritio 提交者: GitHub

Fix Python IndexError of case1: paddle.linalg.lstsq (#49985)

上级 3e9d8548
...@@ -278,5 +278,38 @@ class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase): ...@@ -278,5 +278,38 @@ class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase):
self._input_shape_2 = (50, 300) self._input_shape_2 = (50, 300)
class TestLinalgLstsqAPIError(unittest.TestCase):
def setUp(self):
pass
def test_api_errors(self):
def test_x_bad_shape():
x = paddle.to_tensor(np.random.random(size=(5)), dtype=np.float32)
y = paddle.to_tensor(
np.random.random(size=(5, 15)), dtype=np.float32
)
out = paddle.linalg.lstsq(x, y, driver='gelsy')
def test_y_bad_shape():
x = paddle.to_tensor(
np.random.random(size=(5, 10)), dtype=np.float32
)
y = paddle.to_tensor(np.random.random(size=(5)), dtype=np.float32)
out = paddle.linalg.lstsq(x, y, driver='gelsy')
def test_shape_dismatch():
x = paddle.to_tensor(
np.random.random(size=(5, 10)), dtype=np.float32
)
y = paddle.to_tensor(
np.random.random(size=(4, 15)), dtype=np.float32
)
out = paddle.linalg.lstsq(x, y, driver='gelsy')
self.assertRaises(ValueError, test_x_bad_shape)
self.assertRaises(ValueError, test_y_bad_shape)
self.assertRaises(ValueError, test_shape_dismatch)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -3171,13 +3171,26 @@ def lstsq(x, y, rcond=None, driver=None, name=None): ...@@ -3171,13 +3171,26 @@ def lstsq(x, y, rcond=None, driver=None, name=None):
else: else:
raise RuntimeError("Only support lstsq api for CPU or CUDA device.") raise RuntimeError("Only support lstsq api for CPU or CUDA device.")
if x.dtype == y.dtype and x.dtype in (paddle.float32, paddle.float64): if not (x.dtype == y.dtype and x.dtype in (paddle.float32, paddle.float64)):
pass
else:
raise ValueError( raise ValueError(
"Only support x and y have the same dtype such as 'float32' and 'float64'." "Only support x and y have the same dtype such as 'float32' and 'float64'."
) )
if x.ndim < 2:
raise ValueError(
f"The shape of x should be (*, M, N), but received ndim is [{x.ndim} < 2]"
)
if y.ndim < 2:
raise ValueError(
f"The shape of y should be (*, M, K), but received ndim is [{y.ndim} < 2]"
)
if x.shape[-2] != y.shape[-2]:
raise ValueError(
f"x with shape (*, M = {x.shape[-2]}, N) and y with shape (*, M = {y.shape[-2]}, K) should have same M."
)
if rcond is None: if rcond is None:
if x.dtype == paddle.float32: if x.dtype == paddle.float32:
rcond = 1e-7 * max(x.shape[-2], x.shape[-1]) rcond = 1e-7 * max(x.shape[-2], x.shape[-1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册