From 34d785c22803db1d45148f8dfd175cbaae05a485 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 19 Oct 2021 14:10:27 +0800 Subject: [PATCH] [paddle.linalg.qr] Add the Qr Operator (#35742) * Add QR decomposition op * Change codes to adapt to new svd_helper * Update linalg.py Restore the deleted comma * Restore the deleted line * Update linalg.py * Update linalg.py * Improve the qr code by reviews * Update QR based on CI results * Update qr doc, test=document_fix * Change unsafe and ill-formed codes --- cmake/operators.cmake | 1 + paddle/fluid/operators/qr_op.cc | 152 +++++++++ paddle/fluid/operators/qr_op.cu | 309 ++++++++++++++++++ paddle/fluid/operators/qr_op.h | 135 ++++++++ paddle/fluid/operators/svd_helper.h | 13 + paddle/fluid/platform/dynload/cusolver.h | 18 +- .../fluid/tests/unittests/test_qr_op.py | 173 ++++++++++ python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 66 +++- 10 files changed, 869 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/qr_op.cc create mode 100644 paddle/fluid/operators/qr_op.cu create mode 100644 paddle/fluid/operators/qr_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_qr_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 228da9f777..5eecbefa2f 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -185,6 +185,7 @@ function(op_library TARGET) list(REMOVE_ITEM hip_srcs "cholesky_op.cu") list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") + list(REMOVE_ITEM hip_srcs "qr_op.cu") list(REMOVE_ITEM hip_srcs "eigh_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc new file mode 100644 index 0000000000..f612bb9e31 --- /dev/null +++ b/paddle/fluid/operators/qr_op.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/qr_op.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { +using DDim = framework::DDim; + +class QrOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "qr"); + OP_INOUT_CHECK(ctx->HasOutput("Q"), "Output", "Q", "qr"); + OP_INOUT_CHECK(ctx->HasOutput("R"), "Output", "R", "qr"); + + auto x_dims = ctx->GetInputDim("X"); + int x_rank = x_dims.size(); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "the rank of input must greater than 2")); + bool compute_q; + bool reduced_mode; + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + std::string mode = ctx->Attrs().Get("mode"); + std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); + + if (compute_q) { + int k = reduced_mode ? min_mn : m; + auto q_dims_vec = framework::vectorize(x_dims); + q_dims_vec[q_dims_vec.size() - 1] = k; + ctx->SetOutputDim("Q", framework::make_ddim(q_dims_vec)); + } else { + ctx->SetOutputDim("Q", framework::make_ddim({0})); + } + + int k = reduced_mode ? min_mn : m; + auto r_dims_vec = framework::vectorize(x_dims); + r_dims_vec[r_dims_vec.size() - 2] = k; + r_dims_vec[r_dims_vec.size() - 1] = n; + ctx->SetOutputDim("R", framework::make_ddim(r_dims_vec)); + + ctx->ShareLoD("X", /*->*/ "Q"); + ctx->ShareLoD("X", /*->*/ "R"); + } +}; + +class QrOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of qr op."); + AddOutput("Q", "(Tensor), The output Q tensor of qr op."); + AddOutput("R", "(Tensor), The output R tensor of qr op."); + AddAttr( + "mode", + "(string, default \"reduced\"). " + "If mode is \"reduced\", Qr op will return reduced Q and R matrices. " + "If mode is \"complete\", Qr op will return complete Q and R matrices. " + "If mode is \"r\", Qr op will only return reduced R matrix.") + .SetDefault("reduced"); + AddComment(R"DOC( +Qr Operator. + +This operator is used to perform QR operation for batched matrics $X$. +$$Q, R = qr(X)$$ + +)DOC"); + } +}; + +class QrGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Q")), "Input", + "Q@Grad", "QrGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("R")), "Input", + "R@Grad", "QrGrad"); + OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "QrGrad"); + OP_INOUT_CHECK(ctx->HasInput("R"), "Input", "R", "QrGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "QrGrad"); + + auto x_dims = ctx->GetInputDim(("X")); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class QrGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("qr_grad"); + retv->SetInput(framework::GradVarName("Q"), this->OutputGrad("Q")); + retv->SetInput(framework::GradVarName("R"), this->OutputGrad("R")); + retv->SetInput("Q", this->Output("Q")); + retv->SetInput("R", this->Output("R")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker, + ops::QrGradMaker, + ops::QrGradMaker); + +REGISTER_OPERATOR(qr_grad, ops::QrGradOp); + +REGISTER_OP_CPU_KERNEL(qr, ops::QrCPUKernel, ops::QrCPUKernel); + +REGISTER_OP_CPU_KERNEL( + qr_grad, ops::QrGradKernel, + ops::QrGradKernel); diff --git a/paddle/fluid/operators/qr_op.cu b/paddle/fluid/operators/qr_op.cu new file mode 100644 index 0000000000..992df172ac --- /dev/null +++ b/paddle/fluid/operators/qr_op.cu @@ -0,0 +1,309 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/qr_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" + +// Reuse some helper functions from svd +#include "paddle/fluid/operators/svd_helper.h" + +namespace paddle { +namespace operators { + +template +class QrGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool compute_q; + bool reduced_mode; + auto& dev_ctx = + context.template device_context(); + const Tensor& x = *context.Input("X"); + Tensor& q = *context.Output("Q"); + Tensor& r = *context.Output("R"); + const std::string mode = context.Attr("mode"); + std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); + + auto numel = x.numel(); + PADDLE_ENFORCE_GT(numel, 0, platform::errors::PreconditionNotMet( + "The input of QR is empty.")); + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = numel / (m * n); + int qr_stride = m * n; + int tau_stride = min_mn; + + if (compute_q) { + q.mutable_data>( + context.GetPlace(), + size_t(batch_size * m * k * sizeof(math::Real))); + } + r.mutable_data>( + context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real))); + + auto dito = + math::DeviceIndependenceTensorOperations(context); + + // Note: allocate temporary tensors because of lacking in-place operatios. + // Prepare qr + Tensor qr; + qr.mutable_data>( + context.GetPlace(), size_t(batch_size * m * n * sizeof(math::Real))); + // BatchedGeqrf performs computation in-place and 'qr' must be a copy of + // input + TensorCopy(x, context.GetPlace(), &qr); + + // Prepare tau + auto tau_dims_vec = framework::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + Tensor tau = dito.Fill(tau_dims_vec, 0); + + // Transpose 'qr' to conform the column-major order + auto tmp_qr = dito.Transpose(qr); + framework::TensorCopy(tmp_qr, qr.place(), &qr); + auto qr_data = qr.mutable_data(context.GetPlace()); + auto tau_data = tau.mutable_data(context.GetPlace()); + + BatchedGeqrf(dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, + tau_stride); + + if (reduced_mode) { + auto trans_qr = dito.Transpose(qr); + auto sliced_qr = dito.Slice(trans_qr, {-2}, {0}, {min_mn}); + auto tmp_r = dito.TrilTriu(sliced_qr, 0, false); + // Transpose 'tmp_r' to retore the original row-major order + framework::TensorCopy(tmp_r, r.place(), &r); + } else { + auto trans_qr = dito.Transpose(qr); + auto tmp_r = dito.TrilTriu(trans_qr, 0, false); + // Transpose 'tmp_r' to retore the original row-major order + framework::TensorCopy(tmp_r, r.place(), &r); + } + + if (compute_q) { + // Perform QRGQR for Q using the result from GEQRF + // Transpose 'q' to retore the original row-major order + if (reduced_mode) { + BatchedOrgqr(dev_ctx, batch_size, m, min_mn, min_mn, qr_data, m, + tau_data, qr_stride, tau_stride); + auto trans_q = dito.Transpose(qr); + auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {min_mn}); + framework::TensorCopy(sliced_q, q.place(), &q); + } else { + if (m > n) { + auto new_qr_dims_vec = framework::vectorize(x_dims); + new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m; + Tensor new_qr = dito.Fill(new_qr_dims_vec, 0); + auto new_qr_data = new_qr.mutable_data(context.GetPlace()); + auto new_qr_stride = m * m; + for (int i = 0; i < batch_size; ++i) { + memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + (new_qr_data + i * new_qr_stride), + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + (qr_data + i * qr_stride), qr_stride * sizeof(math::Real), + dev_ctx.stream()); + } + BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, new_qr_data, m, + tau_data, new_qr_stride, tau_stride); + auto trans_q = dito.Transpose(new_qr); + framework::TensorCopy(trans_q, q.place(), &q); + } else { + BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, qr_data, m, tau_data, + qr_stride, tau_stride); + auto trans_q = dito.Transpose(qr); + auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {m}); + framework::TensorCopy(sliced_q, q.place(), &q); + } + } + } + } + + void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, float* a, int lda, float* tau, int a_stride, + int tau_stride) const; + + void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, double* a, int lda, double* tau, int a_stride, + int tau_stride) const; + + void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, int k, float* a, int lda, float* tau, + int a_stride, int tau_stride) const; + + void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, int k, double* a, int lda, double* tau, + int a_stride, int tau_stride) const; +}; + +template <> +void QrGPUKernel::BatchedGeqrf( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + float* a, int lda, float* tau, int a_stride, int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgeqrf_bufferSize( + handle, m, n, a, lda, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgeqrf( + handle, m, n, a_working_ptr, lda, tau_working_ptr, workspace_ptr, lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void QrGPUKernel::BatchedGeqrf( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + double* a, int lda, double* tau, int a_stride, int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgeqrf_bufferSize( + handle, m, n, a, lda, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgeqrf( + handle, m, n, a_working_ptr, lda, tau_working_ptr, workspace_ptr, lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void QrGPUKernel::BatchedOrgqr( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + int k, float* a, int lda, float* tau, int a_stride, int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSorgqr( + handle, m, n, k, a_working_ptr, lda, tau_working_ptr, workspace_ptr, + lwork, info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +template <> +void QrGPUKernel::BatchedOrgqr( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + int k, double* a, int lda, double* tau, int a_stride, + int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDorgqr( + handle, m, n, k, a_working_ptr, lda, tau_working_ptr, workspace_ptr, + lwork, info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(qr, ops::QrGPUKernel, ops::QrGPUKernel); +REGISTER_OP_CUDA_KERNEL( + qr_grad, ops::QrGradKernel, + ops::QrGradKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h new file mode 100644 index 0000000000..73ba52f590 --- /dev/null +++ b/paddle/fluid/operators/qr_op.h @@ -0,0 +1,135 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +static inline std::tuple _parse_qr_mode(std::string mode) { + bool compute_q; + bool reduced; + if (mode == "reduced") { + compute_q = true; + reduced = true; + } else if (mode == "complete") { + compute_q = true; + reduced = false; + } else if (mode == "r") { + compute_q = false; + reduced = true; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "QR received unrecognized mode '%s'" + " but expected one of 'reduced' (default), 'r', or 'complete'", + mode)); + } + return std::make_tuple(compute_q, reduced); +} + +template +class QrCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool compute_q; + bool reduced_mode; + const Tensor& x = *context.Input("X"); + Tensor& q = *context.Output("Q"); + Tensor& r = *context.Output("R"); + std::string mode = context.Attr("mode"); + std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); + + auto numel = x.numel(); + PADDLE_ENFORCE_GT(numel, 0, platform::errors::PreconditionNotMet( + "The input of QR is empty.")); + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = numel / (m * n); + int x_stride = m * n; + int q_stride = m * k; + int r_stride = k * n; + + auto* x_data = x.data>(); + T* q_data = nullptr; + if (compute_q) { + q_data = q.mutable_data>( + context.GetPlace(), + size_t(batch_size * m * k * sizeof(math::Real))); + } + auto* r_data = r.mutable_data>( + context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real))); + + // Implement QR by calling Eigen + for (int i = 0; i < batch_size; ++i) { + const T* x_matrix_ptr = x_data + i * x_stride; + T* r_matrix_ptr = r_data + i * r_stride; + using EigenDynamicMatrix = + Eigen::Matrix; + auto x_matrix = Eigen::Map(x_matrix_ptr, m, n); + Eigen::HouseholderQR qr(x_matrix); + if (reduced_mode) { + auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n); + auto r_matrix_view = + qr_top_matrix.template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } else { + auto r_matrix_view = + qr.matrixQR().template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } + + if (compute_q) { + T* q_matrix_ptr = q_data + i * q_stride; + if (reduced_mode) { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } else { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, m); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } + } + } + } +}; + +template +class QrGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + PADDLE_THROW(platform::errors::InvalidArgument( + "QR doesn't have the backward kernel now and will be supported soon.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 9ba7c9a306..6b25846822 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -502,6 +502,19 @@ struct DeviceIndependenceTensorOperations { return ret; } + framework::Tensor TrilTriu(const framework::Tensor& x, int diagonal, + bool lower) { + framework::AttributeMap attrs; + attrs["diagonal"] = diagonal; + attrs["lower"] = lower; + NameInTensorMap inputs({{"X", {&x}}}); + int x_rank = x.dims().size(); + PADDLE_ENFORCE_GE(x_rank, 2, platform::errors::InvalidArgument( + "Rank must be at least 2.")); + std::vector out_shape = framework::vectorize(x.dims()); + return CreateOpRunAndReturnTensor("tril_triu", inputs, attrs, out_shape); + } + Tensor Conj(const Tensor& x) { Tensor out; auto* out_data = out.mutable_data(x.dims(), context.GetPlace()); diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index a8ce1cc9d3..4c018908b5 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -65,11 +65,27 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); __macro(cusolverDnSpotrfBatched); \ __macro(cusolverDnDpotrfBatched); \ __macro(cusolverDnSgesvdj_bufferSize); \ + __macro(cusolverDnSgeqrf_bufferSize); \ + __macro(cusolverDnDgeqrf_bufferSize); \ + __macro(cusolverDnCgeqrf_bufferSize); \ + __macro(cusolverDnZgeqrf_bufferSize); \ + __macro(cusolverDnSorgqr_bufferSize); \ + __macro(cusolverDnDorgqr_bufferSize); \ + __macro(cusolverDnCungqr_bufferSize); \ + __macro(cusolverDnZungqr_bufferSize); \ __macro(cusolverDnDestroyGesvdjInfo); \ __macro(cusolverDnCreateGesvdjInfo); \ __macro(cusolverDnDgesvdj_bufferSize); \ __macro(cusolverDnSgesvdj); \ - __macro(cusolverDnDgesvdj); + __macro(cusolverDnDgesvdj); \ + __macro(cusolverDnSgeqrf); \ + __macro(cusolverDnDgeqrf); \ + __macro(cusolverDnCgeqrf); \ + __macro(cusolverDnZgeqrf); \ + __macro(cusolverDnSorgqr); \ + __macro(cusolverDnDorgqr); \ + __macro(cusolverDnCungqr); \ + __macro(cusolverDnZungqr); CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif diff --git a/python/paddle/fluid/tests/unittests/test_qr_op.py b/python/paddle/fluid/tests/unittests/test_qr_op.py new file mode 100644 index 0000000000..ea2aaf3f00 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_qr_op.py @@ -0,0 +1,173 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import itertools +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core + + +class TestQrAPI(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + + def run_qr_dygraph(shape, mode, dtype): + if dtype == "float32": + np_dtype = np.float32 + elif dtype == "float64": + np_dtype = np.float64 + a = np.random.rand(*shape).astype(np_dtype) + m = a.shape[-2] + n = a.shape[-1] + min_mn = min(m, n) + if mode == "reduced" or mode == "r": + k = min_mn + else: + k = m + np_q_shape = list(a.shape[:-2]) + np_q_shape.extend([m, k]) + np_r_shape = list(a.shape[:-2]) + np_r_shape.extend([k, n]) + np_q = np.zeros(np_q_shape).astype(np_dtype) + np_r = np.zeros(np_r_shape).astype(np_dtype) + places = [] + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + batch_size = a.size // (a.shape[-1] * a.shape[-2]) + for i in range(batch_size): + coord = np.unravel_index(i, a.shape[:-2]) + if mode == "r": + tmp_r = np.linalg.qr(a[coord], mode=mode) + np_r[coord] = tmp_r + else: + tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) + np_q[coord] = tmp_q + np_r[coord] = tmp_r + + x = paddle.to_tensor(a, dtype=dtype) + if mode == "r": + r = paddle.linalg.qr(x, mode=mode) + self.assertTrue(np.allclose(r, np_r, atol=1e-5)) + else: + q, r = paddle.linalg.qr(x, mode=mode) + self.assertTrue(np.allclose(q, np_q, atol=1e-5)) + self.assertTrue(np.allclose(r, np_r, atol=1e-5)) + + tensor_shapes = [ + (3, 5), + (5, 5), + (5, 3), # 2-dim Tensors + (2, 3, 5), + (3, 5, 5), + (4, 5, 3), # 3-dim Tensors + (2, 5, 3, 5), + (3, 5, 5, 5), + (4, 5, 5, 3) # 4-dim Tensors + ] + modes = ["reduced", "complete", "r"] + dtypes = ["float32", "float64"] + for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, + dtypes): + run_qr_dygraph(tensor_shape, mode, dtype) + + def test_static(self): + paddle.enable_static() + + def run_qr_static(shape, mode, dtype): + if dtype == "float32": + np_dtype = np.float32 + elif dtype == "float64": + np_dtype = np.float64 + a = np.random.rand(*shape).astype(np_dtype) + m = a.shape[-2] + n = a.shape[-1] + min_mn = min(m, n) + if mode == "reduced" or mode == "r": + k = min_mn + else: + k = m + np_q_shape = list(a.shape[:-2]) + np_q_shape.extend([m, k]) + np_r_shape = list(a.shape[:-2]) + np_r_shape.extend([k, n]) + np_q = np.zeros(np_q_shape).astype(np_dtype) + np_r = np.zeros(np_r_shape).astype(np_dtype) + places = [] + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + batch_size = a.size // (a.shape[-1] * a.shape[-2]) + for i in range(batch_size): + coord = np.unravel_index(i, a.shape[:-2]) + if mode == "r": + tmp_r = np.linalg.qr(a[coord], mode=mode) + np_r[coord] = tmp_r + else: + tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) + np_q[coord] = tmp_q + np_r[coord] = tmp_r + x = paddle.fluid.data( + name="input", shape=shape, dtype=dtype) + if mode == "r": + r = paddle.linalg.qr(x, mode=mode) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": a}, + fetch_list=[r]) + self.assertTrue( + np.allclose( + fetches[0], np_r, atol=1e-5)) + else: + q, r = paddle.linalg.qr(x, mode=mode) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": a}, + fetch_list=[q, r]) + self.assertTrue( + np.allclose( + fetches[0], np_q, atol=1e-5)) + self.assertTrue( + np.allclose( + fetches[1], np_r, atol=1e-5)) + + tensor_shapes = [ + (3, 5), + (5, 5), + (5, 3), # 2-dim Tensors + (2, 3, 5), + (3, 5, 5), + (4, 5, 3), # 3-dim Tensors + (2, 5, 3, 5), + (3, 5, 5, 5), + (4, 5, 5, 3) # 4-dim Tensors + ] + modes = ["reduced", "complete", "r"] + dtypes = ["float32", "float64"] + for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, + dtypes): + run_qr_static(tensor_shape, mode, dtype) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 726355379e..06b512150c 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -23,6 +23,7 @@ from .tensor.linalg import eigvals # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd +from .tensor.linalg import qr from .tensor.linalg import eigh # noqa: F401 from .tensor.linalg import det from .tensor.linalg import slogdet @@ -38,6 +39,7 @@ __all__ = [ 'multi_dot', 'matrix_rank', 'svd', + 'qr', 'matrix_power', 'det', 'slogdet', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index c8f897c216..b898b60fe4 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -47,6 +47,7 @@ from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 from .linalg import eig # noqa: F401 from .linalg import matrix_power # noqa: F401 +from .linalg import qr # noqa: F401 from .linalg import eigvals # noqa: F401 from .linalg import multi_dot # noqa: F401 from .linalg import svd # noqa: F401 @@ -237,6 +238,7 @@ tensor_method_func = [ #noqa 'histogram', 'mv', 'matrix_power', + 'qr', 'eigvals', 'abs', 'acos', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index f112603fbb..6853d904ad 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1594,6 +1594,70 @@ def matrix_power(x, n, name=None): return out +def qr(x, mode="reduced", name=None): + r""" + Computes the QR decomposition of one matrix or batches of matrice (backward is unsupported now). + + Args: + x (Tensor): The input tensor. Its shape should be `[..., M, N]`, + where ... is zero or more batch dimensions. M and N can be arbitrary + positive number. The data type of x should be float32 or float64. + mode (str, optional): A flag to control the behavior of qr, the default is "reduced". + Suppose x's shape is `[..., M, N]` and denoting `K = min(M, N)`: + If mode = "reduced", qr op will return reduced Q and R matrices, + which means Q's shape is `[..., M, K]` and R's shape is `[..., K, N]`. + If mode = "complete", qr op will return complete Q and R matrices, + which means Q's shape is `[..., M, M]` and R's shape is `[..., M, N]`. + If mode = "r", qr op will only return reduced R matrix, which means + R's shape is `[..., K, N]`. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + If mode = "reduced" or mode = "complete", qr will return a two tensor-tuple, which represents Q and R. + If mode = "r", qr will return a tensor which represents R. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]).astype('float64') + q, r = paddle.linalg.qr(x) + print (q) + print (r) + + # Q = [[-0.16903085, 0.89708523], + # [-0.50709255, 0.27602622], + # [-0.84515425, -0.34503278]]) + + # R = [[-5.91607978, -7.43735744], + # [ 0. , 0.82807867]]) + + # one can verify : X = Q * R ; + """ + if in_dygraph_mode(): + q, r = _C_ops.qr(x, 'mode', mode) + if mode == "r": + return r + else: + return q, r + check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'qr') + check_type(mode, 'mode', str, 'qr') + helper = LayerHelper('qr', **locals()) + q = helper.create_variable_for_type_inference(dtype=x.dtype) + r = helper.create_variable_for_type_inference(dtype=x.dtype) + attrs = dict() + attrs['mode'] = mode + helper.append_op( + type='qr', inputs={'X': [x]}, outputs={'Q': q, + 'R': r}, attrs=attrs) + if mode == "r": + return r + else: + return q, r + + def eig(x, name=None): """ This API performs the eigenvalue decomposition of a square matrix or a batch of square matrices. @@ -1674,7 +1738,7 @@ def eigvals(x, name=None): Its data type should be float32, float64, complex64, or complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - + Returns: Tensor: A tensor containing the unsorted eigenvalues which has the same batch dimensions with `x`. The eigenvalues are complex-valued even when `x` is real. -- GitLab