未验证 提交 6fdb316c 编写于 作者: Z zhiboniu 提交者: GitHub

add lu_unpack data check (#56311)

* add lu_unpack data check

* add error input api test

* add error type info
上级 6a42ddc6
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
...@@ -500,6 +501,16 @@ void Unpack_Pivot(const Context& dev_ctx, ...@@ -500,6 +501,16 @@ void Unpack_Pivot(const Context& dev_ctx,
arange<Context>(dev_ctx, &idt, h); arange<Context>(dev_ctx, &idt, h);
auto idlst = idt.data<int32_t>(); auto idlst = idt.data<int32_t>();
for (int j = 0; j < Pnum; j++) { for (int j = 0; j < Pnum; j++) {
PADDLE_ENFORCE_EQ(
(pdataptr[i * Pnum + j] > 0) && (pdataptr[i * Pnum + j] <= h),
true,
phi::errors::InvalidArgument(
"The data in Pivot must be between (1, x.shape[-2]],"
"but got %d in Pivot while the x.shape[-2] is %d."
"Please make sure that the inputs(x and Pivot) is the output of "
"paddle.linalg.lu.",
pdataptr[i * Pnum + j],
h));
if (idlst[pdataptr[i * Pnum + j] - 1] == idlst[j]) continue; if (idlst[pdataptr[i * Pnum + j] - 1] == idlst[j]) continue;
auto temp = idlst[j]; auto temp = idlst[j];
idlst[j] = idlst[pdataptr[i * Pnum + j] - 1]; idlst[j] = idlst[pdataptr[i * Pnum + j] - 1];
......
...@@ -2433,7 +2433,14 @@ def lu_unpack(x, y, unpack_ludata=True, unpack_pivots=True, name=None): ...@@ -2433,7 +2433,14 @@ def lu_unpack(x, y, unpack_ludata=True, unpack_pivots=True, name=None):
# one can verify : X = P @ L @ U ; # one can verify : X = P @ L @ U ;
""" """
if x.ndim < 2:
raise ValueError(
f"The shape of x should be (*, M, N), but received ndim is [{x.ndim} < 2]"
)
if y.ndim < 1:
raise ValueError(
f"The shape of Pivots should be (*, K), but received ndim is [{y.ndim} < 1]"
)
if in_dynamic_mode(): if in_dynamic_mode():
P, L, U = _C_ops.lu_unpack(x, y, unpack_ludata, unpack_pivots) P, L, U = _C_ops.lu_unpack(x, y, unpack_ludata, unpack_pivots)
return P, L, U return P, L, U
......
...@@ -315,6 +315,68 @@ class TestLU_UnpackAPI(unittest.TestCase): ...@@ -315,6 +315,68 @@ class TestLU_UnpackAPI(unittest.TestCase):
run_lu_static(tensor_shape, dtype) run_lu_static(tensor_shape, dtype)
class TestLU_UnpackAPIError(unittest.TestCase):
def test_errors_1(self):
with paddle.fluid.dygraph.guard():
# The size of input in lu should not be 0.
def test_x_size():
x = paddle.to_tensor(
np.random.uniform(-6666666, 100000000, [2]).astype(
np.float32
)
)
y = paddle.to_tensor(
np.random.uniform(-2147483648, 2147483647, [2]).astype(
np.int32
)
)
unpack_ludata = True
unpack_pivots = True
paddle.linalg.lu_unpack(x, y, unpack_ludata, unpack_pivots)
self.assertRaises(ValueError, test_x_size)
def test_errors_2(self):
with paddle.fluid.dygraph.guard():
# The size of input in lu should not be 0.
def test_y_size():
x = paddle.to_tensor(
np.random.uniform(-6666666, 100000000, [8, 4, 2]).astype(
np.float32
)
)
y = paddle.to_tensor(
np.random.uniform(-2147483648, 2147483647, []).astype(
np.int32
)
)
unpack_ludata = True
unpack_pivots = True
paddle.linalg.lu_unpack(x, y, unpack_ludata, unpack_pivots)
self.assertRaises(ValueError, test_y_size)
def test_errors_3(self):
with paddle.fluid.dygraph.guard():
# The size of input in lu should not be 0.
def test_y_data():
x = paddle.to_tensor(
np.random.uniform(-6666666, 100000000, [8, 4, 2]).astype(
np.float32
)
)
y = paddle.to_tensor(
np.random.uniform(-2147483648, 2147483647, [8, 2]).astype(
np.int32
)
)
unpack_ludata = True
unpack_pivots = True
paddle.linalg.lu_unpack(x, y, unpack_ludata, unpack_pivots)
self.assertRaises(Exception, test_y_data)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册