diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7730550e061f140d23a6fd2138834d19eae572a1..cdc39161bde2544609c9395ba04cb6ec4c63567f 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -183,6 +183,7 @@ function(op_library TARGET) list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc") list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc") list(REMOVE_ITEM hip_srcs "cholesky_op.cu") + list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5481e1434486454a8db1263153aa118e0ce3dad4 --- /dev/null +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -0,0 +1,256 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/matrix_rank_op.h" +#include +#include +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/svd_helper.h" + +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { +using DDim = framework::DDim; + +namespace detail { +static DDim GetInputBatchDim(const DDim& dim_x) { + auto x_vec = framework::vectorize(dim_x); + if (x_vec.size() == 2) { + return framework::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return framework::make_ddim(x_vec); +} +} // namespace detail + +class MatrixRankeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "MatrixRank"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixRank"); + auto dim_x = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(dim_x.size(), 2, + platform::errors::InvalidArgument( + "The dims of input must be greater than 2")); + + bool hermitian = ctx->Attrs().Get("hermitian"); + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, cols, + platform::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + + DDim dim_x_batch = detail::GetInputBatchDim(dim_x); + if (ctx->Attrs().Get( + "use_default_tol")) { // user not input TolTensor and tol + ctx->SetOutputDim("Out", dim_x_batch); + } else if (ctx->HasInput("TolTensor")) { + auto dim_tol = ctx->GetInputDim("TolTensor"); + if (dim_x_batch == dim_tol) { + ctx->SetOutputDim("Out", dim_x_batch); + } else { + int max_dim = std::max(dim_x_batch.size(), dim_tol.size()); + int axis = std::abs(dim_x_batch.size() - dim_tol.size()); + std::vector x_batch_dims_array(max_dim); + std::vector tol_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(), + tol_dims_array.data(), out_dims_array.data(), + max_dim, axis); + for (auto& it : out_dims_array) { + VLOG(3) << "out dims: " << it; + } + ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array)); + } + } else { + ctx->SetOutputDim("Out", dim_x_batch); + } + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library{framework::LibraryType::kPlain}; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + } +}; + +class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of matrix_rank op."); + AddInput("TolTensor", "(optional) Tol tensor, shape is same as X batch.") + .AsDispensable(); + AddOutput("Out", "(Tensor), The output tensor of matrix_rank op."); + AddAttr("tol", "(float, optional). tol").SetDefault(0.0f); + AddAttr("use_default_tol", + "represent whether user input TolTensor/tol, if input " + "TolTensor/tol use_default_tol=true, otherwise " + "use_default_tol=false") + .SetDefault(true); + AddAttr("hermitian", "(bool, optional). whether is hermitian matrix") + .SetDefault(false); + AddComment(R"DOC(MatrixRank Operator. + This operator is used to perform MatrixRank operation for batched matrics. + $$out = matrix_rank(X, tol, hermitian)$$ + )DOC"); + } +}; + +template +void BatchEigenvalues(const T* x_data, T* eigenvalues_data, int batches, + int rows, int cols, int k) { + // Eigen::Matrix API need non-const pointer. + T* input = const_cast(x_data); + int stride = rows * cols; + for (int i = 0; i < batches; i++) { + auto m = Eigen::Map< + Eigen::Matrix>( + input + i * stride, rows, rows); + Eigen::SelfAdjointEigenSolver< + Eigen::Matrix> + eigen_solver(m); + auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs(); + for (int j = 0; j < k; j++) { + *(eigenvalues_data + i * k + j) = eigenvalues[j]; + } + } +} + +template +void BatchSVD(const T* x_data, T* eigenvalues_data, int batches, int rows, + int cols, int k) { + // Eigen::Matrix API need non-const pointer. + T* input = const_cast(x_data); + int stride = rows * cols; + Eigen::BDCSVD< + Eigen::Matrix> + svd; + for (int i = 0; i < batches; i++) { + auto m = Eigen::Map< + Eigen::Matrix>( + input + i * stride, rows, cols); + svd.compute(m); + auto res_s = svd.singularValues(); + for (int j = 0; j < k; j++) { + eigenvalues_data[i * k + j] = res_s[j]; + } + } +} + +template +class MatrixRankCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + auto* x_data = x->data(); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + bool hermitian = context.Attr("hermitian"); + + auto dim_x = x->dims(); + auto dim_out = out->dims(); + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + int k = std::min(rows, cols); + auto numel = x->numel(); + int batches = numel / (rows * cols); + + bool use_default_tol = context.Attr("use_default_tol"); + const Tensor* atol_tensor = nullptr; + Tensor temp_tensor; + T rtol_T = 0; + if (use_default_tol) { + framework::TensorFromVector(std::vector{0}, + context.device_context(), &temp_tensor); + atol_tensor = &temp_tensor; + rtol_T = std::numeric_limits::epsilon() * std::max(rows, cols); + } else if (context.HasInput("TolTensor")) { + atol_tensor = context.Input("TolTensor"); + } else { + framework::TensorFromVector(std::vector{context.Attr("tol")}, + context.device_context(), &temp_tensor); + atol_tensor = &temp_tensor; + } + + Tensor eigenvalue_tensor; + auto* eigenvalue_data = eigenvalue_tensor.mutable_data( + detail::GetEigenvalueDim(dim_x, k), context.GetPlace()); + if (hermitian) { + BatchEigenvalues(x_data, eigenvalue_data, batches, rows, cols, k); + } else { + BatchSVD(x_data, eigenvalue_data, batches, rows, cols, k); + } + + auto dito_T = + math::DeviceIndependenceTensorOperations( + context); + std::vector max_eigenvalue_shape = framework::vectorize( + detail::RemoveLastDim(eigenvalue_tensor.dims())); + Tensor max_eigenvalue_tensor = + dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape); + + Tensor temp_rtol_tensor; + framework::TensorFromVector(std::vector{rtol_T}, &temp_rtol_tensor); + Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor); + Tensor tol_tensor; + tol_tensor.mutable_data(dim_out, context.GetPlace()); + ElementwiseComputeEx, platform::CPUDeviceContext, + T, T>(context, atol_tensor, &rtol_tensor, -1, + GreaterElementFunctor(), &tol_tensor); + + tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); + + Tensor compare_result; + compare_result.mutable_data(detail::NewAxisDim(dim_out, k), + context.GetPlace()); + + int axis = -1; + if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { + ElementwiseComputeEx, platform::CPUDeviceContext, T, + int>(context, &eigenvalue_tensor, &tol_tensor, axis, + GreaterThanFunctor(), &compare_result); + } else { + ElementwiseComputeEx, platform::CPUDeviceContext, T, + int>(context, &eigenvalue_tensor, &tol_tensor, axis, + LessThanFunctor(), &compare_result); + } + auto dito_int = + math::DeviceIndependenceTensorOperations(context); + std::vector result_shape = framework::vectorize(dim_out); + Tensor result = dito_int.ReduceSum(compare_result, result_shape); + out->ShareDataWith(result); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(matrix_rank, ops::MatrixRankeOp, ops::MatrixRankeOpMaker); + +REGISTER_OP_CPU_KERNEL(matrix_rank, ops::MatrixRankCPUKernel, + ops::MatrixRankCPUKernel); diff --git a/paddle/fluid/operators/matrix_rank_op.cu b/paddle/fluid/operators/matrix_rank_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..c6f85abac97d6fd5336d1a72cce4aab7baa8608d --- /dev/null +++ b/paddle/fluid/operators/matrix_rank_op.cu @@ -0,0 +1,326 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/matrix_rank_op.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +namespace detail { +DDim GetUDDim(const DDim& x_dim, int k) { + auto x_vec = framework::vectorize(x_dim); + x_vec[x_vec.size() - 1] = k; + return framework::make_ddim(x_vec); +} + +DDim GetVHDDim(const DDim& x_dim, int k) { + auto x_vec = framework::vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + return framework::make_ddim(x_vec); +} +} // namespace detail + +template +class MatrixRankGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = + context.template device_context(); + + const Tensor* x = context.Input("X"); + auto* x_data = x->data(); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + bool hermitian = context.Attr("hermitian"); + + auto dim_x = x->dims(); + auto dim_out = out->dims(); + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + int k = std::min(rows, cols); + auto numel = x->numel(); + int batches = numel / (rows * cols); + + bool use_default_tol = context.Attr("use_default_tol"); + const Tensor* atol_tensor = nullptr; + Tensor temp_tensor; + T rtol_T = 0; + if (use_default_tol) { + framework::TensorFromVector(std::vector{0}, + context.device_context(), &temp_tensor); + atol_tensor = &temp_tensor; + rtol_T = std::numeric_limits::epsilon() * std::max(rows, cols); + } else if (context.HasInput("TolTensor")) { + atol_tensor = context.Input("TolTensor"); + } else { + framework::TensorFromVector(std::vector{context.Attr("tol")}, + context.device_context(), &temp_tensor); + atol_tensor = &temp_tensor; + } + + // Must Copy X once, because the gesvdj will destory the content when exit. + Tensor x_tmp; + TensorCopy(*x, context.GetPlace(), &x_tmp); + auto info = memory::Alloc(dev_ctx, sizeof(int) * batches); + int* info_ptr = reinterpret_cast(info->ptr()); + + Tensor eigenvalue_tensor; + auto* eigenvalue_data = eigenvalue_tensor.mutable_data( + detail::GetEigenvalueDim(dim_x, k), context.GetPlace()); + if (hermitian) { + SyevjBatched(dev_ctx, batches, rows, x_tmp.data(), eigenvalue_data, + info_ptr); + platform::ForRange for_range( + dev_ctx, eigenvalue_tensor.numel()); + math::AbsFunctor functor(eigenvalue_data, eigenvalue_data, + eigenvalue_tensor.numel()); + for_range(functor); + } else { + Tensor U, VH; + auto* u_data = + U.mutable_data(detail::GetUDDim(dim_x, k), context.GetPlace()); + auto* vh_data = + VH.mutable_data(detail::GetVHDDim(dim_x, k), context.GetPlace()); + GesvdjBatched(dev_ctx, batches, cols, rows, k, x_tmp.data(), vh_data, + u_data, eigenvalue_data, info_ptr, 1); + } + + auto dito_T = + math::DeviceIndependenceTensorOperations(context); + std::vector max_eigenvalue_shape = framework::vectorize( + detail::RemoveLastDim(eigenvalue_tensor.dims())); + Tensor max_eigenvalue_tensor = + dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape); + Tensor temp_rtol_tensor; + framework::TensorFromVector(std::vector{rtol_T}, + context.device_context(), &temp_rtol_tensor); + Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor); + Tensor tol_tensor; + tol_tensor.mutable_data(dim_out, context.GetPlace()); + ElementwiseComputeEx, platform::CUDADeviceContext, + T, T>(context, atol_tensor, &rtol_tensor, -1, + GreaterElementFunctor(), &tol_tensor); + + tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); + + Tensor compare_result; + compare_result.mutable_data(detail::NewAxisDim(dim_out, k), + context.GetPlace()); + int axis = -1; + if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { + ElementwiseComputeEx, platform::CUDADeviceContext, + T, int64_t>(context, &eigenvalue_tensor, &tol_tensor, + axis, GreaterThanFunctor(), + &compare_result); + } else { + ElementwiseComputeEx, platform::CUDADeviceContext, T, + int64_t>(context, &eigenvalue_tensor, &tol_tensor, + axis, LessThanFunctor(), + &compare_result); + } + auto dito_int = + math::DeviceIndependenceTensorOperations(context); + std::vector result_shape = framework::vectorize(dim_out); + Tensor result = dito_int.ReduceSum(compare_result, result_shape); + out->ShareDataWith(result); + } + + void GesvdjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize, + int m, int n, int k, T* A, T* U, T* V, T* S, int* info, + int thin_UV = 1) const; + + void SyevjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize, + int n, T* A, T* W, int* info) const; +}; + +template <> +void MatrixRankGPUKernel::GesvdjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, + int k, float* A, float* U, float* V, float* S, int* info, + int thin_UV) const { + // do not compute singular vectors + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgesvdj_bufferSize( + handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, + gesvdj_params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; i++) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgesvdj( + handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, + U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, + info, gesvdj_params)); + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template <> +void MatrixRankGPUKernel::GesvdjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, + int k, double* A, double* U, double* V, double* S, int* info, + int thin_UV) const { + // do not compute singular vectors + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgesvdj_bufferSize( + handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, + gesvdj_params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgesvdj( + handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, + U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, + info, gesvdj_params)); + // check the error info + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template <> +void MatrixRankGPUKernel::SyevjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int n, float* A, + float* W, int* info) const { + auto handle = dev_ctx.cusolver_dn_handle(); + // Compute eigenvalues only + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + // matrix is saved as column-major in cusolver. + // numpy and torch use lower triangle to compute eigenvalues, so here use + // upper triangle + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + int lda = n; + int stride_A = lda * n; + int lwork = 0; + syevjInfo_t params = NULL; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateSyevjInfo(¶ms)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSsyevj_bufferSize( + handle, jobz, uplo, n, A, lda, W, &lwork, params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + for (int i = 0; i < batchSize; i++) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSsyevj( + handle, jobz, uplo, n, A + stride_A * i, lda, W + n * i, workspace_ptr, + lwork, info, params)); + + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver eigenvalues is not zero. [%d]", i, + error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroySyevjInfo(params)); +} + +template <> +void MatrixRankGPUKernel::SyevjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int n, double* A, + double* W, int* info) const { + auto handle = dev_ctx.cusolver_dn_handle(); + // Compute eigenvalues only + const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR; + // upper triangle of A is stored + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER; + int lda = n; + int stride_A = lda * n; + int lwork = 0; + syevjInfo_t params = NULL; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateSyevjInfo(¶ms)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDsyevj_bufferSize( + handle, jobz, uplo, n, A, lda, W, &lwork, params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + + for (int i = 0; i < batchSize; i++) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDsyevj( + handle, jobz, uplo, n, A + stride_A * i, lda, W + n * i, workspace_ptr, + lwork, info, params)); + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver eigenvalues is not zero. [%d]", i, + error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroySyevjInfo(params)); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(matrix_rank, ops::MatrixRankGPUKernel, + ops::MatrixRankGPUKernel); +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/matrix_rank_op.h b/paddle/fluid/operators/matrix_rank_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7fa74368332d0ab582919238ed6b0db6a32252ec --- /dev/null +++ b/paddle/fluid/operators/matrix_rank_op.h @@ -0,0 +1,71 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +namespace detail { +static DDim GetEigenvalueDim(const DDim& dim, int k) { + auto vec = framework::vectorize(dim); + vec.erase(vec.end() - 2, vec.end()); + vec.push_back(k); + return framework::make_ddim(vec); +} + +static DDim NewAxisDim(const DDim& dim, int k) { + auto vec = framework::vectorize(dim); + vec.push_back(k); + return framework::make_ddim(vec); +} + +static DDim RemoveLastDim(const DDim& dim) { + auto vec = framework::vectorize(dim); + if (vec.size() <= 1) { + return framework::make_ddim({1}); + } + vec.erase(vec.end() - 1, vec.end()); + return framework::make_ddim(vec); +} +} // namespace detail + +template +struct GreaterThanFunctor { + HOSTDEVICE int operator()(const T& a, const T& b) const { return a > b; } +}; + +template +struct LessThanFunctor { + HOSTDEVICE int operator()(const T& a, const T& b) const { return a < b; } +}; + +template +struct GreaterElementFunctor { + HOSTDEVICE T operator()(const T& a, const T& b) const { + if (a > b) { + return a; + } else { + return b; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index aa6a3697288398e0aa981fe53c630a8984bf40c7..055c0bc57c51d7c091601a4f98eb58677148a18f 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -313,6 +313,22 @@ struct DeviceIndependenceTensorOperations { return CreateOpRunAndReturnTensor("slice", inputs, attrs, out_shape); } + framework::Tensor ReduceSum(const framework::Tensor& x, + std::vector out_dim) { + framework::AttributeMap attrs; + attrs["dim"] = std::vector{-1}; + NameInTensorMap inputs({{"X", {&x}}}); + return CreateOpRunAndReturnTensor("reduce_sum", inputs, attrs, out_dim); + } + + framework::Tensor ReduceMax(const framework::Tensor& x, + std::vector out_dim) { + framework::AttributeMap attrs; + attrs["dim"] = std::vector{-1}; + NameInTensorMap inputs({{"X", {&x}}}); + return CreateOpRunAndReturnTensor("reduce_max", inputs, attrs, out_dim); + } + private: const framework::ExecutionContext& context; BlasT GetBlas() { diff --git a/paddle/fluid/platform/dynload/cusolver.cc b/paddle/fluid/platform/dynload/cusolver.cc index 5d841c55073b1fb2bf01c3545c44cffcb2b92fc5..d4163e9a7431b086cee4e99dd4c07e42d7d8c0b7 100644 --- a/paddle/fluid/platform/dynload/cusolver.cc +++ b/paddle/fluid/platform/dynload/cusolver.cc @@ -28,6 +28,11 @@ CUSOLVER_ROUTINE_EACH(DEFINE_WRAP); #ifdef CUSOLVER_ROUTINE_EACH_R1 CUSOLVER_ROUTINE_EACH_R1(DEFINE_WRAP); #endif + +#ifdef CUSOLVER_ROUTINE_EACH_R2 +CUSOLVER_ROUTINE_EACH_R2(DEFINE_WRAP); +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 42583b60680b951f535cff34eeeb10acc54c75ce..36ba5dd0948159b91fd04362aff52dad0a61416a 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -66,6 +66,18 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif +#if CUDA_VERSION >= 9020 +#define CUSOLVER_ROUTINE_EACH_R2(__macro) \ + __macro(cusolverDnCreateSyevjInfo); \ + __macro(cusolverDnSsyevj_bufferSize); \ + __macro(cusolverDnDsyevj_bufferSize); \ + __macro(cusolverDnSsyevj); \ + __macro(cusolverDnDsyevj); \ + __macro(cusolverDnDestroySyevjInfo); + +CUSOLVER_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) +#endif + #undef DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP } // namespace dynload } // namespace platform diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index fb6ec0b32d9916c938f8236ed101fe3afb83000d..3da4a4b8e82abac2824d963acc93c6fd8f698e36 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -67,7 +67,7 @@ std::map> op_ins_map = { {"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}}, {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, {"run_program", {"X", "Params"}}, -}; + {"matrix_rank", {"X", "TolTensor"}}}; // NOTE(zhiqiu): Like op_ins_map. // Commonly, the outputs in auto-generated OP function are determined by the diff --git a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d0bd7e6017e207c0bc3bdda53178ea78e80ccc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest + +import numpy as np + +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import compiler, Program, program_guard + +paddle.enable_static() +SEED = 2049 +np.random.seed(SEED) + + +class TestMatrixRankOP(OpTest): + def setUp(self): + self.op_type = "matrix_rank" + self.init_data() + self.inputs = {'X': self.x} + self.attrs = {'hermitian': self.hermitian} + if self.tolTensor is not None: + self.inputs["TolTensor"] = self.tolTensor + if self.tol is not None: + self.attrs["tol"] = self.tol + self.attrs["use_default_tol"] = self.use_default_tol + self.outputs = {'Out': self.out} + + def test_check_output(self): + self.check_output() + + def init_data(self): + self.x = np.eye(3, dtype=np.float32) + self.tolTensor = None + self.tol = 0.1 + self.use_default_tol = False + self.hermitian = True + self.out = np.linalg.matrix_rank(self.x, self.tol, self.hermitian) + + +class TestMatrixRankOP1(TestMatrixRankOP): + def init_data(self): + self.x = np.eye(3, k=1, dtype=np.float64) + self.tolTensor = None + self.tol = None + self.use_default_tol = True + self.hermitian = False + self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) + + +class TestMatrixRankOP2(TestMatrixRankOP): + def init_data(self): + self.x = np.random.rand(3, 4, 5, 6).astype(np.float32) + self.tolTensor = np.random.random([3, 4]).astype(self.x.dtype) + self.tol = None + self.use_default_tol = False + self.hermitian = False + self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) + + +class TestMatrixRankOP3(TestMatrixRankOP): + def init_data(self): + self.x = np.eye(200, dtype=np.float64) + self.tolTensor = None + self.tol = None + self.use_default_tol = True + self.hermitian = True + self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) + + +class TestMatrixRankOP4(TestMatrixRankOP): + def init_data(self): + self.x = np.random.rand(1, 10).astype(np.float32) + self.tolTensor = None + self.tol = None + self.use_default_tol = True + self.hermitian = False + self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) + + +class TestMatrixRankOP5(TestMatrixRankOP): + def init_data(self): + self.x = np.random.rand(5, 1).astype(np.float64) + self.tolTensor = np.random.random([1, 4]).astype(self.x.dtype) + self.tol = None + self.use_default_tol = False + self.hermitian = False + self.out = np.linalg.matrix_rank(self.x, self.tolTensor, self.hermitian) + + +class TestMatrixRankAPI(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + + x_np = np.eye(10, dtype=np.float32) + x_pd = paddle.to_tensor(x_np) + rank_np = np.linalg.matrix_rank(x_np, hermitian=True) + rank_pd = paddle.linalg.matrix_rank(x_pd, hermitian=True) + self.assertTrue(np.allclose(rank_np, rank_pd)) + + x_np = np.random.rand(3, 4, 7, 8).astype(np.float64) + tol_np = np.random.random([3, 4]).astype(np.float32) + x_pd = paddle.to_tensor(x_np) + tol_pd = paddle.to_tensor(tol_np) + rank_np = np.linalg.matrix_rank(x_np, tol_np, hermitian=False) + rank_pd = paddle.linalg.matrix_rank(x_pd, tol_pd, hermitian=False) + self.assertTrue(np.allclose(rank_np, rank_pd)) + + x_np = np.random.rand(3, 4, 7, 8).astype(np.float64) + x_pd = paddle.to_tensor(x_np) + tol = 0.1 + rank_np = np.linalg.matrix_rank(x_np, tol, hermitian=False) + rank_pd = paddle.linalg.matrix_rank(x_pd, tol, hermitian=False) + self.assertTrue(np.allclose(rank_np, rank_pd)) + + def test_static(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + x_np = np.random.rand(3, 4, 7, 7).astype(np.float64) + tol_np = np.random.random([3, 4]).astype(np.float32) + x_pd = paddle.fluid.data( + name="X", shape=[3, 4, 7, 7], dtype='float64') + tol_pd = paddle.fluid.data( + name="TolTensor", shape=[3, 4], dtype='float32') + rank_np = np.linalg.matrix_rank(x_np, tol_np, hermitian=False) + rank_pd = paddle.linalg.matrix_rank( + x_pd, tol_pd, hermitian=False) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"X": x_np, + "TolTensor": tol_np}, + fetch_list=[rank_pd]) + self.assertTrue(np.allclose(fetches[0], rank_np)) + + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + x_np = np.random.rand(3, 4, 7, 7).astype(np.float64) + x_pd = paddle.fluid.data( + name="X", shape=[3, 4, 7, 7], dtype='float64') + rank_np = np.linalg.matrix_rank(x_np, hermitian=True) + rank_pd = paddle.linalg.matrix_rank(x_pd, hermitian=True) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"X": x_np}, + fetch_list=[rank_pd]) + self.assertTrue(np.allclose(fetches[0], rank_np)) + + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + x_np = np.random.rand(3, 4, 7, 7).astype(np.float64) + x_pd = paddle.fluid.data( + name="X", shape=[3, 4, 7, 7], dtype='float64') + rank_np = np.linalg.matrix_rank(x_np, 0.1, hermitian=False) + rank_pd = paddle.linalg.matrix_rank(x_pd, 0.1, hermitian=False) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"X": x_np}, + fetch_list=[rank_pd]) + self.assertTrue(np.allclose(fetches[0], rank_np)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 236150eef94795e73813fb251a7e676ac2f70e55..eabb017a0f62de99d2702547b5ee970881afb326 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -16,12 +16,14 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 +from .tensor.linalg import matrix_rank from .tensor.linalg import svd __all__ = [ 'cholesky', #noqa 'norm', 'inv', + 'matrix_rank', 'svd', 'matrix_power' ] diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 40dfd32b50a05f64c5f0aabe2eac9e91c2b91b5a..a565f7bfe2e8dca5e37bda23a877ae6c164534d9 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -15,9 +15,9 @@ import numpy as np from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type -from ..fluid.framework import in_dygraph_mode, _varbase_creator +from ..fluid.framework import in_dygraph_mode, _varbase_creator, Variable -from ..fluid.layers import transpose # noqa: F401 +from ..fluid.layers import transpose, cast # noqa: F401 from paddle.common_ops_import import core from paddle.common_ops_import import VarDesc from paddle import _C_ops @@ -785,6 +785,94 @@ def cholesky(x, upper=False, name=None): return out +def matrix_rank(x, tol=None, hermitian=False, name=None): + r""" + Computes the rank of a matrix. + + The rank of a matrix is the number of singular values that are greater than the specified tol threshold when hermitian=False, + or the number of eigenvalues in absolute value that are greater than the specified tol threshold when hermitian=True. + + Args: + x (Tensor): The input tensor. + Its shape should be [..., m, n], where ... is zero or more batch dimensions. If x is a batch of matrices then the output + has the same batch dimensions. The data type of x should be float32 or float64. + tol (float,Tensor,optional): the tolerance value. Default: None. + If tol is not specified, and sigma is the largest singular value (or eigenvalue in absolute value), and eps is the + epsilon value for the dtype of x, then tol is computed with formula tol=sigma * max(m,n) * eps. Note that if x is + a batch of matrices, tol is computed this way for every batch. + hermitian (bool,optional): indicates whether x is Hermitian. Default: False. + When hermitian=True, x is assumed to be Hermitian, but x is not checked inside the function. Instead, We just use the + lower triangular of the matrix to compute. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Rank of tensor x. + + Examples: + .. code-block:: python + + import paddle + + a = paddle.eye(10) + b = paddle.linalg.matrix_rank(a) + print(b) + # b = [10] + + c = paddle.ones(shape=[3, 4, 5, 5]) + d = paddle.linalg.matrix_rank(c, tol=0.01, hermitian=True) + print(d) + # d = [[1, 1, 1, 1], + # [1, 1, 1, 1], + # [1, 1, 1, 1]] + + """ + + if in_dygraph_mode(): + if tol is None: + tol_tensor = None + tol_attr = 0.0 + use_default_tol = True + elif isinstance(tol, Variable): + if tol.dtype != x.dtype: + tol_tensor = cast(tol, x.dtype) + else: + tol_tensor = tol + tol_attr = 0.0 + use_default_tol = False + else: + tol_tensor = None + tol_attr = float(tol) + use_default_tol = False + return _C_ops.matrix_rank(x, tol_tensor, "tol", tol_attr, 'hermitian', + hermitian, 'use_default_tol', use_default_tol) + + inputs = {} + attrs = {} + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'matrix_rank') + inputs['X'] = x + if tol is None: + attrs['use_default_tol'] = True + elif isinstance(tol, Variable): + check_variable_and_dtype(tol, 'tol', ['float32'], 'matrix_rank') + attrs['use_default_tol'] = False + if tol.dtype != x.dtype: + inputs['TolTensor'] = cast(tol, x.dtype) + else: + inputs['TolTensor'] = tol + else: + check_type(tol, 'tol', float, 'matrix_rank') + attrs['use_default_tol'] = False + attrs['tol'] = tol + check_type(hermitian, 'hermitian', bool, 'matrix_rank') + attrs['hermitian'] = hermitian + + helper = LayerHelper('matrix_rank', **locals()) + out = helper.create_variable_for_type_inference(dtype='int32') + helper.append_op( + type='matrix_rank', inputs=inputs, outputs={'Out': out}, attrs=attrs) + return out + + def bmm(x, y, name=None): """ Applies batched matrix multiplication to two tensors.