From 405103d8ada42aacc42aa326b9381e851f240ffa Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 10 Jan 2022 19:55:49 +0800 Subject: [PATCH] Add gpu kernel for new api : linalg.lstsq (#38621) * add lstsq gpu kernel * update * add docs_en * modify ut * fix bugs * modify example in docs_en * remove lstsq_op.cu from ROCM cmake * modify docs_en * modify docs_en * modify docs_en * remove unneccessary TensorCopy --- cmake/operators.cmake | 1 + paddle/fluid/operators/lstsq_op.cu | 211 ++++++++++++++++++ paddle/fluid/operators/lstsq_op.h | 47 +++- paddle/fluid/operators/qr_op.cu | 52 ++--- paddle/fluid/operators/qr_op.h | 8 + paddle/fluid/platform/dynload/cusolver.h | 4 + .../tests/unittests/test_linalg_lstsq_op.py | 195 +++++++++------- python/paddle/tensor/linalg.py | 75 ++++++- 8 files changed, 465 insertions(+), 128 deletions(-) create mode 100644 paddle/fluid/operators/lstsq_op.cu diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 2d1ce4e834..2d4aa1a815 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -203,6 +203,7 @@ function(op_library TARGET) list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu") list(REMOVE_ITEM hip_srcs "qr_op.cu") list(REMOVE_ITEM hip_srcs "eigh_op.cu") + list(REMOVE_ITEM hip_srcs "lstsq_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} diff --git a/paddle/fluid/operators/lstsq_op.cu b/paddle/fluid/operators/lstsq_op.cu new file mode 100644 index 0000000000..a71b900f14 --- /dev/null +++ b/paddle/fluid/operators/lstsq_op.cu @@ -0,0 +1,211 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include +#include +#include "paddle/fluid/operators/lstsq_op.h" +#include "paddle/fluid/operators/qr_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +template +class LstsqCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor& x = *context.Input("X"); + const Tensor& y = *context.Input("Y"); + auto* solution = context.Output("Solution"); + + auto dito = + math::DeviceIndependenceTensorOperations(context); + auto& dev_ctx = + context.template device_context(); + + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int dim_size = x_dims.size(); + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + int nrhs = y_dims[dim_size - 1]; + int min_mn = std::min(m, n); + int max_mn = std::max(m, n); + int k = min_mn; + + int x_stride = MatrixStride(x); + int y_stride = MatrixStride(y); + int tau_stride = min_mn; + int batch_count = BatchCount(x); + + Tensor new_x, new_y; + new_x.mutable_data(context.GetPlace(), + size_t(batch_count * m * n * sizeof(T))); + new_y.mutable_data(context.GetPlace(), + size_t(batch_count * m * nrhs * sizeof(T))); + framework::TensorCopy(x, context.GetPlace(), &new_x); + framework::TensorCopy(y, context.GetPlace(), &new_y); + + // Prepare tau + auto tau_dims_vec = framework::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + Tensor tau = dito.Fill(tau_dims_vec, 0); + auto tau_data = tau.mutable_data(context.GetPlace()); + + if (m >= n) { + Tensor tmp_x = dito.Transpose(new_x); + Tensor tmp_y = dito.Transpose(new_y); + auto x_data = tmp_x.mutable_data(context.GetPlace()); + auto y_data = tmp_y.mutable_data(context.GetPlace()); + + // step 1, compute QR factorization using geqrf + BatchedGeqrf(dev_ctx, batch_count, m, n, x_data, m, + tau_data, x_stride, tau_stride); + + // Step 2, Y <- Q^H Y + BatchedOrmqr(dev_ctx, true, true, batch_count, m, n, k, + x_data, x_stride, tau_data, tau_stride, + y_data, y_stride); + + Tensor trans_r = dito.Transpose(tmp_x); + Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn}); + Tensor res_r = dito.TrilTriu(slice_r, 0, false); + + Tensor trans_y = dito.Transpose(tmp_y); + Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn}); + + // Step 3, solve R X = Y + triangular_solve(dev_ctx, res_r, slice_y, solution, + true, false, false); + } else { + auto x_data = new_x.mutable_data(context.GetPlace()); + auto y_data = new_y.mutable_data(context.GetPlace()); + + // step 1, compute QR factorization using geqrf + BatchedGeqrf(dev_ctx, batch_count, n, m, x_data, n, + tau_data, x_stride, tau_stride); + + // Step 2, solve R^H Z = Y + Tensor trans_r = dito.Transpose(new_x); + triangular_solve(dev_ctx, trans_r, new_y, solution, + true, true, false); + + // Step 3, X <- Q Z + BatchedOrgqr(dev_ctx, batch_count, n, n, min_mn, x_data, + n, tau_data, x_stride, tau_stride); + Tensor trans_q = dito.Transpose(new_x); + Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m}); + Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false); + framework::TensorCopy(solu_tensor, solution->place(), solution); + } + } +}; + +template <> +void BatchedOrmqr( + const platform::CUDADeviceContext& dev_ctx, bool left, bool transpose, + int batch_size, int m, int n, int k, float* a, int a_stride, float* tau, + int tau_stride, float* other, int other_stride) { + int lwork = 0; + auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; + auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = std::max(1, left ? m : n); + int ldc = std::max(1, m); + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + float* other_working_ptr = &other[i * other_stride]; + // compute ormgr + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr( + handle, side, trans, m, n, k, a_working_ptr, lda, tau_working_ptr, + other_working_ptr, ldc, workspace_ptr, lwork, info_d)); + + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); + } +} + +template <> +void BatchedOrmqr( + const platform::CUDADeviceContext& dev_ctx, bool left, bool transpose, + int batch_size, int m, int n, int k, double* a, int a_stride, double* tau, + int tau_stride, double* other, int other_stride) { + int lwork = 0; + auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; + auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = std::max(1, left ? m : n); + int ldc = std::max(1, m); + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + double* other_working_ptr = &other[i * other_stride]; + // compute ormgr + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr( + handle, side, trans, m, n, k, a_working_ptr, lda, tau_working_ptr, + other_working_ptr, ldc, workspace_ptr, lwork, info_d)); + + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); + } +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + lstsq, ops::LstsqCUDAKernel, + ops::LstsqCUDAKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h index b9c5c87a6a..be41123270 100644 --- a/paddle/fluid/operators/lstsq_op.h +++ b/paddle/fluid/operators/lstsq_op.h @@ -49,7 +49,7 @@ class LstsqCPUKernel : public framework::OpKernel { using ValueType = math::Real; const Tensor& x = *context.Input("X"); - const Tensor& y = *context.Input("Y"); + auto y = context.Input("Y"); auto rcond = context.Attr("rcond"); auto driver_string = context.Attr("driver"); @@ -68,13 +68,15 @@ class LstsqCPUKernel : public framework::OpKernel { math::DeviceIndependenceTensorOperations(context); auto x_dims = x.dims(); - auto y_dims = y.dims(); + auto y_dims = y->dims(); int dim_size = x_dims.size(); int x_stride = MatrixStride(x); - int y_stride = MatrixStride(y); + int y_stride = MatrixStride(*y); int batch_count = BatchCount(x); - auto ori_solution_dim = solution->dims(); + auto solution_dim = solution->dims(); int ori_solu_stride = MatrixStride(*solution); + int max_solu_stride = std::max(y_stride, ori_solu_stride); + int min_solu_stride = std::min(y_stride, ori_solu_stride); // lapack is a column-major storge, transpose make the input to // have a continuous memory layout @@ -88,13 +90,24 @@ class LstsqCPUKernel : public framework::OpKernel { Tensor new_x; new_x.mutable_data(context.GetPlace(), size_t(batch_count * m * n * sizeof(T))); + framework::TensorCopy(x, context.GetPlace(), &new_x); + solution->mutable_data( context.GetPlace(), size_t(batch_count * std::max(m, n) * nrhs * sizeof(T))); - framework::TensorCopy(x, context.GetPlace(), &new_x); - framework::TensorCopy(y, context.GetPlace(), solution); - if (m < n) solution->Resize(UDDim(ori_solution_dim)); + if (m >= n) { + const Tensor& new_y = *context.Input("Y"); + framework::TensorCopy(new_y, context.GetPlace(), solution); + } else { + auto* solu_data = solution->data(); + auto* y_data = y->data(); + for (auto i = 0; i < batch_count; i++) { + for (auto j = 0; j < min_solu_stride; j++) { + solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j]; + } + } + } Tensor input_x_trans = dito.Transpose(new_x); Tensor input_y_trans = dito.Transpose(*solution); @@ -186,10 +199,9 @@ class LstsqCPUKernel : public framework::OpKernel { iwork_data = iwork.mutable_data(context.GetPlace()); } - int solu_stride = std::max(y_stride, ori_solu_stride); for (auto i = 0; i < batch_count; ++i) { auto* x_input = &x_vector[i * x_stride]; - auto* y_input = &y_vector[i * solu_stride]; + auto* y_input = &y_vector[i * max_solu_stride]; rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; @@ -221,9 +233,24 @@ class LstsqCPUKernel : public framework::OpKernel { Tensor tmp_s = dito.Transpose(*solution); framework::TensorCopy(tmp_s, solution->place(), solution); - if (m >= n) solution->Resize(UDDim(ori_solution_dim)); + if (m > n) { + auto* solu_data = solution->data(); + for (auto i = 1; i < batch_count; i++) { + for (auto j = 0; j < min_solu_stride; j++) { + solu_data[i * min_solu_stride + j] = + solu_data[i * max_solu_stride + j]; + } + } + } + + solution->Resize(UDDim(solution_dim)); } }; +template +void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, bool transpose, + int batch_size, int m, int n, int k, T* a, int a_stride, + T* tau, int tau_stride, T* other, int other_stride); + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/qr_op.cu b/paddle/fluid/operators/qr_op.cu index 3eb5f72b5b..af5ebdc531 100644 --- a/paddle/fluid/operators/qr_op.cu +++ b/paddle/fluid/operators/qr_op.cu @@ -88,8 +88,8 @@ class QrGPUKernel : public framework::OpKernel { auto qr_data = qr.mutable_data(context.GetPlace()); auto tau_data = tau.mutable_data(context.GetPlace()); - BatchedGeqrf(dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, - tau_stride); + BatchedGeqrf( + dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride); if (reduced_mode) { auto trans_qr = dito.Transpose(qr); @@ -108,8 +108,9 @@ class QrGPUKernel : public framework::OpKernel { // Perform QRGQR for Q using the result from GEQRF // Transpose 'q' to retore the original row-major order if (reduced_mode) { - BatchedOrgqr(dev_ctx, batch_size, m, min_mn, min_mn, qr_data, m, - tau_data, qr_stride, tau_stride); + BatchedOrgqr( + dev_ctx, batch_size, m, min_mn, min_mn, qr_data, m, tau_data, + qr_stride, tau_stride); auto trans_q = dito.Transpose(qr); auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {min_mn}); framework::TensorCopy(sliced_q, q.place(), &q); @@ -128,13 +129,15 @@ class QrGPUKernel : public framework::OpKernel { (qr_data + i * qr_stride), qr_stride * sizeof(math::Real), dev_ctx.stream()); } - BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, new_qr_data, m, - tau_data, new_qr_stride, tau_stride); + BatchedOrgqr( + dev_ctx, batch_size, m, m, min_mn, new_qr_data, m, tau_data, + new_qr_stride, tau_stride); auto trans_q = dito.Transpose(new_qr); framework::TensorCopy(trans_q, q.place(), &q); } else { - BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, qr_data, m, tau_data, - qr_stride, tau_stride); + BatchedOrgqr( + dev_ctx, batch_size, m, m, min_mn, qr_data, m, tau_data, + qr_stride, tau_stride); auto trans_q = dito.Transpose(qr); auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {m}); framework::TensorCopy(sliced_q, q.place(), &q); @@ -142,28 +145,12 @@ class QrGPUKernel : public framework::OpKernel { } } } - - void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size, - int m, int n, float* a, int lda, float* tau, int a_stride, - int tau_stride) const; - - void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size, - int m, int n, double* a, int lda, double* tau, int a_stride, - int tau_stride) const; - - void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size, - int m, int n, int k, float* a, int lda, float* tau, - int a_stride, int tau_stride) const; - - void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size, - int m, int n, int k, double* a, int lda, double* tau, - int a_stride, int tau_stride) const; }; template <> -void QrGPUKernel::BatchedGeqrf( +void BatchedGeqrf( const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, - float* a, int lda, float* tau, int a_stride, int tau_stride) const { + float* a, int lda, float* tau, int a_stride, int tau_stride) { int lwork = 0; auto handle = dev_ctx.cusolver_dn_handle(); @@ -195,9 +182,9 @@ void QrGPUKernel::BatchedGeqrf( } template <> -void QrGPUKernel::BatchedGeqrf( +void BatchedGeqrf( const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, - double* a, int lda, double* tau, int a_stride, int tau_stride) const { + double* a, int lda, double* tau, int a_stride, int tau_stride) { int lwork = 0; auto handle = dev_ctx.cusolver_dn_handle(); @@ -229,9 +216,9 @@ void QrGPUKernel::BatchedGeqrf( } template <> -void QrGPUKernel::BatchedOrgqr( +void BatchedOrgqr( const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, - int k, float* a, int lda, float* tau, int a_stride, int tau_stride) const { + int k, float* a, int lda, float* tau, int a_stride, int tau_stride) { int lwork = 0; auto handle = dev_ctx.cusolver_dn_handle(); @@ -263,10 +250,9 @@ void QrGPUKernel::BatchedOrgqr( } template <> -void QrGPUKernel::BatchedOrgqr( +void BatchedOrgqr( const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, - int k, double* a, int lda, double* tau, int a_stride, - int tau_stride) const { + int k, double* a, int lda, double* tau, int a_stride, int tau_stride) { int lwork = 0; auto handle = dev_ctx.cusolver_dn_handle(); diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h index 65dfb4261e..1731aa9e07 100644 --- a/paddle/fluid/operators/qr_op.h +++ b/paddle/fluid/operators/qr_op.h @@ -250,5 +250,13 @@ class QrGradKernel : public framework::OpKernel { } }; +template +void BatchedGeqrf(const DeviceContext& dev_ctx, int batch_size, int m, int n, + T* a, int lda, T* tau, int a_stride, int tau_stride); + +template +void BatchedOrgqr(const DeviceContext& dev_ctx, int batch_size, int m, int n, + int k, T* a, int lda, T* tau, int a_stride, int tau_stride); + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index f9dc6baea3..63661a93cf 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -81,6 +81,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); __macro(cusolverDnZgeqrf_bufferSize); \ __macro(cusolverDnSorgqr_bufferSize); \ __macro(cusolverDnDorgqr_bufferSize); \ + __macro(cusolverDnSormqr_bufferSize); \ + __macro(cusolverDnDormqr_bufferSize); \ __macro(cusolverDnCungqr_bufferSize); \ __macro(cusolverDnZungqr_bufferSize); \ __macro(cusolverDnDestroyGesvdjInfo); \ @@ -98,6 +100,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); __macro(cusolverDnZgeqrf); \ __macro(cusolverDnSorgqr); \ __macro(cusolverDnDorgqr); \ + __macro(cusolverDnSormqr); \ + __macro(cusolverDnDormqr); \ __macro(cusolverDnCungqr); \ __macro(cusolverDnZungqr); diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py index 4c0325a35f..59ac2e2808 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -18,11 +18,15 @@ import unittest import numpy as np import paddle import paddle.fluid as fluid +import paddle.fluid.core as core class LinalgLstsqTestCase(unittest.TestCase): def setUp(self): + self.devices = ["cpu"] self.init_config() + if core.is_compiled_with_cuda() and self.driver == "gels": + self.devices.append("gpu:0") self.generate_input() self.generate_output() @@ -43,104 +47,129 @@ class LinalgLstsqTestCase(unittest.TestCase): if len(self._input_shape_1) == 2: out = np.linalg.lstsq( self._input_data_1, self._input_data_2, rcond=self.rcond) + self._output_solution = out[0] + self._output_residuals = out[1] + self._output_rank = out[2] + self._output_sg_values = out[3] elif len(self._input_shape_1) == 3: - out = np.linalg.lstsq( - self._input_data_1[0], self._input_data_2[0], rcond=self.rcond) - - self._output_solution = out[0] - self._output_residuals = out[1] - self._output_rank = out[2] - self._output_sg_values = out[3] + self._output_solution = [] + self._output_residuals = [] + self._output_rank = [] + self._output_sg_values = [] + for i in range(self._input_shape_1[0]): + out = np.linalg.lstsq( + self._input_data_1[i], + self._input_data_2[i], + rcond=self.rcond) + self._output_solution.append(out[0]) + self._output_residuals.append(out[1]) + self._output_rank.append(out[2]) + self._output_sg_values.append(out[3]) def test_dygraph(self): paddle.disable_static() - paddle.device.set_device("cpu") - place = paddle.CPUPlace() - x = paddle.to_tensor(self._input_data_1, place=place, dtype=self.dtype) - y = paddle.to_tensor(self._input_data_2, place=place, dtype=self.dtype) - results = paddle.linalg.lstsq( - x, y, rcond=self.rcond, driver=self.driver) - - res_solution = results[0].numpy() - res_residuals = results[1].numpy() - res_rank = results[2].numpy() - res_singular_values = results[3].numpy() - - if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]: - if (np.abs(res_residuals - self._output_residuals) < 1e-6).any(): - pass - else: - raise RuntimeError("Check LSTSQ residuals dygraph Failed") - - if self.driver in ("gelsy", "gelsd", "gelss"): - if (np.abs(res_rank - self._output_rank) < 1e-6).any(): - pass - else: - raise RuntimeError("Check LSTSQ rank dygraph Failed") - - if self.driver in ("gelsd", "gelss"): - if (np.abs(res_singular_values - self._output_sg_values) < 1e-6 - ).any(): - pass - else: - raise RuntimeError("Check LSTSQ singular values dygraph Failed") - - def test_static(self): - paddle.enable_static() - paddle.device.set_device("cpu") - place = fluid.CPUPlace() - with fluid.program_guard(fluid.Program(), fluid.Program()): - x = paddle.fluid.data( - name="x", - shape=self._input_shape_1, - dtype=self._input_data_1.dtype) - y = paddle.fluid.data( - name="y", - shape=self._input_shape_2, - dtype=self._input_data_2.dtype) + for dev in self.devices: + paddle.set_device(dev) + place = paddle.CPUPlace() if dev == "cpu" else paddle.CUDAPlace(0) + x = paddle.to_tensor( + self._input_data_1, place=place, dtype=self.dtype) + y = paddle.to_tensor( + self._input_data_2, place=place, dtype=self.dtype) results = paddle.linalg.lstsq( x, y, rcond=self.rcond, driver=self.driver) - exe = fluid.Executor(place) - fetches = exe.run( - fluid.default_main_program(), - feed={"x": self._input_data_1, - "y": self._input_data_2}, - fetch_list=[results]) - - if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]: - if (np.abs(fetches[1] - self._output_residuals) < 1e-6).any(): - pass - else: - raise RuntimeError("Check LSTSQ residuals static Failed") + self._result_solution = results[0].numpy() + self._result_residuals = results[1].numpy() + self._result_rank = results[2].numpy() + self._result_sg_values = results[3].numpy() + self.assert_np_close() + def test_static(self): + paddle.enable_static() + for dev in self.devices: + paddle.set_device(dev) + place = fluid.CPUPlace() if dev == "cpu" else fluid.CUDAPlace(0) + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = paddle.fluid.data( + name="x", + shape=self._input_shape_1, + dtype=self._input_data_1.dtype) + y = paddle.fluid.data( + name="y", + shape=self._input_shape_2, + dtype=self._input_data_2.dtype) + results = paddle.linalg.lstsq( + x, y, rcond=self.rcond, driver=self.driver) + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"x": self._input_data_1, + "y": self._input_data_2}, + fetch_list=[results]) + self._result_solution = fetches[0] + self._result_residuals = fetches[1] + self._result_rank = fetches[2] + self._result_sg_values = fetches[3] + self.assert_np_close() + + def assert_np_close(self): + if len(self._input_shape_1) == 2: + np.testing.assert_allclose( + self._result_solution, self._output_solution, rtol=1e-3) + if self._input_shape_1[-2] > self._input_shape_1[ + -1] and self._output_rank == self._input_shape_1[-1]: + np.testing.assert_allclose( + self._result_residuals, self._output_residuals, rtol=1e-5) if self.driver in ("gelsy", "gelsd", "gelss"): - if (np.abs(fetches[2] - self._output_rank) < 1e-6).any(): - pass - else: - raise RuntimeError("Check LSTSQ rank static Failed") - + np.testing.assert_allclose( + self._result_rank, self._output_rank, rtol=1e-5) if self.driver in ("gelsd", "gelss"): - if (np.abs(fetches[3] - self._output_sg_values) < 1e-6).any(): - pass - else: - raise RuntimeError( - "Check LSTSQ singular values static Failed") + np.testing.assert_allclose( + self._result_sg_values, self._output_sg_values, rtol=1e-5) + else: + for i in range(len(self._output_solution)): + np.testing.assert_allclose( + self._result_solution[i], + self._output_solution[i], + rtol=1e-3) + if self._input_shape_1[-2] > self._input_shape_1[ + -1] and self._output_rank[i] == self._input_shape_1[-1]: + np.testing.assert_allclose( + self._result_residuals[i], + self._output_residuals[i], + rtol=1e-5) + if self.driver in ("gelsy", "gelsd", "gelss"): + np.testing.assert_allclose( + self._result_rank[i], self._output_rank[i], rtol=1e-5) + if self.driver in ("gelsd", "gelss"): + np.testing.assert_allclose( + self._result_sg_values[i], + self._output_sg_values[i], + rtol=1e-5) + + +class LinalgLstsqTestCase1(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float32' + self.rcond = 1e-15 + self.driver = "gels" + self._input_shape_1 = (9, 9) + self._input_shape_2 = (9, 5) -class LinalgLstsqTestCase(LinalgLstsqTestCase): +class LinalgLstsqTestCase2(LinalgLstsqTestCase): def init_config(self): self.dtype = 'float64' self.rcond = 1e-15 self.driver = "gels" self._input_shape_1 = (5, 10) - self._input_shape_2 = (5, 5) + self._input_shape_2 = (5, 8) class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase): def init_config(self): self.dtype = 'float64' - self.rcond = 0.1 - self.driver = "gels" + self.rcond = 1e-7 + self.driver = "gelsd" self._input_shape_1 = (3, 2) self._input_shape_2 = (3, 3) @@ -148,7 +177,7 @@ class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase): class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase): def init_config(self): self.dtype = 'float32' - self.rcond = 1e-15 + self.rcond = None self.driver = "gels" self._input_shape_1 = (10, 5) self._input_shape_2 = (10, 2) @@ -157,7 +186,7 @@ class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase): class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase): def init_config(self): self.dtype = 'float64' - self.rcond = 1e-15 + self.rcond = None self.driver = "gelss" self._input_shape_1 = (5, 5) self._input_shape_2 = (5, 1) @@ -176,7 +205,7 @@ class LinalgLstsqTestCaseBatch1(LinalgLstsqTestCase): def init_config(self): self.dtype = 'float32' self.rcond = 1e-15 - self.driver = None + self.driver = "gelss" self._input_shape_1 = (2, 3, 10) self._input_shape_2 = (2, 3, 4) @@ -186,8 +215,8 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase): self.dtype = 'float64' self.rcond = 1e-15 self.driver = "gelss" - self._input_shape_1 = (2, 8, 6) - self._input_shape_2 = (2, 8, 2) + self._input_shape_1 = (10, 8, 6) + self._input_shape_2 = (10, 8, 2) class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase): @@ -201,7 +230,7 @@ class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase): class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase): def init_config(self): - self.dtype = 'float32' + self.dtype = 'float64' self.rcond = 1e-15 self.driver = "gelss" self._input_shape_1 = (50, 600) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 5f71606b7d..170889588a 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2816,8 +2816,66 @@ def eigvalsh(x, UPLO='L', name=None): return out_value -def lstsq(x, y, rcond=1e-15, driver=None, name=None): - device = paddle.device.get_device() +def lstsq(x, y, rcond=None, driver=None, name=None): + """ + Computes a solution to + the least squares problem of a system of linear equations. + + Args: + x (Tensor): A tensor with shape ``(*, M, N)`` , the data type of the input Tensor ``x`` + should be one of float32, float64. + y (Tensor): A tensor with shape ``(*, M, K)`` , the data type of the input Tensor ``y`` + should be one of float32, float64. + rcond(float, optional): The default value is None. A float pointing number used to determine + the effective rank of ``x``. If ``rcond`` is None, it will be set to max(M, N) times the + machine precision of x_dtype. + driver(str, optional): The default value is None. The name of LAPACK method to be used. For + CPU inputs the valid values are ‘gels’, ‘gelsy’, ‘gelsd, ‘gelss’. For CUDA input, the only + valid driver is ‘gels’. If ``driver`` is None, ‘gelsy’ is used for CPU inputs and ‘gels’ + for CUDA inputs. + name(str, optional): The default value is None. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tuple: A tuple of 4 Tensors which is (``solution``, ``residuals``, ``rank``, ``singular_values``). + ``solution`` is a tensor with shape ``(*, N, K)``, meaning the least squares solution. ``residuals`` + is a tensor with shape ``(*, K)``, meaning the squared residuals of the solutions, which is computed + when M > N and every matrix in ``x`` is full-rank, otherwise return an empty tensor. ``rank`` is a tensor + with shape ``(*)``, meaning the ranks of the matrices in ``x``, which is computed when ``driver`` in + (‘gelsy’, ‘gelsd’, ‘gelss’), otherwise return an empty tensor. ``singular_values`` is a tensor with + shape ``(*, min(M, N))``, meaning singular values of the matrices in ``x``, which is computed when + ``driver`` in (‘gelsd’, ‘gelss’), otherwise return an empty tensor. + + Examples: + .. code-block:: python + + import paddle + + paddle.set_device("cpu") + x = paddle.to_tensor([[1, 3], [3, 2], [5, 6.]]) + y = paddle.to_tensor([[3, 4, 6], [5, 3, 4], [1, 2, 1.]]) + results = paddle.linalg.lstsq(x, y, driver="gelsd") + print(results[0]) + # [[ 0.78350395, -0.22165027, -0.62371236], + # [-0.11340097, 0.78866047, 1.14948535]] + print(results[1]) + # [19.81443405, 10.43814468, 30.56185532]) + print(results[2]) + # 2 + print(results[3]) + # [9.03455734, 1.54167950] + + x = paddle.to_tensor([[10, 2, 3], [3, 10, 5], [5, 6, 12.]]) + y = paddle.to_tensor([[4, 2, 9], [2, 0, 3], [2, 5, 3.]]) + results = paddle.linalg.lstsq(x, y, driver="gels") + print(results[0]) + # [[ 0.39386186, 0.10230173, 0.93606132], + # [ 0.10741687, -0.29028133, 0.11892585], + # [-0.05115091, 0.51918161, -0.19948854]] + print(results[1]) + # [] + """ + device = paddle.get_device() if device == "cpu": if driver not in (None, "gels", "gelss", "gelsd", "gelsy"): raise ValueError( @@ -2833,6 +2891,19 @@ def lstsq(x, y, rcond=1e-15, driver=None, name=None): else: raise RuntimeError("Only support lstsq api for CPU or CUDA device.") + if x.dtype == y.dtype and x.dtype in (paddle.float32, paddle.float64): + pass + else: + raise ValueError( + "Only support x and y have the same dtype such as 'float32' and 'float64'." + ) + + if rcond is None: + if x.dtype == paddle.float32: + rcond = 1e-7 * max(x.shape[-2], x.shape[-1]) + elif x.dtype == paddle.float64: + rcond = 1e-15 * max(x.shape[-2], x.shape[-1]) + if in_dygraph_mode(): solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond, "driver", driver) -- GitLab