From 2af286a66f2b7db159219f28ea06fdd0ba65fd94 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Wed, 13 Jul 2022 13:15:16 +0800 Subject: [PATCH] fix bugs of paddle.linalg.lstsq (#44280) --- paddle/fluid/operators/lstsq_op.cu | 23 +++++++++++----- .../tests/unittests/test_linalg_lstsq_op.py | 26 ++++++++++++++++--- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/lstsq_op.cu b/paddle/fluid/operators/lstsq_op.cu index d0b44d0ec88..82a56af7eb4 100644 --- a/paddle/fluid/operators/lstsq_op.cu +++ b/paddle/fluid/operators/lstsq_op.cu @@ -100,7 +100,7 @@ class LstsqCUDAKernel : public framework::OpKernel { true, batch_count, m, - n, + nrhs, k, x_data, x_stride, @@ -137,14 +137,17 @@ class LstsqCUDAKernel : public framework::OpKernel { // Step 2, solve R^H Z = Y Tensor trans_r = dito.Transpose(new_x); + Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn}); + Tensor res_r = dito.TrilTriu(slice_r, 0, false); + phi::TriangularSolveKernel( - phi_dev_ctx, trans_r, new_y, true, true, false, solution); + phi_dev_ctx, res_r, new_y, true, true, false, solution); // Step 3, X <- Q Z BatchedOrgqr(dev_ctx, batch_count, n, - n, + m, min_mn, x_data, n, @@ -183,8 +186,6 @@ void BatchedOrmqr( auto handle = dev_ctx.cusolver_dn_handle(); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize( handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); auto info = memory::Alloc(dev_ctx, sizeof(int)); int* info_d = reinterpret_cast(info->ptr()); @@ -192,6 +193,11 @@ void BatchedOrmqr( float* a_working_ptr = &a[i * a_stride]; float* tau_working_ptr = &tau[i * tau_stride]; float* other_working_ptr = &other[i * other_stride]; + + handle = dev_ctx.cusolver_dn_handle(); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + // compute ormgr PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cusolverDnSormqr(handle, @@ -249,8 +255,6 @@ void BatchedOrmqr( auto handle = dev_ctx.cusolver_dn_handle(); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize( handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); auto info = memory::Alloc(dev_ctx, sizeof(int)); int* info_d = reinterpret_cast(info->ptr()); @@ -258,6 +262,11 @@ void BatchedOrmqr( double* a_working_ptr = &a[i * a_stride]; double* tau_working_ptr = &tau[i * tau_stride]; double* other_working_ptr = &other[i * other_stride]; + + handle = dev_ctx.cusolver_dn_handle(); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + // compute ormgr PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cusolverDnDormqr(handle, diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py index 07729ae4e79..60414b8de97 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -175,6 +175,16 @@ class LinalgLstsqTestCase2(LinalgLstsqTestCase): self._input_shape_2 = (5, 8) +class LinalgLstsqTestCase3(LinalgLstsqTestCase): + + def init_config(self): + self.dtype = 'float64' + self.rcond = 1e-15 + self.driver = "gels" + self._input_shape_1 = (10, 7, 3) + self._input_shape_2 = (10, 7, 6) + + class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase): def init_config(self): @@ -192,7 +202,17 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase): self.rcond = None self.driver = "gels" self._input_shape_1 = (10, 5) - self._input_shape_2 = (10, 2) + self._input_shape_2 = (10, 8) + + +class LinalgLstsqTestCaseGelsFloat64(LinalgLstsqTestCase): + + def init_config(self): + self.dtype = 'float32' + self.rcond = None + self.driver = "gels" + self._input_shape_1 = (3, 2, 8) + self._input_shape_2 = (3, 2, 15) class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase): @@ -230,9 +250,9 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase): def init_config(self): self.dtype = 'float64' self.rcond = 1e-15 - self.driver = "gelss" + self.driver = "gels" self._input_shape_1 = (10, 8, 6) - self._input_shape_2 = (10, 8, 2) + self._input_shape_2 = (10, 8, 10) class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase): -- GitLab