未验证 提交 6f8ec229 编写于 作者: R RedContritio 提交者: GitHub

Fix 空指针 (Null pointer) of case 2 paddle.linalg.lu_unpack (#49976)

* add pivots type check and fix batchsize error

* add unittest for batchsize = 0

* fix nullptr in lu_unpack

fix batchsize error in LU_Unpack
add nullptr check in OneFunctor

* remove exception in device code
上级 2d59aa09
......@@ -443,7 +443,10 @@ void LU_Unpack(const Context& dev_ctx,
auto dim = std::min(H, W);
DenseTensor rowtensor, rt_dev;
auto batchsize = product(phi::slice_ddim(udims, 0, udims.size() - 2));
batchsize = std::max(static_cast<int>(batchsize), 1);
// if udims is [0, ..., H, W], it should be 0
if (udims.size() == 2) batchsize = std::max(static_cast<int>(batchsize), 1);
arange<Context>(dev_ctx, &rowtensor, dim, batchsize, H);
auto idtptr = rowtensor.data<int32_t>();
if (phi::AllocationType::GPU == dev_ctx.GetPlace().GetType()) {
......@@ -494,7 +497,8 @@ void Unpack_Pivot(const Context& dev_ctx,
setter(dev_ctx, P, static_cast<T>(0));
auto batchsize = product(phi::slice_ddim(dims, 0, prank - 1));
batchsize = std::max(static_cast<int>(batchsize), 1);
if (prank == 1) batchsize = std::max(static_cast<int>(batchsize), 1);
DenseTensor idt;
for (int i = 0; i < batchsize; i++) {
arange<Context>(dev_ctx, &idt, h);
......
......@@ -51,6 +51,14 @@ void LUUnpackKernel(const Context& dev_ctx,
if (unpack_pivots) {
dev_ctx.template Alloc<T>(pmat);
PADDLE_ENFORCE_EQ(
pivots.dtype(),
phi::DataType::INT32,
phi::errors::InvalidArgument(
"The pivots of lu_unpack must be of type int32, but received [%s].",
pivots.dtype()));
Unpack_Pivot<Context, T>(dev_ctx, pivots, pmat, m, k);
}
}
......
......@@ -200,6 +200,19 @@ class TestLU_UnpackOp3(TestLU_UnpackOp):
self.dtype = "float64"
# batchsize = 0
class TestLU_UnpackOp4(TestLU_UnpackOp):
"""
case 4
"""
def config(self):
self.x_shape = [10, 12]
self.unpack_ludata = True
self.unpack_pivots = True
self.dtype = "float64"
class TestLU_UnpackAPI(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册