From b9d4285b7eea3abe1eaafada9560401c6df92362 Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Mon, 14 Mar 2022 16:11:05 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90phi=E3=80=91migrate=20matrix=5Frank=20?= =?UTF-8?q?to=20phi=20(#40074)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * migrate matrix_rank to phi * migrate eigh and matrix_rank to phi * fix matrix_rank * optimize code * move matrix_rank to phi * add max functor * migrate matrix_rank to phi * optimize code --- paddle/fluid/operators/matrix_rank_op.cc | 139 +----- paddle/fluid/operators/matrix_rank_op.cu | 316 ------------- paddle/phi/kernels/cpu/matrix_rank_kernel.cc | 43 ++ .../phi/kernels/cpu/matrix_rank_tol_kernel.cc | 178 +++++++ paddle/phi/kernels/gpu/matrix_rank_kernel.cu | 52 +++ .../phi/kernels/gpu/matrix_rank_tol_kernel.cu | 438 ++++++++++++++++++ .../kernels/impl/matrix_rank_kernel_impl.h} | 28 +- paddle/phi/kernels/matrix_rank_kernel.h | 29 ++ paddle/phi/kernels/matrix_rank_tol_kernel.h | 29 ++ paddle/phi/ops/compat/matrix_rank_sig.cc | 38 ++ 10 files changed, 828 insertions(+), 462 deletions(-) delete mode 100644 paddle/fluid/operators/matrix_rank_op.cu create mode 100644 paddle/phi/kernels/cpu/matrix_rank_kernel.cc create mode 100644 paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc create mode 100644 paddle/phi/kernels/gpu/matrix_rank_kernel.cu create mode 100644 paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu rename paddle/{fluid/operators/matrix_rank_op.h => phi/kernels/impl/matrix_rank_kernel_impl.h} (72%) create mode 100644 paddle/phi/kernels/matrix_rank_kernel.h create mode 100644 paddle/phi/kernels/matrix_rank_tol_kernel.h create mode 100644 paddle/phi/ops/compat/matrix_rank_sig.cc diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc index 1f04875c220..e7d08b65973 100644 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/matrix_rank_op.h" #include #include #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" @@ -70,9 +69,9 @@ class MatrixRankeOp : public framework::OperatorWithKernel { std::vector x_batch_dims_array(max_dim); std::vector tol_dims_array(max_dim); std::vector out_dims_array(max_dim); - GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(), - tol_dims_array.data(), out_dims_array.data(), - max_dim, axis); + phi::funcs::GetBroadcastDimsArrays( + dim_x_batch, dim_tol, x_batch_dims_array.data(), + tol_dims_array.data(), out_dims_array.data(), max_dim, axis); ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array)); } } else { @@ -115,141 +114,9 @@ class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker { } }; -template -void BatchEigenvalues(const T* x_data, T* eigenvalues_data, int batches, - int rows, int cols, int k) { - // Eigen::Matrix API need non-const pointer. - T* input = const_cast(x_data); - int stride = rows * cols; - for (int i = 0; i < batches; i++) { - auto m = Eigen::Map< - Eigen::Matrix>( - input + i * stride, rows, rows); - Eigen::SelfAdjointEigenSolver< - Eigen::Matrix> - eigen_solver(m); - auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs(); - for (int j = 0; j < k; j++) { - *(eigenvalues_data + i * k + j) = eigenvalues[j]; - } - } -} - -template -void BatchSVD(const T* x_data, T* eigenvalues_data, int batches, int rows, - int cols, int k) { - // Eigen::Matrix API need non-const pointer. - T* input = const_cast(x_data); - int stride = rows * cols; - Eigen::BDCSVD< - Eigen::Matrix> - svd; - for (int i = 0; i < batches; i++) { - auto m = Eigen::Map< - Eigen::Matrix>( - input + i * stride, rows, cols); - svd.compute(m); - auto res_s = svd.singularValues(); - for (int j = 0; j < k; j++) { - eigenvalues_data[i * k + j] = res_s[j]; - } - } -} - -template -class MatrixRankCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* x = context.Input("X"); - auto* x_data = x->data(); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - bool hermitian = context.Attr("hermitian"); - - auto dim_x = x->dims(); - auto dim_out = out->dims(); - int rows = dim_x[dim_x.size() - 2]; - int cols = dim_x[dim_x.size() - 1]; - int k = std::min(rows, cols); - auto numel = x->numel(); - int batches = numel / (rows * cols); - - bool use_default_tol = context.Attr("use_default_tol"); - const Tensor* atol_tensor = nullptr; - Tensor temp_tensor; - T rtol_T = 0; - if (use_default_tol) { - framework::TensorFromVector(std::vector{0}, - context.device_context(), &temp_tensor); - atol_tensor = &temp_tensor; - rtol_T = std::numeric_limits::epsilon() * std::max(rows, cols); - } else if (context.HasInput("TolTensor")) { - atol_tensor = context.Input("TolTensor"); - } else { - framework::TensorFromVector(std::vector{context.Attr("tol")}, - context.device_context(), &temp_tensor); - atol_tensor = &temp_tensor; - } - - Tensor eigenvalue_tensor; - auto* eigenvalue_data = eigenvalue_tensor.mutable_data( - detail::GetEigenvalueDim(dim_x, k), context.GetPlace()); - if (hermitian) { - BatchEigenvalues(x_data, eigenvalue_data, batches, rows, cols, k); - } else { - BatchSVD(x_data, eigenvalue_data, batches, rows, cols, k); - } - - auto dito_T = - math::DeviceIndependenceTensorOperations( - context); - std::vector max_eigenvalue_shape = - phi::vectorize(detail::RemoveLastDim(eigenvalue_tensor.dims())); - Tensor max_eigenvalue_tensor = - dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape); - - Tensor temp_rtol_tensor; - framework::TensorFromVector(std::vector{rtol_T}, &temp_rtol_tensor); - Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor); - Tensor tol_tensor; - tol_tensor.mutable_data(dim_out, context.GetPlace()); - ElementwiseComputeEx, platform::CPUDeviceContext, - T, T>(context, atol_tensor, &rtol_tensor, -1, - GreaterElementFunctor(), &tol_tensor); - - tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); - - Tensor compare_result; - compare_result.mutable_data(detail::NewAxisDim(dim_out, k), - context.GetPlace()); - - int axis = -1; - if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { - ElementwiseComputeEx, - platform::CPUDeviceContext, T, int>( - context, &eigenvalue_tensor, &tol_tensor, axis, - phi::funcs::GreaterThanFunctor(), &compare_result); - } else { - ElementwiseComputeEx, - platform::CPUDeviceContext, T, int>( - context, &eigenvalue_tensor, &tol_tensor, axis, - phi::funcs::LessThanFunctor(), &compare_result); - } - auto dito_int = - math::DeviceIndependenceTensorOperations(context); - std::vector result_shape = phi::vectorize(dim_out); - Tensor result = dito_int.ReduceSum(compare_result, result_shape); - out->ShareDataWith(result); - } -}; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(matrix_rank, ops::MatrixRankeOp, ops::MatrixRankeOpMaker); - -REGISTER_OP_CPU_KERNEL(matrix_rank, ops::MatrixRankCPUKernel, - ops::MatrixRankCPUKernel); diff --git a/paddle/fluid/operators/matrix_rank_op.cu b/paddle/fluid/operators/matrix_rank_op.cu deleted file mode 100644 index dccd716022d..00000000000 --- a/paddle/fluid/operators/matrix_rank_op.cu +++ /dev/null @@ -1,316 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver -#include -#include -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/matrix_rank_op.h" -#include "paddle/fluid/operators/svd_helper.h" -#include "paddle/fluid/platform/dynload/cusolver.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/compare_functors.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -namespace detail { -DDim GetUDDim(const DDim& x_dim, int k) { - auto x_vec = phi::vectorize(x_dim); - x_vec[x_vec.size() - 1] = k; - return phi::make_ddim(x_vec); -} - -DDim GetVHDDim(const DDim& x_dim, int k) { - auto x_vec = phi::vectorize(x_dim); - x_vec[x_vec.size() - 2] = k; - return phi::make_ddim(x_vec); -} -} // namespace detail - -template -class MatrixRankGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = - context.template device_context(); - - const Tensor* x = context.Input("X"); - auto* x_data = x->data(); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - bool hermitian = context.Attr("hermitian"); - - auto dim_x = x->dims(); - auto dim_out = out->dims(); - int rows = dim_x[dim_x.size() - 2]; - int cols = dim_x[dim_x.size() - 1]; - int k = std::min(rows, cols); - auto numel = x->numel(); - int batches = numel / (rows * cols); - - bool use_default_tol = context.Attr("use_default_tol"); - const Tensor* atol_tensor = nullptr; - Tensor temp_tensor; - T rtol_T = 0; - if (use_default_tol) { - framework::TensorFromVector(std::vector{0}, - context.device_context(), &temp_tensor); - atol_tensor = &temp_tensor; - rtol_T = std::numeric_limits::epsilon() * std::max(rows, cols); - } else if (context.HasInput("TolTensor")) { - atol_tensor = context.Input("TolTensor"); - } else { - framework::TensorFromVector(std::vector{context.Attr("tol")}, - context.device_context(), &temp_tensor); - atol_tensor = &temp_tensor; - } - - // Must Copy X once, because the gesvdj will destory the content when exit. - Tensor x_tmp; - paddle::framework::TensorCopy(*x, context.GetPlace(), &x_tmp); - auto info = memory::Alloc(dev_ctx, sizeof(int) * batches); - int* info_ptr = reinterpret_cast(info->ptr()); - - Tensor eigenvalue_tensor; - auto* eigenvalue_data = eigenvalue_tensor.mutable_data( - detail::GetEigenvalueDim(dim_x, k), context.GetPlace()); - if (hermitian) { - SyevjBatched(dev_ctx, batches, rows, x_tmp.data(), eigenvalue_data, - info_ptr); - platform::ForRange for_range( - dev_ctx, eigenvalue_tensor.numel()); - phi::funcs::AbsFunctor functor(eigenvalue_data, eigenvalue_data, - eigenvalue_tensor.numel()); - for_range(functor); - } else { - Tensor U, VH; - auto* u_data = - U.mutable_data(detail::GetUDDim(dim_x, k), context.GetPlace()); - auto* vh_data = - VH.mutable_data(detail::GetVHDDim(dim_x, k), context.GetPlace()); - GesvdjBatched(dev_ctx, batches, cols, rows, k, x_tmp.data(), vh_data, - u_data, eigenvalue_data, info_ptr, 1); - } - - auto dito_T = - math::DeviceIndependenceTensorOperations(context); - std::vector max_eigenvalue_shape = - phi::vectorize(detail::RemoveLastDim(eigenvalue_tensor.dims())); - Tensor max_eigenvalue_tensor = - dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape); - Tensor temp_rtol_tensor; - framework::TensorFromVector(std::vector{rtol_T}, - context.device_context(), &temp_rtol_tensor); - Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor); - Tensor tol_tensor; - tol_tensor.mutable_data(dim_out, context.GetPlace()); - ElementwiseComputeEx, platform::CUDADeviceContext, - T, T>(context, atol_tensor, &rtol_tensor, -1, - GreaterElementFunctor(), &tol_tensor); - - tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); - - Tensor compare_result; - compare_result.mutable_data(detail::NewAxisDim(dim_out, k), - context.GetPlace()); - int axis = -1; - ElementwiseComputeEx, - platform::CUDADeviceContext, T, int64_t>( - context, &eigenvalue_tensor, &tol_tensor, axis, - phi::funcs::GreaterThanFunctor(), &compare_result); - auto dito_int = - math::DeviceIndependenceTensorOperations(context); - std::vector result_shape = phi::vectorize(dim_out); - Tensor result = dito_int.ReduceSum(compare_result, result_shape); - out->ShareDataWith(result); - } - - void GesvdjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize, - int m, int n, int k, T* A, T* U, T* V, T* S, int* info, - int thin_UV = 1) const; - - void SyevjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize, - int n, T* A, T* W, int* info) const; -}; - -template <> -void MatrixRankGPUKernel::GesvdjBatched( - const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, - int k, float* A, float* U, float* V, float* S, int* info, - int thin_UV) const { - // do not compute singular vectors - const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; - gesvdjInfo_t gesvdj_params = NULL; - int lda = m; - int ldu = m; - int ldt = n; - int lwork = 0; - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgesvdj_bufferSize( - handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, - gesvdj_params)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); - int stride_A = lda * n; - int stride_U = ldu * (thin_UV ? k : m); - int stride_V = ldt * (thin_UV ? k : n); - for (int i = 0; i < batchSize; i++) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgesvdj( - handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, - U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, - info, gesvdj_params)); - int error_info; - memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info, - sizeof(int), dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - error_info, 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); - } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); -} - -template <> -void MatrixRankGPUKernel::GesvdjBatched( - const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, - int k, double* A, double* U, double* V, double* S, int* info, - int thin_UV) const { - // do not compute singular vectors - const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; - gesvdjInfo_t gesvdj_params = NULL; - int lda = m; - int ldu = m; - int ldt = n; - int lwork = 0; - auto handle = dev_ctx.cusolver_dn_handle(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgesvdj_bufferSize( - handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, - gesvdj_params)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); - int stride_A = lda * n; - int stride_U = ldu * (thin_UV ? k : m); - int stride_V = ldt * (thin_UV ? k : n); - for (int i = 0; i < batchSize; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgesvdj( - handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, - U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, - info, gesvdj_params)); - // check the error info - int error_info; - memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info, - sizeof(int), dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - error_info, 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); - } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); -} - -template <> -void MatrixRankGPUKernel::SyevjBatched( - const platform::CUDADeviceContext& dev_ctx, int batchSize, int n, float* A, - float* W, int* info) const { - auto handle = dev_ctx.cusolver_dn_handle(); - // Compute eigenvalues only - const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; - // matrix is saved as column-major in cusolver. - // numpy and torch use lower triangle to compute eigenvalues, so here use - // upper triangle - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; - int lda = n; - int stride_A = lda * n; - int lwork = 0; - syevjInfo_t params = NULL; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnCreateSyevjInfo(¶ms)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj_bufferSize( - handle, jobz, uplo, n, A, lda, W, &lwork, params)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); - float* workspace_ptr = reinterpret_cast(workspace->ptr()); - for (int i = 0; i < batchSize; i++) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj( - handle, jobz, uplo, n, A + stride_A * i, lda, W + n * i, workspace_ptr, - lwork, info, params)); - - int error_info; - memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info, - sizeof(int), dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - error_info, 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver eigenvalues is not zero. [%d]", i, - error_info)); - } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDestroySyevjInfo(params)); -} - -template <> -void MatrixRankGPUKernel::SyevjBatched( - const platform::CUDADeviceContext& dev_ctx, int batchSize, int n, double* A, - double* W, int* info) const { - auto handle = dev_ctx.cusolver_dn_handle(); - // Compute eigenvalues only - const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; - // upper triangle of A is stored - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; - int lda = n; - int stride_A = lda * n; - int lwork = 0; - syevjInfo_t params = NULL; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnCreateSyevjInfo(¶ms)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDsyevj_bufferSize( - handle, jobz, uplo, n, A, lda, W, &lwork, params)); - auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); - double* workspace_ptr = reinterpret_cast(workspace->ptr()); - - for (int i = 0; i < batchSize; i++) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDsyevj( - handle, jobz, uplo, n, A + stride_A * i, lda, W + n * i, workspace_ptr, - lwork, info, params)); - int error_info; - memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info, - sizeof(int), dev_ctx.stream()); - PADDLE_ENFORCE_EQ( - error_info, 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: CUSolver eigenvalues is not zero. [%d]", i, - error_info)); - } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cusolverDnDestroySyevjInfo(params)); -} - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(matrix_rank, ops::MatrixRankGPUKernel, - ops::MatrixRankGPUKernel); -#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/cpu/matrix_rank_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_kernel.cc new file mode 100644 index 00000000000..5e13abe8aed --- /dev/null +++ b/paddle/phi/kernels/cpu/matrix_rank_kernel.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/matrix_rank_kernel.h" +#include "paddle/phi/kernels/matrix_rank_tol_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { + +template +void MatrixRankKernel(const Context& dev_ctx, + const DenseTensor& x, + float tol, + bool use_default_tol, + bool hermitian, + DenseTensor* out) { + DenseTensor atol_tensor; + if (use_default_tol) { + atol_tensor = phi::Full(dev_ctx, {1}, static_cast(0)); + } else { + atol_tensor = phi::Full(dev_ctx, {1}, static_cast(tol)); + } + MatrixRankTolKernel( + dev_ctx, x, atol_tensor, use_default_tol, hermitian, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + matrix_rank, CPU, ALL_LAYOUT, phi::MatrixRankKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc new file mode 100644 index 00000000000..210750da1e0 --- /dev/null +++ b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/matrix_rank_tol_kernel.h" + +#include +#include +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/compare_functors.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h" +#include "paddle/phi/kernels/math_kernel.h" +#include "paddle/phi/kernels/reduce_max_kernel.h" + +namespace phi { + +template +void BatchEigenvalues(const T* x_data, + T* eigenvalues_data, + int batches, + int rows, + int cols, + int k) { + // Eigen::Matrix API need non-const pointer. + T* input = const_cast(x_data); + int stride = rows * cols; + for (int i = 0; i < batches; i++) { + auto m = Eigen::Map< + Eigen::Matrix>( + input + i * stride, rows, rows); + Eigen::SelfAdjointEigenSolver< + Eigen::Matrix> + eigen_solver(m); + auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs(); + for (int j = 0; j < k; j++) { + *(eigenvalues_data + i * k + j) = eigenvalues[j]; + } + } +} + +template +void BatchSVD(const T* x_data, + T* eigenvalues_data, + int batches, + int rows, + int cols, + int k) { + // Eigen::Matrix API need non-const pointer. + T* input = const_cast(x_data); + int stride = rows * cols; + Eigen::BDCSVD< + Eigen::Matrix> + svd; + for (int i = 0; i < batches; i++) { + auto m = Eigen::Map< + Eigen::Matrix>( + input + i * stride, rows, cols); + svd.compute(m); + auto res_s = svd.singularValues(); + for (int j = 0; j < k; j++) { + eigenvalues_data[i * k + j] = res_s[j]; + } + } +} + +template +void MatrixRankTolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + DenseTensor* out) { + auto* x_data = x.data(); + dev_ctx.template Alloc(out); + auto dim_x = x.dims(); + auto dim_out = out->dims(); + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + int k = std::min(rows, cols); + auto numel = x.numel(); + int batches = numel / (rows * cols); + + T rtol_T = 0; + + if (use_default_tol) { + rtol_T = std::numeric_limits::epsilon() * std::max(rows, cols); + } + + DenseTensor eigenvalue_tensor; + eigenvalue_tensor.Resize(detail::GetEigenvalueDim(dim_x, k)); + auto* eigenvalue_data = dev_ctx.template Alloc(&eigenvalue_tensor); + + if (hermitian) { + BatchEigenvalues(x_data, eigenvalue_data, batches, rows, cols, k); + } else { + BatchSVD(x_data, eigenvalue_data, batches, rows, cols, k); + } + + DenseTensor max_eigenvalue_tensor; + max_eigenvalue_tensor.Resize(detail::RemoveLastDim(eigenvalue_tensor.dims())); + dev_ctx.template Alloc(&max_eigenvalue_tensor); + phi::MaxKernel(dev_ctx, + eigenvalue_tensor, + std::vector{-1}, + false, + &max_eigenvalue_tensor); + + DenseTensor temp_rtol_tensor; + temp_rtol_tensor = + phi::Full(dev_ctx, {1}, static_cast(rtol_T)); + + DenseTensor rtol_tensor = + phi::Multiply(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor); + + DenseTensor tol_tensor; + tol_tensor.Resize(dim_out); + dev_ctx.template Alloc(&tol_tensor); + funcs::ElementwiseCompute, T, T>( + dev_ctx, + atol_tensor, + rtol_tensor, + -1, + GreaterElementFunctor(), + &tol_tensor); + + tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); + + DenseTensor compare_result; + compare_result.Resize(detail::NewAxisDim(dim_out, k)); + dev_ctx.template Alloc(&compare_result); + int axis = -1; + if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { + funcs::ElementwiseCompute, T, int>( + dev_ctx, + eigenvalue_tensor, + tol_tensor, + axis, + funcs::GreaterThanFunctor(), + &compare_result); + } else { + funcs::ElementwiseCompute, T, int>( + dev_ctx, + eigenvalue_tensor, + tol_tensor, + axis, + funcs::LessThanFunctor(), + &compare_result); + } + + phi::SumKernel(dev_ctx, + compare_result, + std::vector{-1}, + compare_result.dtype(), + false, + out); +} +} // namespace phi + +PD_REGISTER_KERNEL( + matrix_rank_tol, CPU, ALL_LAYOUT, phi::MatrixRankTolKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/matrix_rank_kernel.cu b/paddle/phi/kernels/gpu/matrix_rank_kernel.cu new file mode 100644 index 00000000000..9b889a9b4c0 --- /dev/null +++ b/paddle/phi/kernels/gpu/matrix_rank_kernel.cu @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/phi/kernels/matrix_rank_kernel.h" +#include "paddle/phi/kernels/matrix_rank_tol_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" + +namespace phi { + +template +void MatrixRankKernel(const Context& dev_ctx, + const DenseTensor& x, + float tol, + bool use_default_tol, + bool hermitian, + DenseTensor* out) { + DenseTensor atol_tensor; + if (use_default_tol) { + atol_tensor = phi::Full(dev_ctx, {1}, static_cast(0)); + } else { + atol_tensor = phi::Full(dev_ctx, {1}, static_cast(tol)); + } + MatrixRankTolKernel( + dev_ctx, x, atol_tensor, use_default_tol, hermitian, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(matrix_rank, // cuda_only + GPU, + ALL_LAYOUT, + phi::MatrixRankKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu new file mode 100644 index 00000000000..ccd9f714956 --- /dev/null +++ b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu @@ -0,0 +1,438 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/phi/kernels/matrix_rank_tol_kernel.h" + +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/abs_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/compare_functors.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h" +#include "paddle/phi/kernels/math_kernel.h" +#include "paddle/phi/kernels/reduce_max_kernel.h" + +namespace phi { + +template +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + T* A, + T* U, + T* V, + T* S, + int* info, + int thin_UV = 1); + +template +void SyevjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int n, + T* A, + T* W, + int* info); + +template <> +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + float* A, + float* U, + float* V, + float* S, + int* info, + int thin_UV) { + // do not compute singular vectors + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnSgesvdj_bufferSize(handle, + jobz, + thin_UV, + m, + n, + A, + lda, + S, + U, + ldu, + V, + ldt, + &lwork, + gesvdj_params)); + auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; i++) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSgesvdj(handle, + jobz, + thin_UV, + m, + n, + A + stride_A * i, + lda, + S + k * i, + U + stride_U * i, + ldu, + V + stride_V * i, + ldt, + workspace_ptr, + lwork, + info, + gesvdj_params)); + int error_info; + paddle::memory::Copy(phi::CPUPlace(), + &error_info, + dev_ctx.GetPlace(), + info, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template <> +void GesvdjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int m, + int n, + int k, + double* A, + double* U, + double* V, + double* S, + int* info, + int thin_UV) { + // do not compute singular vectors + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnDgesvdj_bufferSize(handle, + jobz, + thin_UV, + m, + n, + A, + lda, + S, + U, + ldu, + V, + ldt, + &lwork, + gesvdj_params)); + auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgesvdj(handle, + jobz, + thin_UV, + m, + n, + A + stride_A * i, + lda, + S + k * i, + U + stride_U * i, + ldu, + V + stride_V * i, + ldt, + workspace_ptr, + lwork, + info, + gesvdj_params)); + // check the error info + int error_info; + paddle::memory::Copy(phi::CPUPlace(), + &error_info, + dev_ctx.GetPlace(), + info, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template <> +void SyevjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int n, + float* A, + float* W, + int* info) { + auto handle = dev_ctx.cusolver_dn_handle(); + // Compute eigenvalues only + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + // matrix is saved as column-major in cusolver. + // numpy and torch use lower triangle to compute eigenvalues, so here use + // upper triangle + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + int lda = n; + int stride_A = lda * n; + int lwork = 0; + syevjInfo_t params = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnCreateSyevjInfo(¶ms)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize( + handle, jobz, uplo, n, A, lda, W, &lwork, params)); + auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + for (int i = 0; i < batchSize; i++) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj(handle, + jobz, + uplo, + n, + A + stride_A * i, + lda, + W + n * i, + workspace_ptr, + lwork, + info, + params)); + + int error_info; + paddle::memory::Copy(phi::CPUPlace(), + &error_info, + dev_ctx.GetPlace(), + info, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver eigenvalues is not zero. [%d]", + i, + error_info)); + } + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDestroySyevjInfo(params)); +} + +template <> +void SyevjBatched(const phi::GPUContext& dev_ctx, + int batchSize, + int n, + double* A, + double* W, + int* info) { + auto handle = dev_ctx.cusolver_dn_handle(); + // Compute eigenvalues only + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + // upper triangle of A is stored + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + int lda = n; + int stride_A = lda * n; + int lwork = 0; + syevjInfo_t params = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnCreateSyevjInfo(¶ms)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevj_bufferSize( + handle, jobz, uplo, n, A, lda, W, &lwork, params)); + auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + + for (int i = 0; i < batchSize; i++) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevj(handle, + jobz, + uplo, + n, + A + stride_A * i, + lda, + W + n * i, + workspace_ptr, + lwork, + info, + params)); + int error_info; + paddle::memory::Copy(phi::CPUPlace(), + &error_info, + dev_ctx.GetPlace(), + info, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: CUSolver eigenvalues is not zero. [%d]", + i, + error_info)); + } + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDestroySyevjInfo(params)); +} + +template +void MatrixRankTolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + DenseTensor* out) { + auto* x_data = x.data(); + dev_ctx.template Alloc(out); + + auto dim_x = x.dims(); + auto dim_out = out->dims(); + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + int k = std::min(rows, cols); + auto numel = x.numel(); + int batches = numel / (rows * cols); + + T rtol_T = 0; + if (use_default_tol) { + rtol_T = std::numeric_limits::epsilon() * std::max(rows, cols); + } + + // Must Copy X once, because the gesvdj will destory the content when exit. + DenseTensor x_tmp; + paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), &x_tmp); + auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batches); + int* info_ptr = reinterpret_cast(info->ptr()); + + DenseTensor eigenvalue_tensor; + eigenvalue_tensor.Resize(detail::GetEigenvalueDim(dim_x, k)); + auto* eigenvalue_data = dev_ctx.template Alloc(&eigenvalue_tensor); + + if (hermitian) { + SyevjBatched( + dev_ctx, batches, rows, x_tmp.data(), eigenvalue_data, info_ptr); + + phi::AbsKernel(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor); + + } else { + DenseTensor U, VH; + U.Resize(detail::GetUDDim(dim_x, k)); + VH.Resize(detail::GetVHDDim(dim_x, k)); + auto* u_data = dev_ctx.template Alloc(&U); + auto* vh_data = dev_ctx.template Alloc(&VH); + GesvdjBatched(dev_ctx, + batches, + cols, + rows, + k, + x_tmp.data(), + vh_data, + u_data, + eigenvalue_data, + info_ptr, + 1); + } + + DenseTensor max_eigenvalue_tensor; + dev_ctx.template Alloc(&max_eigenvalue_tensor); + max_eigenvalue_tensor.Resize(detail::RemoveLastDim(eigenvalue_tensor.dims())); + + phi::MaxKernel(dev_ctx, + eigenvalue_tensor, + std::vector{-1}, + false, + &max_eigenvalue_tensor); + + DenseTensor temp_rtol_tensor; + temp_rtol_tensor = + phi::Full(dev_ctx, {1}, static_cast(rtol_T)); + + DenseTensor rtol_tensor = + phi::Multiply(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor); + DenseTensor tol_tensor; + tol_tensor.Resize(dim_out); + dev_ctx.template Alloc(&tol_tensor); + + funcs::ElementwiseCompute, T, T>( + dev_ctx, + atol_tensor, + rtol_tensor, + -1, + GreaterElementFunctor(), + &tol_tensor); + + tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); + + DenseTensor compare_result; + compare_result.Resize(detail::NewAxisDim(dim_out, k)); + dev_ctx.template Alloc(&compare_result); + + int axis = -1; + funcs::ElementwiseCompute, T, int64_t>( + dev_ctx, + eigenvalue_tensor, + tol_tensor, + axis, + funcs::GreaterThanFunctor(), + &compare_result); + + phi::SumKernel(dev_ctx, + compare_result, + std::vector{-1}, + compare_result.dtype(), + false, + out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(matrix_rank_tol, // cuda_only + GPU, + ALL_LAYOUT, + phi::MatrixRankTolKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/matrix_rank_op.h b/paddle/phi/kernels/impl/matrix_rank_kernel_impl.h similarity index 72% rename from paddle/fluid/operators/matrix_rank_op.h rename to paddle/phi/kernels/impl/matrix_rank_kernel_impl.h index 93545fd3103..b0dd76a17ee 100644 --- a/paddle/fluid/operators/matrix_rank_op.h +++ b/paddle/phi/kernels/impl/matrix_rank_kernel_impl.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,14 +13,11 @@ // limitations under the License. #pragma once -#include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/phi/core/ddim.h" -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; -using DDim = framework::DDim; +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/matrix_rank_kernel.h" + +namespace phi { namespace detail { static DDim GetEigenvalueDim(const DDim& dim, int k) { @@ -44,6 +41,18 @@ static DDim RemoveLastDim(const DDim& dim) { vec.erase(vec.end() - 1, vec.end()); return phi::make_ddim(vec); } + +static DDim GetUDDim(const DDim& x_dim, int k) { + auto x_vec = phi::vectorize(x_dim); + x_vec[x_vec.size() - 1] = k; + return phi::make_ddim(x_vec); +} + +static DDim GetVHDDim(const DDim& x_dim, int k) { + auto x_vec = phi::vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + return phi::make_ddim(x_vec); +} } // namespace detail template @@ -57,5 +66,4 @@ struct GreaterElementFunctor { } }; -} // namespace operators -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/kernels/matrix_rank_kernel.h b/paddle/phi/kernels/matrix_rank_kernel.h new file mode 100644 index 00000000000..6edea2723e5 --- /dev/null +++ b/paddle/phi/kernels/matrix_rank_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MatrixRankKernel(const Context& dev_ctx, + const DenseTensor& x, + float tol, + bool use_default_tol, + bool hermitian, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/matrix_rank_tol_kernel.h b/paddle/phi/kernels/matrix_rank_tol_kernel.h new file mode 100644 index 00000000000..351358dfa04 --- /dev/null +++ b/paddle/phi/kernels/matrix_rank_tol_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MatrixRankTolKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/matrix_rank_sig.cc b/paddle/phi/ops/compat/matrix_rank_sig.cc new file mode 100644 index 00000000000..40dc29579b4 --- /dev/null +++ b/paddle/phi/ops/compat/matrix_rank_sig.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +// we have to return every specific KernelSignature for infrt now +KernelSignature MatrixRankOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("TolTensor")) { + return KernelSignature("matrix_rank_tol", + {"X", "TolTensor"}, + {"use_default_tol", "hermitian"}, + {"Out"}); + } else { + return KernelSignature("matrix_rank", + {"X"}, + { + "tol", "use_default_tol", "hermitian", + }, + {"Out"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(matrix_rank, phi::MatrixRankOpArgumentMapping); -- GitLab