From 2cf2e78637bc58160dda2720fa913e7372c23117 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 2 Aug 2022 17:18:38 +0800 Subject: [PATCH] [Phi] Move QR to Phi (#44742) * [Phi] Move Qr to the Phi * [Phi] Regiter the cpu grad kernel for qr * [Phi] Share the cuda kernels to lstsq * [Phi] Remove some improper inlcude files * [Phi] Modify codes based on the reviews * [Phi] Remove unecessary files and add the cuda_only comment * [Phi] Remove the unecessary include file * [Phi] Remove qr_op.cu and lstsq_op.cu --- paddle/fluid/operators/lstsq_op.cu | 308 ------------- paddle/fluid/operators/lstsq_op.h | 345 --------------- paddle/fluid/operators/qr_op.cc | 6 - paddle/fluid/operators/qr_op.cu | 392 ----------------- paddle/fluid/operators/qr_op.h | 204 --------- paddle/phi/api/yaml/legacy_api.yaml | 2 +- paddle/phi/api/yaml/legacy_backward.yaml | 10 + paddle/phi/kernels/cpu/qr_grad_kernel.cc | 23 + paddle/phi/kernels/gpu/qr_grad_kernel.cu | 20 + paddle/phi/kernels/gpu/qr_kernel.cu | 410 ++++++++++++++++++ paddle/phi/kernels/impl/qr_grad_kernel_impl.h | 182 ++++++++ paddle/phi/kernels/impl/qr_kernel_impl.h | 219 ---------- paddle/phi/kernels/qr_grad_kernel.h | 31 ++ paddle/phi/kernels/triangular_solve_kernel.h | 16 + paddle/phi/kernels/tril_triu_kernel.h | 13 + paddle/phi/ops/compat/qr_sig.cc | 6 + .../fluid/tests/unittests/test_qr_op.py | 8 +- python/paddle/tensor/linalg.py | 8 +- 18 files changed, 724 insertions(+), 1479 deletions(-) delete mode 100644 paddle/fluid/operators/lstsq_op.cu delete mode 100644 paddle/fluid/operators/lstsq_op.h delete mode 100644 paddle/fluid/operators/qr_op.cu delete mode 100644 paddle/fluid/operators/qr_op.h create mode 100644 paddle/phi/kernels/cpu/qr_grad_kernel.cc create mode 100644 paddle/phi/kernels/gpu/qr_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/qr_kernel.cu create mode 100644 paddle/phi/kernels/impl/qr_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/qr_grad_kernel.h diff --git a/paddle/fluid/operators/lstsq_op.cu b/paddle/fluid/operators/lstsq_op.cu deleted file mode 100644 index e9d1a6a136..0000000000 --- a/paddle/fluid/operators/lstsq_op.cu +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - -#include -#include - -#include "paddle/fluid/framework/phi_utils.h" -#include "paddle/fluid/operators/lstsq_op.h" -#include "paddle/fluid/operators/qr_op.h" -#include "paddle/fluid/platform/dynload/cusolver.h" -#include "paddle/phi/kernels/triangular_solve_kernel.h" - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; - -template -class LstsqCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor& x = *context.Input("X"); - const Tensor& y = *context.Input("Y"); - auto* solution = context.Output("Solution"); - - auto dito = - math::DeviceIndependenceTensorOperations(context); - auto& dev_ctx = context.template device_context(); - - auto x_dims = x.dims(); - auto y_dims = y.dims(); - int dim_size = x_dims.size(); - int m = x_dims[dim_size - 2]; - int n = x_dims[dim_size - 1]; - int nrhs = y_dims[dim_size - 1]; - int min_mn = std::min(m, n); - int max_mn = std::max(m, n); - int k = min_mn; - - int x_stride = MatrixStride(x); - int y_stride = MatrixStride(y); - int tau_stride = min_mn; - int batch_count = BatchCount(x); - - Tensor new_x, new_y; - new_x.mutable_data(context.GetPlace(), - size_t(batch_count * m * n * sizeof(T))); - new_y.mutable_data(context.GetPlace(), - size_t(batch_count * m * nrhs * sizeof(T))); - framework::TensorCopy(x, context.GetPlace(), &new_x); - framework::TensorCopy(y, context.GetPlace(), &new_y); - - // Prepare tau - auto tau_dims_vec = phi::vectorize(x_dims); - tau_dims_vec.pop_back(); - tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; - Tensor tau = dito.Fill(tau_dims_vec, 0); - auto tau_data = tau.mutable_data(context.GetPlace()); - - using Context = - typename framework::ConvertToPhiContext::TYPE; - auto& phi_dev_ctx = static_cast(dev_ctx); - - if (m >= n) { - Tensor tmp_x = dito.Transpose(new_x); - Tensor tmp_y = dito.Transpose(new_y); - auto x_data = tmp_x.mutable_data(context.GetPlace()); - auto y_data = tmp_y.mutable_data(context.GetPlace()); - - // step 1, compute QR factorization using geqrf - BatchedGeqrf(dev_ctx, - batch_count, - m, - n, - x_data, - m, - tau_data, - x_stride, - tau_stride); - - // Step 2, Y <- Q^H Y - BatchedOrmqr(dev_ctx, - true, - true, - batch_count, - m, - nrhs, - k, - x_data, - x_stride, - tau_data, - tau_stride, - y_data, - y_stride); - - Tensor trans_r = dito.Transpose(tmp_x); - Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn}); - Tensor res_r = dito.TrilTriu(slice_r, 0, false); - - Tensor trans_y = dito.Transpose(tmp_y); - Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn}); - - // Step 3, solve R X = Y - phi::TriangularSolveKernel( - phi_dev_ctx, res_r, slice_y, true, false, false, solution); - - } else { - auto x_data = new_x.mutable_data(context.GetPlace()); - auto y_data = new_y.mutable_data(context.GetPlace()); - - // step 1, compute QR factorization using geqrf - BatchedGeqrf(dev_ctx, - batch_count, - n, - m, - x_data, - n, - tau_data, - x_stride, - tau_stride); - - // Step 2, solve R^H Z = Y - Tensor trans_r = dito.Transpose(new_x); - Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn}); - Tensor res_r = dito.TrilTriu(slice_r, 0, false); - - phi::TriangularSolveKernel( - phi_dev_ctx, res_r, new_y, true, true, false, solution); - - // Step 3, X <- Q Z - BatchedOrgqr(dev_ctx, - batch_count, - n, - m, - min_mn, - x_data, - n, - tau_data, - x_stride, - tau_stride); - Tensor trans_q = dito.Transpose(new_x); - Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m}); - Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false); - framework::TensorCopy(solu_tensor, context.GetPlace(), solution); - } - } -}; - -template <> -void BatchedOrmqr(const phi::GPUContext& dev_ctx, - bool left, - bool transpose, - int batch_size, - int m, - int n, - int k, - float* a, - int a_stride, - float* tau, - int tau_stride, - float* other, - int other_stride) { - int lwork = 0; - auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; - auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - int lda = std::max(1, left ? m : n); - int ldc = std::max(1, m); - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize( - handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - float* a_working_ptr = &a[i * a_stride]; - float* tau_working_ptr = &tau[i * tau_stride]; - float* other_working_ptr = &other[i * other_stride]; - - handle = dev_ctx.cusolver_dn_handle(); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); - - // compute ormgr - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnSormqr(handle, - side, - trans, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - other_working_ptr, - ldc, - workspace_ptr, - lwork, - info_d)); - - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); - } -} - -template <> -void BatchedOrmqr(const phi::GPUContext& dev_ctx, - bool left, - bool transpose, - int batch_size, - int m, - int n, - int k, - double* a, - int a_stride, - double* tau, - int tau_stride, - double* other, - int other_stride) { - int lwork = 0; - auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; - auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - int lda = std::max(1, left ? m : n); - int ldc = std::max(1, m); - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize( - handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork)); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - double* a_working_ptr = &a[i * a_stride]; - double* tau_working_ptr = &tau[i * tau_stride]; - double* other_working_ptr = &other[i * other_stride]; - - handle = dev_ctx.cusolver_dn_handle(); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); - - // compute ormgr - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDormqr(handle, - side, - trans, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - other_working_ptr, - ldc, - workspace_ptr, - lwork, - info_d)); - - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver info is not zero but [%d]", i, info_h)); - } -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(lstsq, - ops::LstsqCUDAKernel, - ops::LstsqCUDAKernel); - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h deleted file mode 100644 index 7e71d17364..0000000000 --- a/paddle/fluid/operators/lstsq_op.h +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include -#include - -#include "paddle/fluid/operators/eig_op.h" -#include "paddle/fluid/operators/math/eigen_values_vectors.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/matrix_solve.h" - -#define EPSILON 1e-6 - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; -enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; - -using DDim = framework::DDim; -static DDim UDDim(const DDim& x_dim) { - auto x_vec = vectorize(x_dim); - return phi::make_ddim(x_vec); -} - -template -class LstsqCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using ValueType = phi::dtype::Real; - - const Tensor& x = *context.Input("X"); - auto y = context.Input("Y"); - auto rcond = context.Attr("rcond"); - auto driver_string = context.Attr("driver"); - - static auto driver_type = std::unordered_map( - {{"gels", LapackDriverType::Gels}, - {"gelsy", LapackDriverType::Gelsy}, - {"gelsd", LapackDriverType::Gelsd}, - {"gelss", LapackDriverType::Gelss}}); - auto driver = driver_type[driver_string]; - - auto solution = context.Output("Solution"); - auto* rank = context.Output("Rank"); - auto* singular_values = context.Output("SingularValues"); - - auto dito = - math::DeviceIndependenceTensorOperations(context); - - auto x_dims = x.dims(); - auto y_dims = y->dims(); - int dim_size = x_dims.size(); - int x_stride = MatrixStride(x); - int y_stride = MatrixStride(*y); - int batch_count = BatchCount(x); - auto solution_dim = solution->dims(); - int ori_solu_stride = MatrixStride(*solution); - int max_solu_stride = std::max(y_stride, ori_solu_stride); - int min_solu_stride = std::min(y_stride, ori_solu_stride); - - // lapack is a column-major storge, transpose make the input to - // have a continuous memory layout - int info = 0; - int m = x_dims[dim_size - 2]; - int n = x_dims[dim_size - 1]; - int nrhs = y_dims[dim_size - 1]; - int lda = std::max(m, 1); - int ldb = std::max(1, std::max(m, n)); - - Tensor new_x; - new_x.mutable_data(context.GetPlace(), - size_t(batch_count * m * n * sizeof(T))); - framework::TensorCopy(x, context.GetPlace(), &new_x); - - solution->mutable_data( - context.GetPlace(), - size_t(batch_count * std::max(m, n) * nrhs * sizeof(T))); - - if (m >= n) { - const Tensor& new_y = *context.Input("Y"); - framework::TensorCopy(new_y, context.GetPlace(), solution); - } else { - auto* solu_data = solution->data(); - auto* y_data = y->data(); - for (auto i = 0; i < batch_count; i++) { - for (auto j = 0; j < min_solu_stride; j++) { - solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j]; - } - } - } - - Tensor input_x_trans = dito.Transpose(new_x); - Tensor input_y_trans = dito.Transpose(*solution); - framework::TensorCopy(input_x_trans, context.GetPlace(), &new_x); - framework::TensorCopy(input_y_trans, context.GetPlace(), solution); - - auto* x_vector = new_x.data(); - auto* y_vector = solution->data(); - - // "gels" divers does not need to compute rank - int rank_32 = 0; - int* rank_data = nullptr; - int* rank_working_ptr = nullptr; - if (driver != LapackDriverType::Gels) { - rank_data = rank->mutable_data(context.GetPlace()); - rank_working_ptr = rank_data; - } - - // "gelsd" and "gelss" divers need to compute singular values - ValueType* s_data = nullptr; - ValueType* s_working_ptr = nullptr; - int s_stride = 0; - if (driver == LapackDriverType::Gelsd || - driver == LapackDriverType::Gelss) { - s_data = singular_values->mutable_data(context.GetPlace()); - s_working_ptr = s_data; - auto s_dims = singular_values->dims(); - s_stride = s_dims[s_dims.size() - 1]; - } - - // "jpvt" is only used for "gelsy" driver - Tensor jpvt; - int* jpvt_data = nullptr; - if (driver == LapackDriverType::Gelsy) { - jpvt.Resize(phi::make_ddim({std::max(1, n)})); - jpvt_data = jpvt.mutable_data(context.GetPlace()); - } - - // run once the driver, first to get the optimal workspace size - int lwork = -1; - T wkopt; - ValueType rwkopt; - int iwkopt = 0; - - if (driver == LapackDriverType::Gels) { - phi::funcs::lapackGels( - 'N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, lwork, &info); - } else if (driver == LapackDriverType::Gelsd) { - phi::funcs::lapackGelsd(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &iwkopt, - &info); - } else if (driver == LapackDriverType::Gelsy) { - phi::funcs::lapackGelsy(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - jpvt_data, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &info); - } else if (driver == LapackDriverType::Gelss) { - phi::funcs::lapackGelss(m, - n, - nrhs, - x_vector, - lda, - y_vector, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - &wkopt, - lwork, - &rwkopt, - &info); - } - - lwork = std::max(1, static_cast(phi::dtype::Real(wkopt))); - Tensor work; - work.Resize(phi::make_ddim({lwork})); - T* work_data = work.mutable_data(context.GetPlace()); - - // "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers - Tensor rwork; - ValueType* rwork_data = nullptr; - if (framework::IsComplexType(framework::TransToProtoVarType(x.dtype())) && - driver != LapackDriverType::Gels) { - int rwork_len = 0; - if (driver == LapackDriverType::Gelsy) { - rwork_len = std::max(1, 2 * n); - } else if (driver == LapackDriverType::Gelss) { - rwork_len = std::max(1, 5 * std::min(m, n)); - } else if (driver == LapackDriverType::Gelsd) { - rwork_len = std::max(1, rwkopt); - } - rwork.Resize(phi::make_ddim({rwork_len})); - rwork_data = rwork.mutable_data(context.GetPlace()); - } - - // "iwork" workspace array is relavant only for "gelsd" driver - Tensor iwork; - int* iwork_data = nullptr; - if (driver == LapackDriverType::Gelsd) { - iwork.Resize(phi::make_ddim({std::max(1, iwkopt)})); - iwork_data = iwork.mutable_data(context.GetPlace()); - } - - for (auto i = 0; i < batch_count; ++i) { - auto* x_input = &x_vector[i * x_stride]; - auto* y_input = &y_vector[i * max_solu_stride]; - rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; - s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; - - if (driver == LapackDriverType::Gels) { - phi::funcs::lapackGels('N', - m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - work_data, - lwork, - &info); - } else if (driver == LapackDriverType::Gelsd) { - phi::funcs::lapackGelsd(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - iwork_data, - &info); - } else if (driver == LapackDriverType::Gelsy) { - phi::funcs::lapackGelsy(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - jpvt_data, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - &info); - } else if (driver == LapackDriverType::Gelss) { - phi::funcs::lapackGelss(m, - n, - nrhs, - x_input, - lda, - y_input, - ldb, - s_working_ptr, - static_cast(rcond), - &rank_32, - work_data, - lwork, - rwork_data, - &info); - } - - PADDLE_ENFORCE_EQ( - info, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: Lapack info is not zero but [%d]", i, info)); - - if (rank_working_ptr) *rank_working_ptr = static_cast(rank_32); - } - - Tensor tmp_s = dito.Transpose(*solution); - framework::TensorCopy(tmp_s, context.GetPlace(), solution); - - if (m > n) { - auto* solu_data = solution->data(); - for (auto i = 1; i < batch_count; i++) { - for (auto j = 0; j < min_solu_stride; j++) { - solu_data[i * min_solu_stride + j] = - solu_data[i * max_solu_stride + j]; - } - } - } - - solution->Resize(UDDim(solution_dim)); - } -}; - -template -void BatchedOrmqr(const DeviceContext& dev_ctx, - bool left, - bool transpose, - int batch_size, - int m, - int n, - int k, - T* a, - int a_stride, - T* tau, - int tau_stride, - T* other, - int other_stride); - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc index 90ace1ba77..e939ec7be2 100644 --- a/paddle/fluid/operators/qr_op.cc +++ b/paddle/fluid/operators/qr_op.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/qr_op.h" - #include #include #include @@ -123,7 +121,3 @@ REGISTER_OPERATOR(qr, QrInferShapeFunctor); REGISTER_OPERATOR(qr_grad, ops::QrGradOp); - -REGISTER_OP_CPU_KERNEL(qr_grad, - ops::QrGradKernel, - ops::QrGradKernel); diff --git a/paddle/fluid/operators/qr_op.cu b/paddle/fluid/operators/qr_op.cu deleted file mode 100644 index 8ae18a5632..0000000000 --- a/paddle/fluid/operators/qr_op.cu +++ /dev/null @@ -1,392 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - -#include - -#include -#include - -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/qr_op.h" -#include "paddle/fluid/platform/dynload/cusolver.h" - -// Reuse some helper functions from svd -#include "paddle/fluid/operators/svd_helper.h" - -namespace paddle { -namespace operators { - -template -class QrGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - bool compute_q; - bool reduced_mode; - auto& dev_ctx = context.template device_context(); - const Tensor& x = *context.Input("X"); - Tensor& q = *context.Output("Q"); - Tensor& r = *context.Output("R"); - const std::string mode = context.Attr("mode"); - std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); - - auto numel = x.numel(); - PADDLE_ENFORCE_GT( - numel, - 0, - platform::errors::PreconditionNotMet("The input of QR is empty.")); - auto x_dims = x.dims(); - int x_rank = x_dims.size(); - int m = x_dims[x_rank - 2]; - int n = x_dims[x_rank - 1]; - int min_mn = std::min(m, n); - int k = reduced_mode ? min_mn : m; - int batch_size = numel / (m * n); - int qr_stride = m * n; - int tau_stride = min_mn; - - if (compute_q) { - q.mutable_data>( - context.GetPlace(), - size_t(batch_size * m * k * sizeof(phi::dtype::Real))); - } - r.mutable_data>( - context.GetPlace(), - size_t(batch_size * k * n * sizeof(phi::dtype::Real))); - - auto dito = - math::DeviceIndependenceTensorOperations(context); - - // Note: allocate temporary tensors because of lacking in-place operatios. - // Prepare qr - Tensor qr; - qr.mutable_data>( - context.GetPlace(), - size_t(batch_size * m * n * sizeof(phi::dtype::Real))); - // BatchedGeqrf performs computation in-place and 'qr' must be a copy of - // input - paddle::framework::TensorCopy(x, context.GetPlace(), &qr); - - // Prepare tau - auto tau_dims_vec = phi::vectorize(x_dims); - tau_dims_vec.pop_back(); - tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; - Tensor tau = dito.Fill(tau_dims_vec, 0); - - // Transpose 'qr' to conform the column-major order - auto tmp_qr = dito.Transpose(qr); - framework::TensorCopy(tmp_qr, qr.place(), &qr); - auto qr_data = qr.mutable_data(context.GetPlace()); - auto tau_data = tau.mutable_data(context.GetPlace()); - - BatchedGeqrf( - dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride); - - if (reduced_mode) { - auto trans_qr = dito.Transpose(qr); - auto sliced_qr = dito.Slice(trans_qr, {-2}, {0}, {min_mn}); - auto tmp_r = dito.TrilTriu(sliced_qr, 0, false); - // Transpose 'tmp_r' to retore the original row-major order - framework::TensorCopy(tmp_r, r.place(), &r); - } else { - auto trans_qr = dito.Transpose(qr); - auto tmp_r = dito.TrilTriu(trans_qr, 0, false); - // Transpose 'tmp_r' to retore the original row-major order - framework::TensorCopy(tmp_r, r.place(), &r); - } - - if (compute_q) { - // Perform QRGQR for Q using the result from GEQRF - // Transpose 'q' to retore the original row-major order - if (reduced_mode) { - BatchedOrgqr(dev_ctx, - batch_size, - m, - min_mn, - min_mn, - qr_data, - m, - tau_data, - qr_stride, - tau_stride); - auto trans_q = dito.Transpose(qr); - auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {min_mn}); - framework::TensorCopy(sliced_q, q.place(), &q); - } else { - if (m > n) { - auto new_qr_dims_vec = phi::vectorize(x_dims); - new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m; - Tensor new_qr = dito.Fill(new_qr_dims_vec, 0); - auto new_qr_data = new_qr.mutable_data(context.GetPlace()); - auto new_qr_stride = m * m; - for (int i = 0; i < batch_size; ++i) { - memory::Copy(dev_ctx.GetPlace(), - (new_qr_data + i * new_qr_stride), - dev_ctx.GetPlace(), - (qr_data + i * qr_stride), - qr_stride * sizeof(phi::dtype::Real), - dev_ctx.stream()); - } - BatchedOrgqr(dev_ctx, - batch_size, - m, - m, - min_mn, - new_qr_data, - m, - tau_data, - new_qr_stride, - tau_stride); - auto trans_q = dito.Transpose(new_qr); - framework::TensorCopy(trans_q, q.place(), &q); - } else { - BatchedOrgqr(dev_ctx, - batch_size, - m, - m, - min_mn, - qr_data, - m, - tau_data, - qr_stride, - tau_stride); - auto trans_q = dito.Transpose(qr); - auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {m}); - framework::TensorCopy(sliced_q, q.place(), &q); - } - } - } - } -}; - -template <> -void BatchedGeqrf(const phi::GPUContext& dev_ctx, - int batch_size, - int m, - int n, - float* a, - int lda, - float* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgeqrf_bufferSize( - handle, m, n, a, lda, &lwork)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - float* a_working_ptr = &a[i * a_stride]; - float* tau_working_ptr = &tau[i * tau_stride]; - // compute geqrf - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnSgeqrf(handle, - m, - n, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); - } -} - -template <> -void BatchedGeqrf(const phi::GPUContext& dev_ctx, - int batch_size, - int m, - int n, - double* a, - int lda, - double* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgeqrf_bufferSize( - handle, m, n, a, lda, &lwork)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - double* a_working_ptr = &a[i * a_stride]; - double* tau_working_ptr = &tau[i * tau_stride]; - // compute geqrf - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDgeqrf(handle, - m, - n, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); - } -} - -template <> -void BatchedOrgqr(const phi::GPUContext& dev_ctx, - int batch_size, - int m, - int n, - int k, - float* a, - int lda, - float* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSorgqr_bufferSize( - handle, m, n, k, a, lda, tau, &lwork)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - float* a_working_ptr = &a[i * a_stride]; - float* tau_working_ptr = &tau[i * tau_stride]; - // compute orggr - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnSorgqr(handle, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); - } -} - -template <> -void BatchedOrgqr(const phi::GPUContext& dev_ctx, - int batch_size, - int m, - int n, - int k, - double* a, - int lda, - double* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDorgqr_bufferSize( - handle, m, n, k, a, lda, tau, &lwork)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); - auto info = memory::Alloc(dev_ctx, sizeof(int)); - int* info_d = reinterpret_cast(info->ptr()); - - for (int i = 0; i < batch_size; ++i) { - double* a_working_ptr = &a[i * a_stride]; - double* tau_working_ptr = &tau[i * tau_stride]; - // compute orggr - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDorgqr(handle, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - memory::Copy(platform::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); - } -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(qr, ops::QrGPUKernel, ops::QrGPUKernel); -REGISTER_OP_CUDA_KERNEL(qr_grad, - ops::QrGradKernel, - ops::QrGradKernel); - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h deleted file mode 100644 index dd63e197e9..0000000000 --- a/paddle/fluid/operators/qr_op.h +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -static inline std::tuple _parse_qr_mode(std::string mode) { - bool compute_q; - bool reduced; - if (mode == "reduced") { - compute_q = true; - reduced = true; - } else if (mode == "complete") { - compute_q = true; - reduced = false; - } else if (mode == "r") { - compute_q = false; - reduced = true; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "QR received unrecognized mode '%s'" - " but expected one of 'reduced' (default), 'r', or 'complete'", - mode)); - } - return std::make_tuple(compute_q, reduced); -} - -template -class QrGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor& Q = *ctx.Input("Q"); - const framework::Tensor& R = *ctx.Input("R"); - // Use a different name A instead of X - const framework::Tensor& A = *ctx.Input("X"); - const framework::Tensor& dQ = - *ctx.Input(framework::GradVarName("Q")); - const framework::Tensor& dR = - *ctx.Input(framework::GradVarName("R")); - // Use a different name dA instead of dX - framework::Tensor& dA = - *ctx.Output(framework::GradVarName("X")); - dA.mutable_data>(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant()(dev_ctx, &dA, T(0)); - - auto dito = math::DeviceIndependenceTensorOperations(ctx); - - std::string mode = ctx.Attr("mode"); - bool compute_q, reduced; - std::tie(compute_q, reduced) = _parse_qr_mode(mode); - if (!compute_q) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The derivative of qr is not implemented when mode='r'.")); - } - - auto a_dims = A.dims(); - int a_rank = a_dims.size(); - int m = a_dims[a_rank - 2]; - int n = a_dims[a_rank - 1]; - - if ((m > n) && (!reduced)) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The derivative of qr is not implemented when mode='complete' and " - "nrows > ncols.")); - } - - // m >= n case - auto m_gt_n_case = - [](const framework::ExecutionContext& ctx, - math::DeviceIndependenceTensorOperations& dito, - const Tensor& dQ, - const Tensor& dR, - const Tensor& A, - const Tensor& Q, - const Tensor& R) -> framework::Tensor { - // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable - // Programming Tensor Networks. - // https://arxiv.org/abs/1903.09650 Section 3. QR factorization - - // dR^H - framework::Tensor R_term; - if (ctx.HasInput(framework::GradVarName("R"))) { - R_term = dito.Matmul(R, dito.Transpose(dR)); - } else { - R_term = dito.Fill(phi::vectorize(R.dims()), 0); - } - - // dQ^H * Q - framework::Tensor Q_term; - if (ctx.HasInput(framework::GradVarName("Q"))) { - Q_term = dito.Matmul(dito.Transpose(dQ), Q); - } else { - Q_term = dito.Fill(phi::vectorize(R.dims()), 0); - } - - framework::Tensor M_tmp1 = dito.Sub(R_term, Q_term); - - // Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity - framework::Tensor M_tril_0 = dito.TrilTriu(M_tmp1, 0, true); - framework::Tensor M_tril_1 = dito.TrilTriu(M_tmp1, -1, true); - framework::Tensor M = dito.Add(M_tril_0, dito.Transpose(M_tril_1)); - - framework::Tensor rhs_term; - if (ctx.HasInput(framework::GradVarName("Q"))) { - rhs_term = dito.Add(dQ, dito.Matmul(Q, M)); - } else { - rhs_term = dito.Matmul(Q, M); - } - - // dA * R^H = rhs_term - auto dA = - dito.TriangularSolve(dito.Transpose(dito.Conj(dito.Transpose(R))), - dito.Transpose(rhs_term), - /*upper=*/true, - /*transpose=*/false, - /*unitriangular=*/false); - - return dito.Transpose(dA); - }; - - if (m >= n) { - auto dA_tmp = m_gt_n_case(ctx, dito, dQ, dR, A, Q, R); - framework::TensorCopy(dA_tmp, dA.place(), &dA); - } else { - // If m < n for input matrices A, we partition A = [X|Y] and R = [U|V] - // Calculate dX and dY individually and concatenate them to get dA - dA.mutable_data>(ctx.GetPlace()); - - auto Y = dito.Slice(A, {-1}, {m}, {n}); - auto U = dito.Slice(R, {-1}, {0}, {m}); - framework::Tensor dY, dX, dV, dR_tmp, dQ_prime; - - if (ctx.HasInput(framework::GradVarName("R"))) { - dV = dito.Slice(dR, {-1}, {m}, {n}); - dR_tmp = dito.Slice(dR, {-1}, {0}, {m}); - // Y * dV^H - dQ_prime = dito.Matmul(Y, dito.Transpose(dV)); - } else { - dV = dito.Fill(phi::vectorize(Y.dims()), 0); - dQ_prime = dito.Fill(phi::vectorize(Q.dims()), 0); - } - - if (ctx.HasInput(framework::GradVarName("Q"))) { - dQ_prime = dito.Add(dQ_prime, dQ); - } - dX = m_gt_n_case(ctx, dito, dQ_prime, dR_tmp, A, Q, U); - dY = dito.Matmul(Q, dV); - // Concatenate dX and dY to get dA. - auto dA_tmp = dito.ConcatTwoTensors(dX, dY, -1); - framework::TensorCopy(dA_tmp, dA.place(), &dA); - } - } -}; - -template -void BatchedGeqrf(const DeviceContext& dev_ctx, - int batch_size, - int m, - int n, - T* a, - int lda, - T* tau, - int a_stride, - int tau_stride); - -template -void BatchedOrgqr(const DeviceContext& dev_ctx, - int batch_size, - int m, - int n, - int k, - T* a, - int lda, - T* tau, - int a_stride, - int tau_stride); - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index a0565f6b64..4822994fef 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1857,7 +1857,7 @@ func : QrInferMeta kernel : func : qr - # backward : qr_grad + backward : qr_grad - api : randint args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={}) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 4dbfb15661..909f248838 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1709,6 +1709,16 @@ kernel : func : put_along_axis_grad +- backward_api : qr_grad + forward : qr (Tensor x, str mode) -> Tensor(q), Tensor(r) + args : (Tensor x, Tensor q, Tensor r, Tensor q_grad, Tensor r_grad, str mode) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : qr_grad + - backward_api : real_grad forward : real (Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/kernels/cpu/qr_grad_kernel.cc b/paddle/phi/kernels/cpu/qr_grad_kernel.cc new file mode 100644 index 0000000000..3e9f8453d8 --- /dev/null +++ b/paddle/phi/kernels/cpu/qr_grad_kernel.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/qr_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/qr_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(qr_grad, CPU, ALL_LAYOUT, phi::QrGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/qr_grad_kernel.cu b/paddle/phi/kernels/gpu/qr_grad_kernel.cu new file mode 100644 index 0000000000..9f59ee53c1 --- /dev/null +++ b/paddle/phi/kernels/gpu/qr_grad_kernel.cu @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/qr_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/qr_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(qr_grad, GPU, ALL_LAYOUT, phi::QrGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/qr_kernel.cu b/paddle/phi/kernels/gpu/qr_kernel.cu new file mode 100644 index 0000000000..99752ac486 --- /dev/null +++ b/paddle/phi/kernels/gpu/qr_kernel.cu @@ -0,0 +1,410 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP // HIP not support cusolver + +#include +#include +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/parse_qr_mode.h" +#include "paddle/phi/kernels/impl/qr_kernel_impl.h" +#include "paddle/phi/kernels/qr_kernel.h" +#include "paddle/phi/kernels/slice_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/phi/kernels/tril_triu_kernel.h" + +namespace phi { + +template +static DenseTensor Fill(const Context& ctx, + std::vector shape, + float fill_value) { + DenseTensor ret; + ret.Resize(make_ddim(shape)); + ctx.template Alloc(&ret); + funcs::SetConstant()(ctx, &ret, T(fill_value)); + return ret; +} + +template +void QrKernel(const Context& ctx, + const DenseTensor& x, + const std::string& mode, + DenseTensor* q, + DenseTensor* r) { + bool compute_q; + bool reduced_mode; + std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); + auto numel = x.numel(); + PADDLE_ENFORCE_GT( + numel, 0, errors::PreconditionNotMet("The input of QR is empty.")); + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = numel / (m * n); + int qr_stride = m * n; + int tau_stride = min_mn; + + if (compute_q) { + ctx.template Alloc>( + q, batch_size * m * k * sizeof(phi::dtype::Real)); + } + ctx.template Alloc>( + r, batch_size * k * n * sizeof(phi::dtype::Real)); + + // Note: allocate temporary tensors because of lacking in-place operatios. + // Prepare qr + DenseTensor qr; + ctx.template Alloc>( + &qr, size_t(batch_size * m * n * sizeof(phi::dtype::Real))); + // BatchedGeqrf performs computation in-place and 'qr' must be a copy of + // input + phi::Copy(ctx, x, ctx.GetPlace(), false, &qr); + + // Prepare tau + auto tau_dims_vec = phi::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + DenseTensor tau = Fill(ctx, tau_dims_vec, 0); + + // Transpose 'qr' to conform the column-major order + auto tmp_qr = TransposeLast2Dim(ctx, qr); + phi::Copy(ctx, tmp_qr, qr.place(), false, &qr); + auto qr_data = ctx.template Alloc>(&qr); + auto tau_data = ctx.template Alloc>(&tau); + + BatchedGeqrf( + ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride); + + if (reduced_mode) { + auto trans_qr = TransposeLast2Dim(ctx, qr); + auto sliced_qr = SliceKernel( + ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn}, {1}, {}); + auto tmp_r = TrilTriu(ctx, sliced_qr, 0, false); + // Transpose 'tmp_r' to retore the original row-major order + phi::Copy(ctx, tmp_r, r->place(), false, r); + } else { + auto trans_qr = TransposeLast2Dim(ctx, qr); + auto tmp_r = TrilTriu(ctx, trans_qr, 0, false); + // Transpose 'tmp_r' to retore the original row-major order + phi::Copy(ctx, tmp_r, r->place(), false, r); + } + + if (compute_q) { + // Perform QRGQR for Q using the result from GEQRF + // Transpose 'q' to retore the original row-major order + if (reduced_mode) { + BatchedOrgqr(ctx, + batch_size, + m, + min_mn, + min_mn, + qr_data, + m, + tau_data, + qr_stride, + tau_stride); + auto trans_q = TransposeLast2Dim(ctx, qr); + auto sliced_q = SliceKernel( + ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn}, {1}, {}); + phi::Copy(ctx, sliced_q, q->place(), false, q); + } else { + if (m > n) { + auto new_qr_dims_vec = phi::vectorize(x_dims); + new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m; + DenseTensor new_qr = Fill(ctx, new_qr_dims_vec, 0); + auto new_qr_data = ctx.template Alloc>(&new_qr); + auto new_qr_stride = m * m; + for (int i = 0; i < batch_size; ++i) { + paddle::memory::Copy(ctx.GetPlace(), + (new_qr_data + i * new_qr_stride), + ctx.GetPlace(), + (qr_data + i * qr_stride), + qr_stride * sizeof(phi::dtype::Real), + ctx.stream()); + } + BatchedOrgqr(ctx, + batch_size, + m, + m, + min_mn, + new_qr_data, + m, + tau_data, + new_qr_stride, + tau_stride); + auto trans_q = TransposeLast2Dim(ctx, new_qr); + phi::Copy(ctx, trans_q, q->place(), false, q); + } else { + BatchedOrgqr(ctx, + batch_size, + m, + m, + min_mn, + qr_data, + m, + tau_data, + qr_stride, + tau_stride); + auto trans_q = TransposeLast2Dim(ctx, qr); + auto sliced_q = SliceKernel( + ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m}, {1}, {}); + phi::Copy(ctx, sliced_q, q->place(), false, q); + } + } + } +} + +template <> +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + float* a, + int lda, + float* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedGeqrf(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + double* a, + int lda, + double* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgeqrf(handle, + m, + n, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + float* a, + int lda, + float* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(make_ddim({lwork})); + float* workspace_ptr = dev_ctx.template Alloc(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedOrgqr(const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + double* a, + int lda, + double* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(make_ddim({lwork})); + double* workspace_ptr = dev_ctx.template Alloc(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr(handle, + m, + n, + k, + a_working_ptr, + lda, + tau_working_ptr, + workspace_ptr, + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + paddle::memory::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(qr, // cuda_only + GPU, + ALL_LAYOUT, + phi::QrKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/qr_grad_kernel_impl.h b/paddle/phi/kernels/impl/qr_grad_kernel_impl.h new file mode 100644 index 0000000000..5c04d9bb90 --- /dev/null +++ b/paddle/phi/kernels/impl/qr_grad_kernel_impl.h @@ -0,0 +1,182 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/concat_kernel.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/parse_qr_mode.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/slice_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/phi/kernels/triangular_solve_kernel.h" +#include "paddle/phi/kernels/tril_triu_kernel.h" + +namespace phi { + +template +static DenseTensor Fill(const Context& ctx, + std::vector shape, + float fill_value) { + DenseTensor ret; + ret.Resize(make_ddim(shape)); + ctx.template Alloc(&ret); + funcs::SetConstant()(ctx, &ret, T(fill_value)); + return ret; +} + +template +void QrGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& q, + const DenseTensor& r, + const DenseTensor& q_grad, + const DenseTensor& r_grad, + const std::string& mode, + DenseTensor* x_grad) { + // Using alias names + const DenseTensor& A = x; + const DenseTensor& Q = q; + const DenseTensor& R = r; + const DenseTensor& dQ = q_grad; + const DenseTensor& dR = r_grad; + DenseTensor& dA = *x_grad; + + ctx.template Alloc>(&dA); + phi::funcs::SetConstant()(ctx, &dA, T(0)); + + bool compute_q, reduced; + std::tie(compute_q, reduced) = phi::funcs::ParseQrMode(mode); + if (!compute_q) { + PADDLE_THROW(errors::InvalidArgument( + "The derivative of qr is not implemented when mode='%s'.", mode)); + } + + auto a_dims = A.dims(); + int a_rank = a_dims.size(); + int m = a_dims[a_rank - 2]; + int n = a_dims[a_rank - 1]; + + if ((m > n) && (!reduced)) { + PADDLE_THROW(errors::InvalidArgument( + "The derivative of qr is not implemented when mode='complete' and " + "%d > %d.", + m, + n)); + } + + // m >= n case + auto m_gt_n_case = [](const Context& ctx, + const DenseTensor& dQ, + const DenseTensor& dR, + const DenseTensor& A, + const DenseTensor& Q, + const DenseTensor& R) -> DenseTensor { + // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable + // Programming Tensor Networks. + // https://arxiv.org/abs/1903.09650 Section 3. QR factorization + + // dR^H + DenseTensor R_term; + if (dR.initialized()) { + R_term = + Matmul(ctx, R, TransposeLast2Dim(ctx, dR)); + } else { + R_term = Fill(ctx, phi::vectorize(R.dims()), 0); + } + + // dQ^H * Q + DenseTensor Q_term; + if (dQ.initialized()) { + Q_term = + Matmul(ctx, TransposeLast2Dim(ctx, dQ), Q); + } else { + Q_term = Fill(ctx, phi::vectorize(R.dims()), 0); + } + + DenseTensor M_tmp1 = Subtract(ctx, R_term, Q_term); + + // Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity + DenseTensor M_tril_0 = TrilTriu(ctx, M_tmp1, 0, true); + DenseTensor M_tril_1 = TrilTriu(ctx, M_tmp1, -1, true); + DenseTensor M = Add( + ctx, M_tril_0, TransposeLast2Dim(ctx, M_tril_1)); + + DenseTensor rhs_term; + if (dQ.initialized()) { + rhs_term = Add(ctx, dQ, Matmul(ctx, Q, M)); + } else { + rhs_term = Matmul(ctx, Q, M); + } + + // dA * R^H = rhs_term + auto dA = TriangularSolve( + ctx, + TransposeLast2Dim( + ctx, Conj(ctx, TransposeLast2Dim(ctx, R))), + TransposeLast2Dim(ctx, rhs_term), + /*upper=*/true, + /*transpose=*/false, + /*unitriangular=*/false); + + return TransposeLast2Dim(ctx, dA); + }; + + if (m >= n) { + auto dA_tmp = m_gt_n_case(ctx, dQ, dR, A, Q, R); + phi::Copy(ctx, dA_tmp, dA.place(), false, &dA); + } else { + // If m < n for input matrices A, we partition A = [X|Y] and R = [U|V] + // Calculate dX and dY individually and concatenate them to get dA + ctx.template Alloc>(&dA); + + auto Y = SliceKernel( + ctx, A, {A.dims().size() - 1}, {m}, {n}, {1}, {}); + auto U = SliceKernel( + ctx, R, {R.dims().size() - 1}, {0}, {m}, {1}, {}); + DenseTensor dY, dX, dV, dR_tmp, dQ_prime; + + if (dR.initialized()) { + dV = SliceKernel( + ctx, dR, {dR.dims().size() - 1}, {m}, {n}, {1}, {}); + dR_tmp = SliceKernel( + ctx, dR, {dR.dims().size() - 1}, {0}, {m}, {1}, {}); + // Y * dV^H + dQ_prime = + Matmul(ctx, Y, TransposeLast2Dim(ctx, dV)); + } else { + dV = Fill(ctx, phi::vectorize(Y.dims()), 0); + dQ_prime = Fill(ctx, phi::vectorize(Q.dims()), 0); + } + + if (dQ.initialized()) { + dQ_prime = Add(ctx, dQ_prime, dQ); + } + dX = m_gt_n_case(ctx, dQ_prime, dR_tmp, A, Q, U); + dY = Matmul(ctx, Q, dV); + // Concatenate dX and dY to get dA. + auto dA_tmp = Concat(ctx, {&dX, &dY}, -1); + phi::Copy(ctx, dA_tmp, dA.place(), false, &dA); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/qr_kernel_impl.h b/paddle/phi/kernels/impl/qr_kernel_impl.h index 1d64117922..924676cc4c 100644 --- a/paddle/phi/kernels/impl/qr_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_kernel_impl.h @@ -50,225 +50,6 @@ void BatchedOrgqr(const DeviceContext& dev_ctx, int a_stride, int tau_stride); -template <> -void BatchedGeqrf(const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - float* a, - int lda, - float* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cusolverDnSgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); - - DenseTensor* workspace = new DenseTensor(); - workspace->Resize(make_ddim({lwork})); - float* workspace_ptr = dev_ctx.template Alloc(workspace); - - DenseTensor* info = new DenseTensor(); - info->Resize(make_ddim({1})); - int* info_d = dev_ctx.template Alloc(info); - - for (int i = 0; i < batch_size; ++i) { - float* a_working_ptr = &a[i * a_stride]; - float* tau_working_ptr = &tau[i * tau_stride]; - // compute geqrf - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSgeqrf(handle, - m, - n, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - paddle::memory::Copy(phi::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - phi::errors::PreconditionNotMet( - "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); - } -} - -template <> -void BatchedGeqrf(const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - double* a, - int lda, - double* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::cusolverDnDgeqrf_bufferSize(handle, m, n, a, lda, &lwork)); - - DenseTensor* workspace = new DenseTensor(); - workspace->Resize(make_ddim({lwork})); - double* workspace_ptr = dev_ctx.template Alloc(workspace); - - DenseTensor* info = new DenseTensor(); - info->Resize(make_ddim({1})); - int* info_d = dev_ctx.template Alloc(info); - - for (int i = 0; i < batch_size; ++i) { - double* a_working_ptr = &a[i * a_stride]; - double* tau_working_ptr = &tau[i * tau_stride]; - // compute geqrf - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDgeqrf(handle, - m, - n, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - paddle::memory::Copy(phi::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - phi::errors::PreconditionNotMet( - "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); - } -} - -template <> -void BatchedOrgqr(const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - int k, - float* a, - int lda, - float* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr_bufferSize( - handle, m, n, k, a, lda, tau, &lwork)); - - DenseTensor* workspace = new DenseTensor(); - workspace->Resize(make_ddim({lwork})); - float* workspace_ptr = dev_ctx.template Alloc(workspace); - - DenseTensor* info = new DenseTensor(); - info->Resize(make_ddim({1})); - int* info_d = dev_ctx.template Alloc(info); - - for (int i = 0; i < batch_size; ++i) { - float* a_working_ptr = &a[i * a_stride]; - float* tau_working_ptr = &tau[i * tau_stride]; - // compute orggr - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnSorgqr(handle, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - paddle::memory::Copy(phi::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - phi::errors::PreconditionNotMet( - "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); - } -} - -template <> -void BatchedOrgqr(const GPUContext& dev_ctx, - int batch_size, - int m, - int n, - int k, - double* a, - int lda, - double* tau, - int a_stride, - int tau_stride) { - int lwork = 0; - - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr_bufferSize( - handle, m, n, k, a, lda, tau, &lwork)); - - DenseTensor* workspace = new DenseTensor(); - workspace->Resize(make_ddim({lwork})); - double* workspace_ptr = dev_ctx.template Alloc(workspace); - - DenseTensor* info = new DenseTensor(); - info->Resize(make_ddim({1})); - int* info_d = dev_ctx.template Alloc(info); - - for (int i = 0; i < batch_size; ++i) { - double* a_working_ptr = &a[i * a_stride]; - double* tau_working_ptr = &tau[i * tau_stride]; - // compute orggr - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnDorgqr(handle, - m, - n, - k, - a_working_ptr, - lda, - tau_working_ptr, - workspace_ptr, - lwork, - info_d)); - // Do we need synchronized here? - // check the error info - int info_h; - paddle::memory::Copy(phi::CPUPlace(), - &info_h, - dev_ctx.GetPlace(), - info_d, - sizeof(int), - dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - info_h, - 0, - phi::errors::PreconditionNotMet( - "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); - } -} #endif } // namespace phi diff --git a/paddle/phi/kernels/qr_grad_kernel.h b/paddle/phi/kernels/qr_grad_kernel.h new file mode 100644 index 0000000000..51d24dc533 --- /dev/null +++ b/paddle/phi/kernels/qr_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void QrGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& q, + const DenseTensor& r, + const DenseTensor& q_grad, + const DenseTensor& r_grad, + const std::string& mode, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/triangular_solve_kernel.h b/paddle/phi/kernels/triangular_solve_kernel.h index 833de3f843..3aa5b93124 100644 --- a/paddle/phi/kernels/triangular_solve_kernel.h +++ b/paddle/phi/kernels/triangular_solve_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/binary.h" namespace phi { @@ -27,4 +28,19 @@ void TriangularSolveKernel(const Context& dev_ctx, bool unitriangular, DenseTensor* out); +template +DenseTensor TriangularSolve(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + bool upper, + bool transpose, + bool unitriangular) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + TriangularSolveInferMeta(x, y, upper, transpose, unitriangular, &meta_out); + TriangularSolveKernel( + ctx, x, y, upper, transpose, unitriangular, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/tril_triu_kernel.h b/paddle/phi/kernels/tril_triu_kernel.h index 4daa84e25c..8d4c44c5b3 100644 --- a/paddle/phi/kernels/tril_triu_kernel.h +++ b/paddle/phi/kernels/tril_triu_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -25,4 +26,16 @@ void TrilTriuKernel(const Context& ctx, bool lower, DenseTensor* out); +template +DenseTensor TrilTriu(const Context& ctx, + const DenseTensor& x, + int diagonal, + bool lower) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + TrilTriuInferMeta(x, diagonal, lower, &meta_out); + TrilTriuKernel(ctx, x, diagonal, lower, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/ops/compat/qr_sig.cc b/paddle/phi/ops/compat/qr_sig.cc index dd424d590e..dbe1cd8643 100644 --- a/paddle/phi/ops/compat/qr_sig.cc +++ b/paddle/phi/ops/compat/qr_sig.cc @@ -20,6 +20,12 @@ KernelSignature QrOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("qr", {"X"}, {"mode"}, {"Q", "R"}); } +KernelSignature QrGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "qr_grad", {"X", "Q", "R", "Q@GRAD", "R@GRAD"}, {"mode"}, {"X@GRAD"}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(qr, phi::QrOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(qr_grad, phi::QrGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_qr_op.py b/python/paddle/fluid/tests/unittests/test_qr_op.py index 338b08d1aa..290ce39151 100644 --- a/python/paddle/fluid/tests/unittests/test_qr_op.py +++ b/python/paddle/fluid/tests/unittests/test_qr_op.py @@ -28,6 +28,7 @@ class TestQrOp(OpTest): def setUp(self): paddle.enable_static() + self.python_api = paddle.linalg.qr np.random.seed(7) self.op_type = "qr" a, q, r = self.get_input_and_output() @@ -72,10 +73,11 @@ class TestQrOp(OpTest): return a, q, r def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): self.check_grad(['X'], ['Q', 'R'], + check_eager=True, numeric_grad_delta=1e-5, max_relative_error=1e-6) @@ -175,7 +177,7 @@ class TestQrAPI(unittest.TestCase): tensor_shapes = [ (3, 5), (5, 5), - (5, 3), # 2-dim Tensors + (5, 3), # 2-dim Tensors (2, 3, 5), (3, 5, 5), (4, 5, 3), # 3-dim Tensors @@ -253,7 +255,7 @@ class TestQrAPI(unittest.TestCase): tensor_shapes = [ (3, 5), (5, 5), - (5, 3), # 2-dim Tensors + (5, 3), # 2-dim Tensors (2, 3, 5), (3, 5, 5), (4, 5, 3), # 3-dim Tensors diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index ef557e783a..4931f3ffbb 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1998,7 +1998,13 @@ def qr(x, mode="reduced", name=None): # one can verify : X = Q * R ; """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + q, r = _C_ops.final_state_qr(x, mode) + if mode == "r": + return r + else: + return q, r + if _in_legacy_dygraph(): q, r = _C_ops.qr(x, 'mode', mode) if mode == "r": return r -- GitLab