未验证 提交 2af286a6 编写于 作者: H Haohongxiang 提交者: GitHub

fix bugs of paddle.linalg.lstsq (#44280)

上级 7cf72a38
......@@ -100,7 +100,7 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
true,
batch_count,
m,
n,
nrhs,
k,
x_data,
x_stride,
......@@ -137,14 +137,17 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
// Step 2, solve R^H Z = Y
Tensor trans_r = dito.Transpose(new_x);
Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn});
Tensor res_r = dito.TrilTriu(slice_r, 0, false);
phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, trans_r, new_y, true, true, false, solution);
phi_dev_ctx, res_r, new_y, true, true, false, solution);
// Step 3, X <- Q Z
BatchedOrgqr<DeviceContext, T>(dev_ctx,
batch_count,
n,
n,
m,
min_mn,
x_data,
n,
......@@ -183,8 +186,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>(
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
......@@ -192,6 +193,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>(
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
float* other_working_ptr = &other[i * other_stride];
handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnSormqr(handle,
......@@ -249,8 +255,6 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>(
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
......@@ -258,6 +262,11 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>(
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
double* other_working_ptr = &other[i * other_stride];
handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDormqr(handle,
......
......@@ -175,6 +175,16 @@ class LinalgLstsqTestCase2(LinalgLstsqTestCase):
self._input_shape_2 = (5, 8)
class LinalgLstsqTestCase3(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gels"
self._input_shape_1 = (10, 7, 3)
self._input_shape_2 = (10, 7, 6)
class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase):
def init_config(self):
......@@ -192,7 +202,17 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase):
self.rcond = None
self.driver = "gels"
self._input_shape_1 = (10, 5)
self._input_shape_2 = (10, 2)
self._input_shape_2 = (10, 8)
class LinalgLstsqTestCaseGelsFloat64(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float32'
self.rcond = None
self.driver = "gels"
self._input_shape_1 = (3, 2, 8)
self._input_shape_2 = (3, 2, 15)
class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase):
......@@ -230,9 +250,9 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelss"
self.driver = "gels"
self._input_shape_1 = (10, 8, 6)
self._input_shape_2 = (10, 8, 2)
self._input_shape_2 = (10, 8, 10)
class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册