未验证 提交 2af286a6 编写于 作者: H Haohongxiang 提交者: GitHub

fix bugs of paddle.linalg.lstsq (#44280)

上级 7cf72a38
...@@ -100,7 +100,7 @@ class LstsqCUDAKernel : public framework::OpKernel<T> { ...@@ -100,7 +100,7 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
true, true,
batch_count, batch_count,
m, m,
n, nrhs,
k, k,
x_data, x_data,
x_stride, x_stride,
...@@ -137,14 +137,17 @@ class LstsqCUDAKernel : public framework::OpKernel<T> { ...@@ -137,14 +137,17 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
// Step 2, solve R^H Z = Y // Step 2, solve R^H Z = Y
Tensor trans_r = dito.Transpose(new_x); Tensor trans_r = dito.Transpose(new_x);
Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn});
Tensor res_r = dito.TrilTriu(slice_r, 0, false);
phi::TriangularSolveKernel<T, Context>( phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, trans_r, new_y, true, true, false, solution); phi_dev_ctx, res_r, new_y, true, true, false, solution);
// Step 3, X <- Q Z // Step 3, X <- Q Z
BatchedOrgqr<DeviceContext, T>(dev_ctx, BatchedOrgqr<DeviceContext, T>(dev_ctx,
batch_count, batch_count,
n, n,
n, m,
min_mn, min_mn,
x_data, x_data,
n, n,
...@@ -183,8 +186,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>( ...@@ -183,8 +186,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>(
auto handle = dev_ctx.cusolver_dn_handle(); auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int)); auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr()); int* info_d = reinterpret_cast<int*>(info->ptr());
...@@ -192,6 +193,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>( ...@@ -192,6 +193,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>(
float* a_working_ptr = &a[i * a_stride]; float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride]; float* tau_working_ptr = &tau[i * tau_stride];
float* other_working_ptr = &other[i * other_stride]; float* other_working_ptr = &other[i * other_stride];
handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
// compute ormgr // compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnSormqr(handle, platform::dynload::cusolverDnSormqr(handle,
...@@ -249,8 +255,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>( ...@@ -249,8 +255,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>(
auto handle = dev_ctx.cusolver_dn_handle(); auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int)); auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr()); int* info_d = reinterpret_cast<int*>(info->ptr());
...@@ -258,6 +262,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>( ...@@ -258,6 +262,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>(
double* a_working_ptr = &a[i * a_stride]; double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride]; double* tau_working_ptr = &tau[i * tau_stride];
double* other_working_ptr = &other[i * other_stride]; double* other_working_ptr = &other[i * other_stride];
handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
// compute ormgr // compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDormqr(handle, platform::dynload::cusolverDnDormqr(handle,
......
...@@ -175,6 +175,16 @@ class LinalgLstsqTestCase2(LinalgLstsqTestCase): ...@@ -175,6 +175,16 @@ class LinalgLstsqTestCase2(LinalgLstsqTestCase):
self._input_shape_2 = (5, 8) self._input_shape_2 = (5, 8)
class LinalgLstsqTestCase3(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gels"
self._input_shape_1 = (10, 7, 3)
self._input_shape_2 = (10, 7, 6)
class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase): class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
...@@ -192,7 +202,17 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase): ...@@ -192,7 +202,17 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase):
self.rcond = None self.rcond = None
self.driver = "gels" self.driver = "gels"
self._input_shape_1 = (10, 5) self._input_shape_1 = (10, 5)
self._input_shape_2 = (10, 2) self._input_shape_2 = (10, 8)
class LinalgLstsqTestCaseGelsFloat64(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float32'
self.rcond = None
self.driver = "gels"
self._input_shape_1 = (3, 2, 8)
self._input_shape_2 = (3, 2, 15)
class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase): class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase):
...@@ -230,9 +250,9 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase): ...@@ -230,9 +250,9 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
self.dtype = 'float64' self.dtype = 'float64'
self.rcond = 1e-15 self.rcond = 1e-15
self.driver = "gelss" self.driver = "gels"
self._input_shape_1 = (10, 8, 6) self._input_shape_1 = (10, 8, 6)
self._input_shape_2 = (10, 8, 2) self._input_shape_2 = (10, 8, 10)
class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase): class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册