From 6f8ec229991ed555c8a1ccdab9660dbe3cf9c137 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Mon, 30 Jan 2023 18:13:13 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20=E7=A9=BA=E6=8C=87=E9=92=88=20(Null=20poi?= =?UTF-8?q?nter)=20of=20case=202=20paddle.linalg.lu=5Funpack=20(#49976)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add pivots type check and fix batchsize error * add unittest for batchsize = 0 * fix nullptr in lu_unpack fix batchsize error in LU_Unpack add nullptr check in OneFunctor * remove exception in device code --- paddle/phi/kernels/impl/lu_kernel_impl.h | 8 ++++++-- paddle/phi/kernels/impl/lu_unpack_kernel_impl.h | 8 ++++++++ .../fluid/tests/unittests/test_lu_unpack_op.py | 13 +++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/impl/lu_kernel_impl.h b/paddle/phi/kernels/impl/lu_kernel_impl.h index ed3cc0801d..31a83ea540 100644 --- a/paddle/phi/kernels/impl/lu_kernel_impl.h +++ b/paddle/phi/kernels/impl/lu_kernel_impl.h @@ -443,7 +443,10 @@ void LU_Unpack(const Context& dev_ctx, auto dim = std::min(H, W); DenseTensor rowtensor, rt_dev; auto batchsize = product(phi::slice_ddim(udims, 0, udims.size() - 2)); - batchsize = std::max(static_cast(batchsize), 1); + + // if udims is [0, ..., H, W], it should be 0 + if (udims.size() == 2) batchsize = std::max(static_cast(batchsize), 1); + arange(dev_ctx, &rowtensor, dim, batchsize, H); auto idtptr = rowtensor.data(); if (phi::AllocationType::GPU == dev_ctx.GetPlace().GetType()) { @@ -494,7 +497,8 @@ void Unpack_Pivot(const Context& dev_ctx, setter(dev_ctx, P, static_cast(0)); auto batchsize = product(phi::slice_ddim(dims, 0, prank - 1)); - batchsize = std::max(static_cast(batchsize), 1); + if (prank == 1) batchsize = std::max(static_cast(batchsize), 1); + DenseTensor idt; for (int i = 0; i < batchsize; i++) { arange(dev_ctx, &idt, h); diff --git a/paddle/phi/kernels/impl/lu_unpack_kernel_impl.h b/paddle/phi/kernels/impl/lu_unpack_kernel_impl.h index 7e77fdd171..dce73af550 100644 --- a/paddle/phi/kernels/impl/lu_unpack_kernel_impl.h +++ b/paddle/phi/kernels/impl/lu_unpack_kernel_impl.h @@ -51,6 +51,14 @@ void LUUnpackKernel(const Context& dev_ctx, if (unpack_pivots) { dev_ctx.template Alloc(pmat); + + PADDLE_ENFORCE_EQ( + pivots.dtype(), + phi::DataType::INT32, + phi::errors::InvalidArgument( + "The pivots of lu_unpack must be of type int32, but received [%s].", + pivots.dtype())); + Unpack_Pivot(dev_ctx, pivots, pmat, m, k); } } diff --git a/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py b/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py index 677ae648fb..d05b16df25 100644 --- a/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py +++ b/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py @@ -200,6 +200,19 @@ class TestLU_UnpackOp3(TestLU_UnpackOp): self.dtype = "float64" +# batchsize = 0 +class TestLU_UnpackOp4(TestLU_UnpackOp): + """ + case 4 + """ + + def config(self): + self.x_shape = [10, 12] + self.unpack_ludata = True + self.unpack_pivots = True + self.dtype = "float64" + + class TestLU_UnpackAPI(unittest.TestCase): def setUp(self): np.random.seed(2022) -- GitLab