From ab2aaf8b5c806ea555b61e6e38a2670b5782f91d Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Fri, 29 Jul 2022 18:57:53 +0800 Subject: [PATCH] [API/OP] Migrate Lstsq op into phi (#44318) * migrate lstsq op * update * fix bugs for CIs * update * fix bugs * add uts * update * update * update * fix bugs of jip * fix bugs of hip * update * update according to review * update * update * update * update --- paddle/fluid/operators/lstsq_op.cc | 105 ++---- paddle/fluid/pybind/op_function_generator.h | 1 + paddle/phi/api/yaml/legacy_api.yaml | 9 + paddle/phi/infermeta/binary.cc | 84 +++++ paddle/phi/infermeta/binary.h | 9 + paddle/phi/kernels/cpu/lstsq_kernel.cc | 304 ++++++++++++++++++ paddle/phi/kernels/gpu/lstsq_kernel.cu | 177 ++++++++++ paddle/phi/kernels/impl/lstsq_kernel_impl.h | 240 ++++++++++++++ paddle/phi/kernels/impl/qr_kernel_impl.h | 274 ++++++++++++++++ paddle/phi/kernels/lstsq_kernel.h | 32 ++ paddle/phi/ops/compat/lstsq_sig.cc | 28 ++ .../tests/unittests/test_linalg_lstsq_op.py | 25 +- python/paddle/tensor/linalg.py | 55 +--- 13 files changed, 1211 insertions(+), 132 deletions(-) create mode 100644 paddle/phi/kernels/cpu/lstsq_kernel.cc create mode 100644 paddle/phi/kernels/gpu/lstsq_kernel.cu create mode 100644 paddle/phi/kernels/impl/lstsq_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/qr_kernel_impl.h create mode 100644 paddle/phi/kernels/lstsq_kernel.h create mode 100644 paddle/phi/ops/compat/lstsq_sig.cc diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc index 70ce5082ce..b02a2fe13a 100644 --- a/paddle/fluid/operators/lstsq_op.cc +++ b/paddle/fluid/operators/lstsq_op.cc @@ -12,12 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/lstsq_op.h" - -#include -#include - +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -25,79 +23,6 @@ namespace operators { class LstsqOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LstsqOp"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "LstsqOp"); - - OP_INOUT_CHECK(ctx->HasOutput("Solution"), "Output", "Solution", "LstsqOp"); - OP_INOUT_CHECK(ctx->HasOutput("Rank"), "Output", "Rank", "LstsqOp"); - OP_INOUT_CHECK(ctx->HasOutput("SingularValues"), - "Output", - "SingularValues", - "LstsqOp"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - int x_rank = x_dims.size(); - int y_rank = y_dims.size(); - - PADDLE_ENFORCE_GE(x_rank, - 2, - platform::errors::InvalidArgument( - "Expects input tensor x to be not less than " - "2 dimentions, but got dimention %d", - x_rank)); - PADDLE_ENFORCE_GE(y_rank, - 2, - platform::errors::InvalidArgument( - "Expects input tensor y to be not less than " - "2 dimentions, but got dimention %d", - y_rank)); - - PADDLE_ENFORCE_EQ( - x_rank, - y_rank, - platform::errors::InvalidArgument( - "Expects input tensor x and y to have the same dimension " - "but got x's dimention [%d] and y's dimention [%d]", - x_rank, - y_rank)); - - std::vector batch_dims_vec{}; - for (int i = 0; i < x_rank - 2; ++i) { - PADDLE_ENFORCE_EQ( - x_dims[i], - y_dims[i], - platform::errors::InvalidArgument( - "Expects input tensor x and y to have the same batch " - "dimension, but got x's batch dimention [%d] and " - "y's batch dimention [%d] in %d-th dim", - x_dims[i], - y_dims[i], - i)); - batch_dims_vec.emplace_back(x_dims[i]); - } - - PADDLE_ENFORCE_EQ( - x_dims[x_rank - 2], - y_dims[y_rank - 2], - platform::errors::InvalidArgument( - "Expects input tensor x and y to have the same row dimension " - "of the inner-most 2-dims matrix, " - "but got x's row dimention [%d] and y's row dimention [%d]", - x_dims[x_rank - 2], - y_dims[y_rank - 2])); - - ctx->SetOutputDim("Rank", phi::make_ddim(batch_dims_vec)); - - batch_dims_vec.emplace_back( - std::min(x_dims[x_rank - 2], x_dims[x_rank - 1])); - ctx->SetOutputDim("SingularValues", phi::make_ddim(batch_dims_vec)); - - batch_dims_vec[x_rank - 2] = x_dims[x_rank - 1]; - batch_dims_vec.emplace_back(y_dims[x_rank - 1]); - ctx->SetOutputDim("Solution", phi::make_ddim(batch_dims_vec)); - } protected: // The output of lstsq is always complex-valued even for real-valued inputs @@ -133,6 +58,9 @@ class LstsqOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault("gels"); AddOutput("Solution", "(Tensor), The output Solution tensor with shape (*, n, k)."); + AddOutput("Residuals", + "(Tensor), The output Residuals tensor with shape (*, k).") + .AsDispensable(); AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*)."); AddOutput( "SingularValues", @@ -148,8 +76,21 @@ This API processes Lstsq functor for general matrices. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker) -REGISTER_OP_CPU_KERNEL(lstsq, - ops::LstsqCPUKernel, - ops::LstsqCPUKernel); +DECLARE_INFER_SHAPE_FUNCTOR(lstsq, + LstsqInferShapeFunctor, + PD_INFER_META(phi::LstsqInferMeta)); + +REGISTER_OPERATOR(lstsq, + ops::LstsqOp, + ops::LstsqOpMaker, + LstsqInferShapeFunctor); + +REGISTER_OP_VERSION(lstsq).AddCheckpoint( + R"ROC( + Upgrade lstsq, add 1 outputs [Residuals]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewOutput( + "Residuals", + "Output tensor of lstsq operator, " + "meaning the squared residuals of the calculated solutions.")); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 590d9d2f83..8f66d258ed 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -245,6 +245,7 @@ std::map> op_outs_map = { "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"lstsq", {"Solution", "Residuals", "Rank", "SingularValues"}}, {"inplace_abn", {"Y", "MeanOut", diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 0e01074f0a..ff80469eed 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1425,6 +1425,15 @@ func : logsumexp backward : logsumexp_grad +- api : lstsq + args : (Tensor x, Tensor y, Scalar rcond, str driver) + output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values) + infer_meta : + func : LstsqInferMeta + dtype : x + kernel : + func : lstsq + - api : lu args : (Tensor x, bool pivot) output : Tensor(out), Tensor(pivots), Tensor(infos) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 1463296664..4d72e1b60d 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2007,6 +2007,90 @@ void TriangularSolveInferMeta(const MetaTensor& x, out->share_lod(y); } +void LstsqInferMeta(const MetaTensor& x, + const MetaTensor& y, + const Scalar& rcond, + const std::string& driver, + MetaTensor* solution, + MetaTensor* residuals, + MetaTensor* rank, + MetaTensor* singular_values) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int x_rank = x_dims.size(); + int y_rank = y_dims.size(); + + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int nrhs = y_dims[x_rank - 1]; + + PADDLE_ENFORCE_GE( + x_rank, + 2, + phi::errors::InvalidArgument("Expects input tensor x to be not less than " + "2 dimentions, but got dimention %d", + x_rank)); + PADDLE_ENFORCE_GE( + y_rank, + 2, + phi::errors::InvalidArgument("Expects input tensor y to be not less than " + "2 dimentions, but got dimention %d", + y_rank)); + + PADDLE_ENFORCE_EQ( + x_rank, + y_rank, + phi::errors::InvalidArgument( + "Expects input tensor x and y to have the same dimension " + "but got x's dimention [%d] and y's dimention [%d]", + x_rank, + y_rank)); + + std::vector batch_dims_vec{}; + for (int i = 0; i < x_rank - 2; ++i) { + PADDLE_ENFORCE_EQ(x_dims[i], + y_dims[i], + phi::errors::InvalidArgument( + "Expects input tensor x and y to have the same batch " + "dimension, but got x's batch dimention [%d] and " + "y's batch dimention [%d] in %d-th dim", + x_dims[i], + y_dims[i], + i)); + batch_dims_vec.emplace_back(x_dims[i]); + } + + PADDLE_ENFORCE_EQ( + m, + y_dims[y_rank - 2], + phi::errors::InvalidArgument( + "Expects input tensor x and y to have the same row dimension " + "of the inner-most 2-dims matrix, " + "but got x's row dimention [%d] and y's row dimention [%d]", + m, + y_dims[y_rank - 2])); + + rank->set_dims(phi::make_ddim(batch_dims_vec)); + + if (m > n) { + batch_dims_vec.emplace_back(nrhs); + residuals->set_dims(phi::make_ddim(batch_dims_vec)); + batch_dims_vec.pop_back(); + } else { + residuals->set_dims(phi::make_ddim({0})); + } + residuals->set_dtype(y.dtype()); + + batch_dims_vec.emplace_back(std::min(m, n)); + singular_values->set_dims(phi::make_ddim(batch_dims_vec)); + singular_values->set_dtype(y.dtype()); + + batch_dims_vec[x_rank - 2] = n; + batch_dims_vec.emplace_back(nrhs); + solution->set_dims(phi::make_ddim(batch_dims_vec)); + solution->set_dtype(y.dtype()); +} + void YoloBoxInferMeta(const MetaTensor& x, const MetaTensor& img_size, const std::vector& anchors, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 85851ee705..53d6c12e88 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -288,6 +288,15 @@ void TriangularSolveInferMeta(const MetaTensor& x, bool unitriangular, MetaTensor* out); +void LstsqInferMeta(const MetaTensor& x, + const MetaTensor& y, + const Scalar& rcond, + const std::string& driver, + MetaTensor* solution, + MetaTensor* residuals, + MetaTensor* rank, + MetaTensor* singular_values); + void YoloBoxInferMeta(const MetaTensor& x, const MetaTensor& img_size, const std::vector& anchors, diff --git a/paddle/phi/kernels/cpu/lstsq_kernel.cc b/paddle/phi/kernels/cpu/lstsq_kernel.cc new file mode 100644 index 0000000000..5542c2ba6e --- /dev/null +++ b/paddle/phi/kernels/cpu/lstsq_kernel.cc @@ -0,0 +1,304 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h" +#include "paddle/phi/kernels/lstsq_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; + +template +void LstsqKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rcond_scaler, + const std::string& driver_string, + DenseTensor* solution, + DenseTensor* residuals, + DenseTensor* rank, + DenseTensor* singular_values) { + using ValueType = phi::dtype::Real; + + static auto driver_type = std::unordered_map( + {{"gels", LapackDriverType::Gels}, + {"gelsy", LapackDriverType::Gelsy}, + {"gelsd", LapackDriverType::Gelsd}, + {"gelss", LapackDriverType::Gelss}}); + auto driver = driver_type[driver_string]; + T rcond = rcond_scaler.to(); + + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int dim_size = x_dims.size(); + int x_stride = phi::GetMatrixStride(x_dims); + int y_stride = phi::GetMatrixStride(y_dims); + int batch_count = phi::GetBatchCount(x_dims); + auto solution_dim = solution->dims(); + int ori_solu_stride = phi::GetMatrixStride(solution_dim); + int max_solu_stride = std::max(y_stride, ori_solu_stride); + int min_solu_stride = std::min(y_stride, ori_solu_stride); + + // lapack is a column-major storge, transpose make the input to + // have a continuous memory layout + int info = 0; + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + int nrhs = y_dims[dim_size - 1]; + int lda = std::max(m, 1); + int ldb = std::max(1, std::max(m, n)); + + DenseTensor* new_x = new DenseTensor(); + new_x->Resize(phi::make_ddim({batch_count, m, n})); + dev_ctx.template Alloc(new_x); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), true, new_x); + + solution->Resize(phi::make_ddim({batch_count, std::max(m, n), nrhs})); + dev_ctx.template Alloc(solution); + + if (m >= n) { + phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), true, solution); + } else { + auto* solu_data = solution->data(); + auto* y_data = y.data(); + for (auto i = 0; i < batch_count; i++) { + for (auto j = 0; j < min_solu_stride; j++) { + solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j]; + } + } + } + + DenseTensor input_x_trans = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor input_y_trans = phi::TransposeLast2Dim(dev_ctx, *solution); + phi::Copy(dev_ctx, input_x_trans, dev_ctx.GetPlace(), true, new_x); + phi::Copy( + dev_ctx, input_y_trans, dev_ctx.GetPlace(), true, solution); + + auto* x_vector = new_x->data(); + auto* y_vector = solution->data(); + + // "gels" divers does not need to compute rank + int rank_32 = 0; + int* rank_data = nullptr; + int* rank_working_ptr = nullptr; + if (driver != LapackDriverType::Gels) { + rank_data = dev_ctx.template Alloc(rank); + rank_working_ptr = rank_data; + } + + // "gelsd" and "gelss" divers need to compute singular values + ValueType* s_data = nullptr; + ValueType* s_working_ptr = nullptr; + int s_stride = 0; + if (driver == LapackDriverType::Gelsd || driver == LapackDriverType::Gelss) { + s_data = dev_ctx.template Alloc(singular_values); + s_working_ptr = s_data; + auto s_dims = singular_values->dims(); + s_stride = s_dims[s_dims.size() - 1]; + } + + // "jpvt" is only used for "gelsy" driver + DenseTensor* jpvt = new DenseTensor(); + int* jpvt_data = nullptr; + if (driver == LapackDriverType::Gelsy) { + jpvt->Resize(phi::make_ddim({std::max(1, n)})); + jpvt_data = dev_ctx.template Alloc(jpvt); + } + + // run once the driver, first to get the optimal workspace size + int lwork = -1; + T wkopt; + ValueType rwkopt; + int iwkopt = 0; + + if (driver == LapackDriverType::Gels) { + phi::funcs::lapackGels( + 'N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, lwork, &info); + } else if (driver == LapackDriverType::Gelsd) { + phi::funcs::lapackGelsd(m, + n, + nrhs, + x_vector, + lda, + y_vector, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + &wkopt, + lwork, + &rwkopt, + &iwkopt, + &info); + } else if (driver == LapackDriverType::Gelsy) { + phi::funcs::lapackGelsy(m, + n, + nrhs, + x_vector, + lda, + y_vector, + ldb, + jpvt_data, + static_cast(rcond), + &rank_32, + &wkopt, + lwork, + &rwkopt, + &info); + } else if (driver == LapackDriverType::Gelss) { + phi::funcs::lapackGelss(m, + n, + nrhs, + x_vector, + lda, + y_vector, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + &wkopt, + lwork, + &rwkopt, + &info); + } + + lwork = std::max(1, static_cast(phi::dtype::Real(wkopt))); + DenseTensor* work = new DenseTensor(); + work->Resize(phi::make_ddim({lwork})); + T* work_data = dev_ctx.template Alloc(work); + + // "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers + DenseTensor* rwork = new DenseTensor(); + ValueType* rwork_data = nullptr; + if (IsComplexDtype(x.dtype()) && driver != LapackDriverType::Gels) { + int rwork_len = 0; + if (driver == LapackDriverType::Gelsy) { + rwork_len = std::max(1, 2 * n); + } else if (driver == LapackDriverType::Gelss) { + rwork_len = std::max(1, 5 * std::min(m, n)); + } else if (driver == LapackDriverType::Gelsd) { + rwork_len = std::max(1, rwkopt); + } + rwork->Resize(phi::make_ddim({rwork_len})); + rwork_data = dev_ctx.template Alloc(rwork); + } + + // "iwork" workspace array is relavant only for "gelsd" driver + DenseTensor* iwork = new DenseTensor(); + int* iwork_data = nullptr; + if (driver == LapackDriverType::Gelsd) { + iwork->Resize(phi::make_ddim({std::max(1, iwkopt)})); + iwork_data = dev_ctx.template Alloc(iwork); + } + + for (auto i = 0; i < batch_count; ++i) { + auto* x_input = &x_vector[i * x_stride]; + auto* y_input = &y_vector[i * max_solu_stride]; + rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; + s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; + + if (driver == LapackDriverType::Gels) { + phi::funcs::lapackGels( + 'N', m, n, nrhs, x_input, lda, y_input, ldb, work_data, lwork, &info); + } else if (driver == LapackDriverType::Gelsd) { + phi::funcs::lapackGelsd(m, + n, + nrhs, + x_input, + lda, + y_input, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + work_data, + lwork, + rwork_data, + iwork_data, + &info); + } else if (driver == LapackDriverType::Gelsy) { + phi::funcs::lapackGelsy(m, + n, + nrhs, + x_input, + lda, + y_input, + ldb, + jpvt_data, + static_cast(rcond), + &rank_32, + work_data, + lwork, + rwork_data, + &info); + } else if (driver == LapackDriverType::Gelss) { + phi::funcs::lapackGelss(m, + n, + nrhs, + x_input, + lda, + y_input, + ldb, + s_working_ptr, + static_cast(rcond), + &rank_32, + work_data, + lwork, + rwork_data, + &info); + } + + PADDLE_ENFORCE_EQ( + info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: Lapack info is not zero but [%d]", i, info)); + + if (rank_working_ptr) *rank_working_ptr = static_cast(rank_32); + } + + DenseTensor tmp_s = phi::TransposeLast2Dim(dev_ctx, *solution); + phi::Copy(dev_ctx, tmp_s, dev_ctx.GetPlace(), true, solution); + + if (m > n) { + auto* solu_data = solution->data(); + for (auto i = 1; i < batch_count; i++) { + for (auto j = 0; j < min_solu_stride; j++) { + solu_data[i * min_solu_stride + j] = solu_data[i * max_solu_stride + j]; + } + } + } + + if (batch_count > 1) { + solution->Resize(solution_dim); + } else { + solution->Resize(phi::make_ddim({n, nrhs})); + } + + GetResidualsTensor(dev_ctx, x, y, solution, residuals); +} + +} // namespace phi + +PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lstsq_kernel.cu b/paddle/phi/kernels/gpu/lstsq_kernel.cu new file mode 100644 index 0000000000..adb0ca09d8 --- /dev/null +++ b/paddle/phi/kernels/gpu/lstsq_kernel.cu @@ -0,0 +1,177 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP // HIP not support cusolver + +#include +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/slice.h" +#include "paddle/phi/kernels/impl/lstsq_kernel_impl.h" +#include "paddle/phi/kernels/impl/qr_kernel_impl.h" +#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" +#include "paddle/phi/kernels/lstsq_kernel.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/phi/kernels/triangular_solve_kernel.h" + +namespace phi { + +enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; + +template +void LstsqKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rcond_scalar, + const std::string& driver_string, + DenseTensor* solution, + DenseTensor* residuals, + DenseTensor* rank, + DenseTensor* singular_values) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int dim_size = x_dims.size(); + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + int nrhs = y_dims[dim_size - 1]; + int min_mn = std::min(m, n); + int max_mn = std::max(m, n); + int k = min_mn; + + int x_stride = phi::GetMatrixStride(x_dims); + int y_stride = phi::GetMatrixStride(y_dims); + int tau_stride = min_mn; + int batch_count = phi::GetBatchCount(x_dims); + + T rcond = rcond_scalar.to(); + + DenseTensor* new_x = new DenseTensor(); + new_x->Resize(phi::make_ddim({batch_count, m, n})); + dev_ctx.template Alloc(new_x); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), true, new_x); + + DenseTensor* new_y = new DenseTensor(); + new_y->Resize(phi::make_ddim({batch_count, m, nrhs})); + dev_ctx.template Alloc(new_y); + phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), true, new_y); + + // Prepare tau + auto tau_dims_vec = phi::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + + DenseTensor* tau = new DenseTensor(); + tau->Resize(phi::make_ddim(tau_dims_vec)); + auto tau_data = dev_ctx.template Alloc(tau); + + if (m >= n) { + DenseTensor tmp_x = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor tmp_y = phi::TransposeLast2Dim(dev_ctx, *new_y); + auto x_data = tmp_x.data(); + auto y_data = tmp_y.data(); + + // step 1, compute QR factorization using geqrf + BatchedGeqrf( + dev_ctx, batch_count, m, n, x_data, m, tau_data, x_stride, tau_stride); + + // Step 2, Y <- Q^H Y + BatchedOrmqr(dev_ctx, + true, + true, + batch_count, + m, + nrhs, + k, + x_data, + x_stride, + tau_data, + tau_stride, + y_data, + y_stride); + + DenseTensor trans_r = phi::TransposeLast2Dim(dev_ctx, tmp_x); + DenseTensor slice_r = + phi::funcs::Slice(dev_ctx, trans_r, {-2}, {0}, {min_mn}); + DenseTensor* res_r = new DenseTensor(); + res_r->Resize(phi::make_ddim({batch_count, min_mn, min_mn})); + dev_ctx.template Alloc(res_r); + phi::TrilTriuKernel(dev_ctx, slice_r, 0, false, res_r); + + DenseTensor trans_y = phi::TransposeLast2Dim(dev_ctx, tmp_y); + DenseTensor slice_y = + phi::funcs::Slice(dev_ctx, trans_y, {-2}, {0}, {min_mn}); + + // Step 3, solve R X = Y + phi::TriangularSolveKernel( + dev_ctx, *res_r, slice_y, true, false, false, solution); + + } else { + auto x_data = dev_ctx.template Alloc(new_x); + auto y_data = dev_ctx.template Alloc(new_y); + + // step 1, compute QR factorization using geqrf + BatchedGeqrf( + dev_ctx, batch_count, n, m, x_data, n, tau_data, x_stride, tau_stride); + + // Step 2, solve R^H Z = Y + DenseTensor trans_r = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor slice_r = + phi::funcs::Slice(dev_ctx, trans_r, {-2}, {0}, {min_mn}); + DenseTensor* res_r = new DenseTensor(); + res_r->Resize(phi::make_ddim({batch_count, min_mn, min_mn})); + dev_ctx.template Alloc(res_r); + phi::TrilTriuKernel(dev_ctx, slice_r, 0, false, res_r); + + phi::TriangularSolveKernel( + dev_ctx, *res_r, *new_y, true, true, false, solution); + + // Step 3, X <- Q Z + BatchedOrgqr(dev_ctx, + batch_count, + n, + m, + min_mn, + x_data, + n, + tau_data, + x_stride, + tau_stride); + + DenseTensor trans_q = phi::TransposeLast2Dim(dev_ctx, *new_x); + DenseTensor slice_q = + phi::funcs::Slice(dev_ctx, trans_q, {-1}, {0}, {m}); + DenseTensor solu_tensor = + phi::Matmul(dev_ctx, slice_q, *solution, false, false); + phi::Copy( + dev_ctx, solu_tensor, dev_ctx.GetPlace(), true, solution); + } + + if (batch_count == 1) solution->Resize(phi::make_ddim({n, nrhs})); + GetResidualsTensor(dev_ctx, x, y, solution, residuals); +} + +} // namespace phi + +PD_REGISTER_KERNEL(lstsq, // cuda_only + GPU, + ALL_LAYOUT, + phi::LstsqKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h new file mode 100644 index 0000000000..73ba954614 --- /dev/null +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -0,0 +1,240 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/utils/optional.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/impl/activation_impl.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +#if defined(PADDLE_WITH_CUDA) +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif + +namespace phi { + +inline int GetBatchCount(const DDim& dims) { + int count = 1; + int num_dims = dims.size(); + for (int i = 0; i < num_dims - 2; ++i) { + count *= dims[i]; + } + return count; +} + +inline int GetMatrixStride(const DDim& dims) { + int num_dims = dims.size(); + return dims[num_dims - 1] * dims[num_dims - 2]; +} + +inline bool IsComplexDtype(const DataType& type) { + return (type == DataType::COMPLEX64 || type == DataType::COMPLEX128); +} + +template +inline void GetResidualsTensor(const DeviceContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* solution, + DenseTensor* residuals) { + auto x_dims = x.dims(); + int dim_size = x_dims.size(); + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + + if (m > n) { + DenseTensor matmul_tensor = + phi::Matmul(dev_ctx, x, *solution, false, false); + DenseTensor sub_tensor = phi::Subtract(dev_ctx, matmul_tensor, y); + DenseTensor* pow_tensor = new DenseTensor(); + pow_tensor->Resize(sub_tensor.dims()); + dev_ctx.template Alloc(pow_tensor); + phi::PowKernel(dev_ctx, sub_tensor, Scalar(2), pow_tensor); + + auto sum_tensor = + phi::Sum(dev_ctx, *pow_tensor, {-2}, pow_tensor->dtype(), false); + phi::Copy( + dev_ctx, sum_tensor, dev_ctx.GetPlace(), true, residuals); + } else { + IntArray empty_shape({0}); + DenseTensor empty_tensor = + phi::Empty(dev_ctx, empty_shape); + phi::Copy( + dev_ctx, empty_tensor, dev_ctx.GetPlace(), true, residuals); + } +} + +#if defined(PADDLE_WITH_CUDA) +template +inline void BatchedOrmqr(const DeviceContext& dev_ctx, + bool left, + bool transpose, + int batch_size, + int m, + int n, + int k, + T* a, + int a_stride, + T* tau, + int tau_stride, + T* other, + int other_stride); + +template <> +inline void BatchedOrmqr(const GPUContext& dev_ctx, + bool left, + bool transpose, + int batch_size, + int m, + int n, + int k, + float* a, + int a_stride, + float* tau, + int tau_stride, + float* other, + int other_stride) { + int lwork = 0; + auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; + auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = std::max(1, left ? m : n); + int ldc = std::max(1, m); + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + float* other_working_ptr = &other[i * other_stride]; + + handle = dev_ctx.cusolver_dn_handle(); + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(workspace); + + // compute ormgr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSormqr(handle, + side, + trans, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + other_working_ptr, + ldc, + workspace_ptr, + lwork, + info_d)); + + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); + } +} + +template <> +inline void BatchedOrmqr(const GPUContext& dev_ctx, + bool left, + bool transpose, + int batch_size, + int m, + int n, + int k, + double* a, + int a_stride, + double* tau, + int tau_stride, + double* other, + int other_stride) { + int lwork = 0; + auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; + auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + int lda = std::max(1, left ? m : n); + int ldc = std::max(1, m); + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDormqr_bufferSize( + handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + double* other_working_ptr = &other[i * other_stride]; + + handle = dev_ctx.cusolver_dn_handle(); + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(workspace); + + // compute ormgr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDormqr(handle, + side, + trans, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + other_working_ptr, + ldc, + workspace_ptr, + lwork, + info_d)); + + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); + } +} +#endif + +} // namespace phi diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h new file mode 100644 index 0000000000..1d64117922 --- /dev/null +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -0,0 +1,274 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/utils/optional.h" + +#if defined(PADDLE_WITH_CUDA) +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif + +namespace phi { + +#if defined(PADDLE_WITH_CUDA) +template +void BatchedGeqrf(const DeviceContext& dev_ctx, + int batch_size, + int m, + int n, + T* a, + int lda, + T* tau, + int a_stride, + int tau_stride); + +template +void BatchedOrgqr(const DeviceContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + T* a, + int lda, + T* tau, + int a_stride, + int tau_stride); + +template <> +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + float* a, + int lda, + float* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + double* a, + int lda, + double* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + float* a, + int lda, + float* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + double* a, + int lda, + double* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + + DenseTensor* workspace = new DenseTensor(); + workspace->Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(workspace); + + DenseTensor* info = new DenseTensor(); + info->Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} +#endif + +} // namespace phi diff --git a/paddle/phi/kernels/lstsq_kernel.h b/paddle/phi/kernels/lstsq_kernel.h new file mode 100644 index 0000000000..1ad58615b4 --- /dev/null +++ b/paddle/phi/kernels/lstsq_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LstsqKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rcond, + const std::string& driver, + DenseTensor* solution, + DenseTensor* residuals, + DenseTensor* rank, + DenseTensor* singular_values); +} // namespace phi diff --git a/paddle/phi/ops/compat/lstsq_sig.cc b/paddle/phi/ops/compat/lstsq_sig.cc new file mode 100644 index 0000000000..f36dfb1917 --- /dev/null +++ b/paddle/phi/ops/compat/lstsq_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LstsqOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("lstsq", + {"X", "Y"}, + {"rcond", "driver"}, + {"Solution", "Residuals", "Rank", "SingularValues"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(lstsq, phi::LstsqOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py index b283c80adf..60acfd414f 100644 --- a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -68,8 +68,31 @@ class LinalgLstsqTestCase(unittest.TestCase): self._output_rank.append(out[2]) self._output_sg_values.append(out[3]) - def test_dygraph(self): + def test_eager_dygraph(self): paddle.disable_static() + paddle.fluid.framework._disable_legacy_dygraph() + for dev in self.devices: + paddle.set_device(dev) + place = paddle.CPUPlace() if dev == "cpu" else paddle.CUDAPlace(0) + x = paddle.to_tensor(self._input_data_1, + place=place, + dtype=self.dtype) + y = paddle.to_tensor(self._input_data_2, + place=place, + dtype=self.dtype) + results = paddle.linalg.lstsq(x, + y, + rcond=self.rcond, + driver=self.driver) + self._result_solution = results[0].numpy() + self._result_residuals = results[1].numpy() + self._result_rank = results[2].numpy() + self._result_sg_values = results[3].numpy() + self.assert_np_close() + + def test_legacy_dygraph(self): + paddle.disable_static() + paddle.fluid.framework._enable_legacy_dygraph() for dev in self.devices: paddle.set_device(dev) place = paddle.CPUPlace() if dev == "cpu" else paddle.CUDAPlace(0) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 7001b6ec57..8f6a4f8e11 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3174,21 +3174,12 @@ def lstsq(x, y, rcond=None, driver=None, name=None): rcond = 1e-15 * max(x.shape[-2], x.shape[-1]) if _non_static_mode(): - solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond, - "driver", driver) - if x.shape[-2] > x.shape[-1]: - matmul_out = _varbase_creator(dtype=x.dtype) - _C_ops.matmul(x, solution, matmul_out, 'trans_x', False, 'trans_y', - False) - minus_out = _C_ops.elementwise_sub(matmul_out, y) - pow_out = _C_ops.pow(minus_out, 'factor', 2) - if in_dygraph_mode(): - residuals = _C_ops.final_state_sum(pow_out, [-2], None, False) - else: - residuals = _C_ops.reduce_sum(pow_out, 'dim', [-2], 'keepdim', - False, 'reduce_all', False) + if in_dygraph_mode(): + solution, residuals, rank, singular_values = _C_ops.final_state_lstsq( + x, y, rcond, driver) else: - residuals = paddle.empty(shape=[0], dtype=x.dtype) + solution, residuals, rank, singular_values = _C_ops.lstsq( + x, y, 'rcond', rcond, 'driver', driver) if driver == "gels": rank = paddle.empty(shape=[0], dtype=paddle.int32) @@ -3218,6 +3209,7 @@ def lstsq(x, y, rcond=None, driver=None, name=None): }, outputs={ 'Solution': solution, + 'Residuals': residuals, 'Rank': rank, 'SingularValues': singular_values }, @@ -3226,41 +3218,6 @@ def lstsq(x, y, rcond=None, driver=None, name=None): 'driver': driver }) - matmul_out = helper.create_variable_for_type_inference(dtype=x.dtype) - minus_out = helper.create_variable_for_type_inference(dtype=x.dtype) - pow_out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op(type='matmul_v2', - inputs={ - 'X': x, - 'Y': solution - }, - outputs={'Out': matmul_out}, - attrs={ - 'trans_x': False, - 'trans_y': False, - }) - - helper.append_op(type='elementwise_sub', - inputs={ - 'X': matmul_out, - 'Y': y - }, - outputs={'Out': minus_out}) - - helper.append_op(type='pow', - inputs={'X': minus_out}, - outputs={'Out': pow_out}, - attrs={'factor': 2}) - - helper.append_op(type='reduce_sum', - inputs={'X': pow_out}, - outputs={'Out': residuals}, - attrs={ - 'dim': [-2], - 'keep_dim': False, - 'reduce_all': False - }) - if driver == "gels": rank = paddle.static.data(name='rank', shape=[0]) singular_values = paddle.static.data(name='singular_values', shape=[0]) -- GitLab