未验证 提交 405103d8 编写于 作者: H Haohongxiang 提交者: GitHub

Add gpu kernel for new api : linalg.lstsq (#38621)

* add lstsq gpu kernel

* update

* add docs_en

* modify ut

* fix bugs

* modify example in docs_en

* remove lstsq_op.cu from ROCM cmake

* modify docs_en

* modify docs_en

* modify docs_en

* remove unneccessary TensorCopy
上级 c50c22b0
...@@ -203,6 +203,7 @@ function(op_library TARGET) ...@@ -203,6 +203,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu") list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu")
list(REMOVE_ITEM hip_srcs "qr_op.cu") list(REMOVE_ITEM hip_srcs "qr_op.cu")
list(REMOVE_ITEM hip_srcs "eigh_op.cu") list(REMOVE_ITEM hip_srcs "eigh_op.cu")
list(REMOVE_ITEM hip_srcs "lstsq_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include <string>
#include <vector>
#include "paddle/fluid/operators/lstsq_op.h"
#include "paddle/fluid/operators/qr_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
template <typename DeviceContext, typename T>
class LstsqCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
auto* solution = context.Output<Tensor>("Solution");
auto dito =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
T>(context);
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto x_dims = x.dims();
auto y_dims = y.dims();
int dim_size = x_dims.size();
int m = x_dims[dim_size - 2];
int n = x_dims[dim_size - 1];
int nrhs = y_dims[dim_size - 1];
int min_mn = std::min(m, n);
int max_mn = std::max(m, n);
int k = min_mn;
int x_stride = MatrixStride(x);
int y_stride = MatrixStride(y);
int tau_stride = min_mn;
int batch_count = BatchCount(x);
Tensor new_x, new_y;
new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T)));
new_y.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * nrhs * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
framework::TensorCopy(y, context.GetPlace(), &new_y);
// Prepare tau
auto tau_dims_vec = framework::vectorize<int>(x_dims);
tau_dims_vec.pop_back();
tau_dims_vec[tau_dims_vec.size() - 1] = min_mn;
Tensor tau = dito.Fill(tau_dims_vec, 0);
auto tau_data = tau.mutable_data<T>(context.GetPlace());
if (m >= n) {
Tensor tmp_x = dito.Transpose(new_x);
Tensor tmp_y = dito.Transpose(new_y);
auto x_data = tmp_x.mutable_data<T>(context.GetPlace());
auto y_data = tmp_y.mutable_data<T>(context.GetPlace());
// step 1, compute QR factorization using geqrf
BatchedGeqrf<DeviceContext, T>(dev_ctx, batch_count, m, n, x_data, m,
tau_data, x_stride, tau_stride);
// Step 2, Y <- Q^H Y
BatchedOrmqr<DeviceContext, T>(dev_ctx, true, true, batch_count, m, n, k,
x_data, x_stride, tau_data, tau_stride,
y_data, y_stride);
Tensor trans_r = dito.Transpose(tmp_x);
Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn});
Tensor res_r = dito.TrilTriu(slice_r, 0, false);
Tensor trans_y = dito.Transpose(tmp_y);
Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn});
// Step 3, solve R X = Y
triangular_solve<DeviceContext, T>(dev_ctx, res_r, slice_y, solution,
true, false, false);
} else {
auto x_data = new_x.mutable_data<T>(context.GetPlace());
auto y_data = new_y.mutable_data<T>(context.GetPlace());
// step 1, compute QR factorization using geqrf
BatchedGeqrf<DeviceContext, T>(dev_ctx, batch_count, n, m, x_data, n,
tau_data, x_stride, tau_stride);
// Step 2, solve R^H Z = Y
Tensor trans_r = dito.Transpose(new_x);
triangular_solve<DeviceContext, T>(dev_ctx, trans_r, new_y, solution,
true, true, false);
// Step 3, X <- Q Z
BatchedOrgqr<DeviceContext, T>(dev_ctx, batch_count, n, n, min_mn, x_data,
n, tau_data, x_stride, tau_stride);
Tensor trans_q = dito.Transpose(new_x);
Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m});
Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false);
framework::TensorCopy(solu_tensor, solution->place(), solution);
}
}
};
template <>
void BatchedOrmqr<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& dev_ctx, bool left, bool transpose,
int batch_size, int m, int n, int k, float* a, int a_stride, float* tau,
int tau_stride, float* other, int other_stride) {
int lwork = 0;
auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
int lda = std::max<int>(1, left ? m : n);
int ldc = std::max<int>(1, m);
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
float* other_working_ptr = &other[i * other_stride];
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr(
handle, side, trans, m, n, k, a_working_ptr, lda, tau_working_ptr,
other_working_ptr, ldc, workspace_ptr, lwork, info_d));
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(), &info_h,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
info_d, sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver info is not zero but [%d]", i, info_h));
}
}
template <>
void BatchedOrmqr<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& dev_ctx, bool left, bool transpose,
int batch_size, int m, int n, int k, double* a, int a_stride, double* tau,
int tau_stride, double* other, int other_stride) {
int lwork = 0;
auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
int lda = std::max<int>(1, left ? m : n);
int ldc = std::max<int>(1, m);
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());
for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
double* other_working_ptr = &other[i * other_stride];
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr(
handle, side, trans, m, n, k, a_working_ptr, lda, tau_working_ptr,
other_working_ptr, ldc, workspace_ptr, lwork, info_d));
// check the error info
int info_h;
memory::Copy(platform::CPUPlace(), &info_h,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
info_d, sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver info is not zero but [%d]", i, info_h));
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lstsq, ops::LstsqCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LstsqCUDAKernel<paddle::platform::CUDADeviceContext, double>);
#endif // not PADDLE_WITH_HIP
...@@ -49,7 +49,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> { ...@@ -49,7 +49,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
using ValueType = math::Real<T>; using ValueType = math::Real<T>;
const Tensor& x = *context.Input<Tensor>("X"); const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y"); auto y = context.Input<Tensor>("Y");
auto rcond = context.Attr<float>("rcond"); auto rcond = context.Attr<float>("rcond");
auto driver_string = context.Attr<std::string>("driver"); auto driver_string = context.Attr<std::string>("driver");
...@@ -68,13 +68,15 @@ class LstsqCPUKernel : public framework::OpKernel<T> { ...@@ -68,13 +68,15 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
math::DeviceIndependenceTensorOperations<DeviceContext, T>(context); math::DeviceIndependenceTensorOperations<DeviceContext, T>(context);
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y->dims();
int dim_size = x_dims.size(); int dim_size = x_dims.size();
int x_stride = MatrixStride(x); int x_stride = MatrixStride(x);
int y_stride = MatrixStride(y); int y_stride = MatrixStride(*y);
int batch_count = BatchCount(x); int batch_count = BatchCount(x);
auto ori_solution_dim = solution->dims(); auto solution_dim = solution->dims();
int ori_solu_stride = MatrixStride(*solution); int ori_solu_stride = MatrixStride(*solution);
int max_solu_stride = std::max(y_stride, ori_solu_stride);
int min_solu_stride = std::min(y_stride, ori_solu_stride);
// lapack is a column-major storge, transpose make the input to // lapack is a column-major storge, transpose make the input to
// have a continuous memory layout // have a continuous memory layout
...@@ -88,13 +90,24 @@ class LstsqCPUKernel : public framework::OpKernel<T> { ...@@ -88,13 +90,24 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
Tensor new_x; Tensor new_x;
new_x.mutable_data<T>(context.GetPlace(), new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T))); size_t(batch_count * m * n * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
solution->mutable_data<T>( solution->mutable_data<T>(
context.GetPlace(), context.GetPlace(),
size_t(batch_count * std::max(m, n) * nrhs * sizeof(T))); size_t(batch_count * std::max(m, n) * nrhs * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
framework::TensorCopy(y, context.GetPlace(), solution);
if (m < n) solution->Resize(UDDim(ori_solution_dim)); if (m >= n) {
const Tensor& new_y = *context.Input<Tensor>("Y");
framework::TensorCopy(new_y, context.GetPlace(), solution);
} else {
auto* solu_data = solution->data<T>();
auto* y_data = y->data<T>();
for (auto i = 0; i < batch_count; i++) {
for (auto j = 0; j < min_solu_stride; j++) {
solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j];
}
}
}
Tensor input_x_trans = dito.Transpose(new_x); Tensor input_x_trans = dito.Transpose(new_x);
Tensor input_y_trans = dito.Transpose(*solution); Tensor input_y_trans = dito.Transpose(*solution);
...@@ -186,10 +199,9 @@ class LstsqCPUKernel : public framework::OpKernel<T> { ...@@ -186,10 +199,9 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
iwork_data = iwork.mutable_data<int>(context.GetPlace()); iwork_data = iwork.mutable_data<int>(context.GetPlace());
} }
int solu_stride = std::max(y_stride, ori_solu_stride);
for (auto i = 0; i < batch_count; ++i) { for (auto i = 0; i < batch_count; ++i) {
auto* x_input = &x_vector[i * x_stride]; auto* x_input = &x_vector[i * x_stride];
auto* y_input = &y_vector[i * solu_stride]; auto* y_input = &y_vector[i * max_solu_stride];
rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr;
s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr;
...@@ -221,9 +233,24 @@ class LstsqCPUKernel : public framework::OpKernel<T> { ...@@ -221,9 +233,24 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
Tensor tmp_s = dito.Transpose(*solution); Tensor tmp_s = dito.Transpose(*solution);
framework::TensorCopy(tmp_s, solution->place(), solution); framework::TensorCopy(tmp_s, solution->place(), solution);
if (m >= n) solution->Resize(UDDim(ori_solution_dim)); if (m > n) {
auto* solu_data = solution->data<T>();
for (auto i = 1; i < batch_count; i++) {
for (auto j = 0; j < min_solu_stride; j++) {
solu_data[i * min_solu_stride + j] =
solu_data[i * max_solu_stride + j];
}
}
}
solution->Resize(UDDim(solution_dim));
} }
}; };
template <typename DeviceContext, typename T>
void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, bool transpose,
int batch_size, int m, int n, int k, T* a, int a_stride,
T* tau, int tau_stride, T* other, int other_stride);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -88,8 +88,8 @@ class QrGPUKernel : public framework::OpKernel<T> { ...@@ -88,8 +88,8 @@ class QrGPUKernel : public framework::OpKernel<T> {
auto qr_data = qr.mutable_data<T>(context.GetPlace()); auto qr_data = qr.mutable_data<T>(context.GetPlace());
auto tau_data = tau.mutable_data<T>(context.GetPlace()); auto tau_data = tau.mutable_data<T>(context.GetPlace());
BatchedGeqrf(dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, BatchedGeqrf<platform::CUDADeviceContext, T>(
tau_stride); dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride);
if (reduced_mode) { if (reduced_mode) {
auto trans_qr = dito.Transpose(qr); auto trans_qr = dito.Transpose(qr);
...@@ -108,8 +108,9 @@ class QrGPUKernel : public framework::OpKernel<T> { ...@@ -108,8 +108,9 @@ class QrGPUKernel : public framework::OpKernel<T> {
// Perform QRGQR for Q using the result from GEQRF // Perform QRGQR for Q using the result from GEQRF
// Transpose 'q' to retore the original row-major order // Transpose 'q' to retore the original row-major order
if (reduced_mode) { if (reduced_mode) {
BatchedOrgqr(dev_ctx, batch_size, m, min_mn, min_mn, qr_data, m, BatchedOrgqr<platform::CUDADeviceContext, T>(
tau_data, qr_stride, tau_stride); dev_ctx, batch_size, m, min_mn, min_mn, qr_data, m, tau_data,
qr_stride, tau_stride);
auto trans_q = dito.Transpose(qr); auto trans_q = dito.Transpose(qr);
auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {min_mn}); auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {min_mn});
framework::TensorCopy(sliced_q, q.place(), &q); framework::TensorCopy(sliced_q, q.place(), &q);
...@@ -128,13 +129,15 @@ class QrGPUKernel : public framework::OpKernel<T> { ...@@ -128,13 +129,15 @@ class QrGPUKernel : public framework::OpKernel<T> {
(qr_data + i * qr_stride), qr_stride * sizeof(math::Real<T>), (qr_data + i * qr_stride), qr_stride * sizeof(math::Real<T>),
dev_ctx.stream()); dev_ctx.stream());
} }
BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, new_qr_data, m, BatchedOrgqr<platform::CUDADeviceContext, T>(
tau_data, new_qr_stride, tau_stride); dev_ctx, batch_size, m, m, min_mn, new_qr_data, m, tau_data,
new_qr_stride, tau_stride);
auto trans_q = dito.Transpose(new_qr); auto trans_q = dito.Transpose(new_qr);
framework::TensorCopy(trans_q, q.place(), &q); framework::TensorCopy(trans_q, q.place(), &q);
} else { } else {
BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, qr_data, m, tau_data, BatchedOrgqr<platform::CUDADeviceContext, T>(
qr_stride, tau_stride); dev_ctx, batch_size, m, m, min_mn, qr_data, m, tau_data,
qr_stride, tau_stride);
auto trans_q = dito.Transpose(qr); auto trans_q = dito.Transpose(qr);
auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {m}); auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {m});
framework::TensorCopy(sliced_q, q.place(), &q); framework::TensorCopy(sliced_q, q.place(), &q);
...@@ -142,28 +145,12 @@ class QrGPUKernel : public framework::OpKernel<T> { ...@@ -142,28 +145,12 @@ class QrGPUKernel : public framework::OpKernel<T> {
} }
} }
} }
void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size,
int m, int n, float* a, int lda, float* tau, int a_stride,
int tau_stride) const;
void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size,
int m, int n, double* a, int lda, double* tau, int a_stride,
int tau_stride) const;
void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size,
int m, int n, int k, float* a, int lda, float* tau,
int a_stride, int tau_stride) const;
void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size,
int m, int n, int k, double* a, int lda, double* tau,
int a_stride, int tau_stride) const;
}; };
template <> template <>
void QrGPUKernel<float>::BatchedGeqrf( void BatchedGeqrf<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n,
float* a, int lda, float* tau, int a_stride, int tau_stride) const { float* a, int lda, float* tau, int a_stride, int tau_stride) {
int lwork = 0; int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle(); auto handle = dev_ctx.cusolver_dn_handle();
...@@ -195,9 +182,9 @@ void QrGPUKernel<float>::BatchedGeqrf( ...@@ -195,9 +182,9 @@ void QrGPUKernel<float>::BatchedGeqrf(
} }
template <> template <>
void QrGPUKernel<double>::BatchedGeqrf( void BatchedGeqrf<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n,
double* a, int lda, double* tau, int a_stride, int tau_stride) const { double* a, int lda, double* tau, int a_stride, int tau_stride) {
int lwork = 0; int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle(); auto handle = dev_ctx.cusolver_dn_handle();
...@@ -229,9 +216,9 @@ void QrGPUKernel<double>::BatchedGeqrf( ...@@ -229,9 +216,9 @@ void QrGPUKernel<double>::BatchedGeqrf(
} }
template <> template <>
void QrGPUKernel<float>::BatchedOrgqr( void BatchedOrgqr<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n,
int k, float* a, int lda, float* tau, int a_stride, int tau_stride) const { int k, float* a, int lda, float* tau, int a_stride, int tau_stride) {
int lwork = 0; int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle(); auto handle = dev_ctx.cusolver_dn_handle();
...@@ -263,10 +250,9 @@ void QrGPUKernel<float>::BatchedOrgqr( ...@@ -263,10 +250,9 @@ void QrGPUKernel<float>::BatchedOrgqr(
} }
template <> template <>
void QrGPUKernel<double>::BatchedOrgqr( void BatchedOrgqr<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n,
int k, double* a, int lda, double* tau, int a_stride, int k, double* a, int lda, double* tau, int a_stride, int tau_stride) {
int tau_stride) const {
int lwork = 0; int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle(); auto handle = dev_ctx.cusolver_dn_handle();
......
...@@ -250,5 +250,13 @@ class QrGradKernel : public framework::OpKernel<T> { ...@@ -250,5 +250,13 @@ class QrGradKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, int m, int n,
T* a, int lda, T* tau, int a_stride, int tau_stride);
template <typename DeviceContext, typename T>
void BatchedOrgqr(const DeviceContext& dev_ctx, int batch_size, int m, int n,
int k, T* a, int lda, T* tau, int a_stride, int tau_stride);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -81,6 +81,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); ...@@ -81,6 +81,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
__macro(cusolverDnZgeqrf_bufferSize); \ __macro(cusolverDnZgeqrf_bufferSize); \
__macro(cusolverDnSorgqr_bufferSize); \ __macro(cusolverDnSorgqr_bufferSize); \
__macro(cusolverDnDorgqr_bufferSize); \ __macro(cusolverDnDorgqr_bufferSize); \
__macro(cusolverDnSormqr_bufferSize); \
__macro(cusolverDnDormqr_bufferSize); \
__macro(cusolverDnCungqr_bufferSize); \ __macro(cusolverDnCungqr_bufferSize); \
__macro(cusolverDnZungqr_bufferSize); \ __macro(cusolverDnZungqr_bufferSize); \
__macro(cusolverDnDestroyGesvdjInfo); \ __macro(cusolverDnDestroyGesvdjInfo); \
...@@ -98,6 +100,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); ...@@ -98,6 +100,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
__macro(cusolverDnZgeqrf); \ __macro(cusolverDnZgeqrf); \
__macro(cusolverDnSorgqr); \ __macro(cusolverDnSorgqr); \
__macro(cusolverDnDorgqr); \ __macro(cusolverDnDorgqr); \
__macro(cusolverDnSormqr); \
__macro(cusolverDnDormqr); \
__macro(cusolverDnCungqr); \ __macro(cusolverDnCungqr); \
__macro(cusolverDnZungqr); __macro(cusolverDnZungqr);
......
...@@ -18,11 +18,15 @@ import unittest ...@@ -18,11 +18,15 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
class LinalgLstsqTestCase(unittest.TestCase): class LinalgLstsqTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.devices = ["cpu"]
self.init_config() self.init_config()
if core.is_compiled_with_cuda() and self.driver == "gels":
self.devices.append("gpu:0")
self.generate_input() self.generate_input()
self.generate_output() self.generate_output()
...@@ -43,104 +47,129 @@ class LinalgLstsqTestCase(unittest.TestCase): ...@@ -43,104 +47,129 @@ class LinalgLstsqTestCase(unittest.TestCase):
if len(self._input_shape_1) == 2: if len(self._input_shape_1) == 2:
out = np.linalg.lstsq( out = np.linalg.lstsq(
self._input_data_1, self._input_data_2, rcond=self.rcond) self._input_data_1, self._input_data_2, rcond=self.rcond)
self._output_solution = out[0]
self._output_residuals = out[1]
self._output_rank = out[2]
self._output_sg_values = out[3]
elif len(self._input_shape_1) == 3: elif len(self._input_shape_1) == 3:
out = np.linalg.lstsq( self._output_solution = []
self._input_data_1[0], self._input_data_2[0], rcond=self.rcond) self._output_residuals = []
self._output_rank = []
self._output_solution = out[0] self._output_sg_values = []
self._output_residuals = out[1] for i in range(self._input_shape_1[0]):
self._output_rank = out[2] out = np.linalg.lstsq(
self._output_sg_values = out[3] self._input_data_1[i],
self._input_data_2[i],
rcond=self.rcond)
self._output_solution.append(out[0])
self._output_residuals.append(out[1])
self._output_rank.append(out[2])
self._output_sg_values.append(out[3])
def test_dygraph(self): def test_dygraph(self):
paddle.disable_static() paddle.disable_static()
paddle.device.set_device("cpu") for dev in self.devices:
place = paddle.CPUPlace() paddle.set_device(dev)
x = paddle.to_tensor(self._input_data_1, place=place, dtype=self.dtype) place = paddle.CPUPlace() if dev == "cpu" else paddle.CUDAPlace(0)
y = paddle.to_tensor(self._input_data_2, place=place, dtype=self.dtype) x = paddle.to_tensor(
results = paddle.linalg.lstsq( self._input_data_1, place=place, dtype=self.dtype)
x, y, rcond=self.rcond, driver=self.driver) y = paddle.to_tensor(
self._input_data_2, place=place, dtype=self.dtype)
res_solution = results[0].numpy()
res_residuals = results[1].numpy()
res_rank = results[2].numpy()
res_singular_values = results[3].numpy()
if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]:
if (np.abs(res_residuals - self._output_residuals) < 1e-6).any():
pass
else:
raise RuntimeError("Check LSTSQ residuals dygraph Failed")
if self.driver in ("gelsy", "gelsd", "gelss"):
if (np.abs(res_rank - self._output_rank) < 1e-6).any():
pass
else:
raise RuntimeError("Check LSTSQ rank dygraph Failed")
if self.driver in ("gelsd", "gelss"):
if (np.abs(res_singular_values - self._output_sg_values) < 1e-6
).any():
pass
else:
raise RuntimeError("Check LSTSQ singular values dygraph Failed")
def test_static(self):
paddle.enable_static()
paddle.device.set_device("cpu")
place = fluid.CPUPlace()
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = paddle.fluid.data(
name="x",
shape=self._input_shape_1,
dtype=self._input_data_1.dtype)
y = paddle.fluid.data(
name="y",
shape=self._input_shape_2,
dtype=self._input_data_2.dtype)
results = paddle.linalg.lstsq( results = paddle.linalg.lstsq(
x, y, rcond=self.rcond, driver=self.driver) x, y, rcond=self.rcond, driver=self.driver)
exe = fluid.Executor(place) self._result_solution = results[0].numpy()
fetches = exe.run( self._result_residuals = results[1].numpy()
fluid.default_main_program(), self._result_rank = results[2].numpy()
feed={"x": self._input_data_1, self._result_sg_values = results[3].numpy()
"y": self._input_data_2}, self.assert_np_close()
fetch_list=[results])
if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]:
if (np.abs(fetches[1] - self._output_residuals) < 1e-6).any():
pass
else:
raise RuntimeError("Check LSTSQ residuals static Failed")
def test_static(self):
paddle.enable_static()
for dev in self.devices:
paddle.set_device(dev)
place = fluid.CPUPlace() if dev == "cpu" else fluid.CUDAPlace(0)
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = paddle.fluid.data(
name="x",
shape=self._input_shape_1,
dtype=self._input_data_1.dtype)
y = paddle.fluid.data(
name="y",
shape=self._input_shape_2,
dtype=self._input_data_2.dtype)
results = paddle.linalg.lstsq(
x, y, rcond=self.rcond, driver=self.driver)
exe = fluid.Executor(place)
fetches = exe.run(
fluid.default_main_program(),
feed={"x": self._input_data_1,
"y": self._input_data_2},
fetch_list=[results])
self._result_solution = fetches[0]
self._result_residuals = fetches[1]
self._result_rank = fetches[2]
self._result_sg_values = fetches[3]
self.assert_np_close()
def assert_np_close(self):
if len(self._input_shape_1) == 2:
np.testing.assert_allclose(
self._result_solution, self._output_solution, rtol=1e-3)
if self._input_shape_1[-2] > self._input_shape_1[
-1] and self._output_rank == self._input_shape_1[-1]:
np.testing.assert_allclose(
self._result_residuals, self._output_residuals, rtol=1e-5)
if self.driver in ("gelsy", "gelsd", "gelss"): if self.driver in ("gelsy", "gelsd", "gelss"):
if (np.abs(fetches[2] - self._output_rank) < 1e-6).any(): np.testing.assert_allclose(
pass self._result_rank, self._output_rank, rtol=1e-5)
else:
raise RuntimeError("Check LSTSQ rank static Failed")
if self.driver in ("gelsd", "gelss"): if self.driver in ("gelsd", "gelss"):
if (np.abs(fetches[3] - self._output_sg_values) < 1e-6).any(): np.testing.assert_allclose(
pass self._result_sg_values, self._output_sg_values, rtol=1e-5)
else: else:
raise RuntimeError( for i in range(len(self._output_solution)):
"Check LSTSQ singular values static Failed") np.testing.assert_allclose(
self._result_solution[i],
self._output_solution[i],
rtol=1e-3)
if self._input_shape_1[-2] > self._input_shape_1[
-1] and self._output_rank[i] == self._input_shape_1[-1]:
np.testing.assert_allclose(
self._result_residuals[i],
self._output_residuals[i],
rtol=1e-5)
if self.driver in ("gelsy", "gelsd", "gelss"):
np.testing.assert_allclose(
self._result_rank[i], self._output_rank[i], rtol=1e-5)
if self.driver in ("gelsd", "gelss"):
np.testing.assert_allclose(
self._result_sg_values[i],
self._output_sg_values[i],
rtol=1e-5)
class LinalgLstsqTestCase1(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.driver = "gels"
self._input_shape_1 = (9, 9)
self._input_shape_2 = (9, 5)
class LinalgLstsqTestCase(LinalgLstsqTestCase): class LinalgLstsqTestCase2(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
self.dtype = 'float64' self.dtype = 'float64'
self.rcond = 1e-15 self.rcond = 1e-15
self.driver = "gels" self.driver = "gels"
self._input_shape_1 = (5, 10) self._input_shape_1 = (5, 10)
self._input_shape_2 = (5, 5) self._input_shape_2 = (5, 8)
class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase): class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
self.dtype = 'float64' self.dtype = 'float64'
self.rcond = 0.1 self.rcond = 1e-7
self.driver = "gels" self.driver = "gelsd"
self._input_shape_1 = (3, 2) self._input_shape_1 = (3, 2)
self._input_shape_2 = (3, 3) self._input_shape_2 = (3, 3)
...@@ -148,7 +177,7 @@ class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase): ...@@ -148,7 +177,7 @@ class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase):
class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase): class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
self.dtype = 'float32' self.dtype = 'float32'
self.rcond = 1e-15 self.rcond = None
self.driver = "gels" self.driver = "gels"
self._input_shape_1 = (10, 5) self._input_shape_1 = (10, 5)
self._input_shape_2 = (10, 2) self._input_shape_2 = (10, 2)
...@@ -157,7 +186,7 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase): ...@@ -157,7 +186,7 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase):
class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase): class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
self.dtype = 'float64' self.dtype = 'float64'
self.rcond = 1e-15 self.rcond = None
self.driver = "gelss" self.driver = "gelss"
self._input_shape_1 = (5, 5) self._input_shape_1 = (5, 5)
self._input_shape_2 = (5, 1) self._input_shape_2 = (5, 1)
...@@ -176,7 +205,7 @@ class LinalgLstsqTestCaseBatch1(LinalgLstsqTestCase): ...@@ -176,7 +205,7 @@ class LinalgLstsqTestCaseBatch1(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
self.dtype = 'float32' self.dtype = 'float32'
self.rcond = 1e-15 self.rcond = 1e-15
self.driver = None self.driver = "gelss"
self._input_shape_1 = (2, 3, 10) self._input_shape_1 = (2, 3, 10)
self._input_shape_2 = (2, 3, 4) self._input_shape_2 = (2, 3, 4)
...@@ -186,8 +215,8 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase): ...@@ -186,8 +215,8 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase):
self.dtype = 'float64' self.dtype = 'float64'
self.rcond = 1e-15 self.rcond = 1e-15
self.driver = "gelss" self.driver = "gelss"
self._input_shape_1 = (2, 8, 6) self._input_shape_1 = (10, 8, 6)
self._input_shape_2 = (2, 8, 2) self._input_shape_2 = (10, 8, 2)
class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase): class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase):
...@@ -201,7 +230,7 @@ class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase): ...@@ -201,7 +230,7 @@ class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase):
class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase): class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase):
def init_config(self): def init_config(self):
self.dtype = 'float32' self.dtype = 'float64'
self.rcond = 1e-15 self.rcond = 1e-15
self.driver = "gelss" self.driver = "gelss"
self._input_shape_1 = (50, 600) self._input_shape_1 = (50, 600)
......
...@@ -2816,8 +2816,66 @@ def eigvalsh(x, UPLO='L', name=None): ...@@ -2816,8 +2816,66 @@ def eigvalsh(x, UPLO='L', name=None):
return out_value return out_value
def lstsq(x, y, rcond=1e-15, driver=None, name=None): def lstsq(x, y, rcond=None, driver=None, name=None):
device = paddle.device.get_device() """
Computes a solution to
the least squares problem of a system of linear equations.
Args:
x (Tensor): A tensor with shape ``(*, M, N)`` , the data type of the input Tensor ``x``
should be one of float32, float64.
y (Tensor): A tensor with shape ``(*, M, K)`` , the data type of the input Tensor ``y``
should be one of float32, float64.
rcond(float, optional): The default value is None. A float pointing number used to determine
the effective rank of ``x``. If ``rcond`` is None, it will be set to max(M, N) times the
machine precision of x_dtype.
driver(str, optional): The default value is None. The name of LAPACK method to be used. For
CPU inputs the valid values are ‘gels’, ‘gelsy’, ‘gelsd, ‘gelss’. For CUDA input, the only
valid driver is ‘gels’. If ``driver`` is None, ‘gelsy’ is used for CPU inputs and ‘gels’
for CUDA inputs.
name(str, optional): The default value is None. Normally there is no need for user to set
this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tuple: A tuple of 4 Tensors which is (``solution``, ``residuals``, ``rank``, ``singular_values``).
``solution`` is a tensor with shape ``(*, N, K)``, meaning the least squares solution. ``residuals``
is a tensor with shape ``(*, K)``, meaning the squared residuals of the solutions, which is computed
when M > N and every matrix in ``x`` is full-rank, otherwise return an empty tensor. ``rank`` is a tensor
with shape ``(*)``, meaning the ranks of the matrices in ``x``, which is computed when ``driver`` in
(‘gelsy’, ‘gelsd’, ‘gelss’), otherwise return an empty tensor. ``singular_values`` is a tensor with
shape ``(*, min(M, N))``, meaning singular values of the matrices in ``x``, which is computed when
``driver`` in (‘gelsd’, ‘gelss’), otherwise return an empty tensor.
Examples:
.. code-block:: python
import paddle
paddle.set_device("cpu")
x = paddle.to_tensor([[1, 3], [3, 2], [5, 6.]])
y = paddle.to_tensor([[3, 4, 6], [5, 3, 4], [1, 2, 1.]])
results = paddle.linalg.lstsq(x, y, driver="gelsd")
print(results[0])
# [[ 0.78350395, -0.22165027, -0.62371236],
# [-0.11340097, 0.78866047, 1.14948535]]
print(results[1])
# [19.81443405, 10.43814468, 30.56185532])
print(results[2])
# 2
print(results[3])
# [9.03455734, 1.54167950]
x = paddle.to_tensor([[10, 2, 3], [3, 10, 5], [5, 6, 12.]])
y = paddle.to_tensor([[4, 2, 9], [2, 0, 3], [2, 5, 3.]])
results = paddle.linalg.lstsq(x, y, driver="gels")
print(results[0])
# [[ 0.39386186, 0.10230173, 0.93606132],
# [ 0.10741687, -0.29028133, 0.11892585],
# [-0.05115091, 0.51918161, -0.19948854]]
print(results[1])
# []
"""
device = paddle.get_device()
if device == "cpu": if device == "cpu":
if driver not in (None, "gels", "gelss", "gelsd", "gelsy"): if driver not in (None, "gels", "gelss", "gelsd", "gelsy"):
raise ValueError( raise ValueError(
...@@ -2833,6 +2891,19 @@ def lstsq(x, y, rcond=1e-15, driver=None, name=None): ...@@ -2833,6 +2891,19 @@ def lstsq(x, y, rcond=1e-15, driver=None, name=None):
else: else:
raise RuntimeError("Only support lstsq api for CPU or CUDA device.") raise RuntimeError("Only support lstsq api for CPU or CUDA device.")
if x.dtype == y.dtype and x.dtype in (paddle.float32, paddle.float64):
pass
else:
raise ValueError(
"Only support x and y have the same dtype such as 'float32' and 'float64'."
)
if rcond is None:
if x.dtype == paddle.float32:
rcond = 1e-7 * max(x.shape[-2], x.shape[-1])
elif x.dtype == paddle.float64:
rcond = 1e-15 * max(x.shape[-2], x.shape[-1])
if in_dygraph_mode(): if in_dygraph_mode():
solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond, solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond,
"driver", driver) "driver", driver)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册