diff --git a/cmake/operators.cmake b/cmake/operators.cmake index cdc39161bde2544609c9395ba04cb6ec4c63567f..2c010a1e6297f0744eb8c579aa048b63e98a6211 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -185,6 +185,7 @@ function(op_library TARGET) list(REMOVE_ITEM hip_srcs "cholesky_op.cu") list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") + list(REMOVE_ITEM hip_srcs "eigh_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} diff --git a/paddle/fluid/operators/eigh_op.cc b/paddle/fluid/operators/eigh_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3056bd43ba53dfe1c34bf1fb74ebb9bc71b43da --- /dev/null +++ b/paddle/fluid/operators/eigh_op.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigh_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class EighOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh"); + OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", + "Eigh"); + OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", + "Eigh"); + + auto input_dim = ctx->GetInputDim("X"); + auto rank = input_dim.size(); + + PADDLE_ENFORCE_GE(rank, 2, + platform::errors::InvalidArgument( + "The Input(X) should have at least 2 dimensions." + "But received a %d dimension tensor.", + rank)); + PADDLE_ENFORCE_EQ( + input_dim[rank - 2], input_dim[rank - 1], + platform::errors::InvalidArgument( + "Eigh op is designed for square matrix, consequently" + "inner-most 2 dimensions of Input(X) should be symmetric." + "But received X's shape[-2] = %d and shape[-1] = %d.", + input_dim[rank - 2], input_dim[rank - 1])); + + std::vector values_dim; + if (rank > 2) { + for (auto i = 0; i < rank - 1; i++) { + values_dim.emplace_back(input_dim[i]); + } + } else { + values_dim = {input_dim[1]}; + } + + ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim)); + ctx->SetOutputDim("Eigenvectors", input_dim); + } +}; + +class EignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), Hermitian or real symmetric matrices." + "Its shape should be [*, N, N] where * is zero or" + "more batch dimensions. The data type is float32 ," + "float64, complex64, complex128."); + AddOutput("Eigenvalues", + "(Tensor), The eigenvalues in ascending order." + "The data type is float32 or float64."); + AddOutput( + "Eigenvectors", + "(Tensor), The column is the normalized eigenvector " + "corresponding to the eigenvalue. The data type is the same as ``X``."); + AddAttr( + "UPLO", + "(string, default 'L'), 'L' represents the lower triangular matrix," + "'U' represents the upper triangular matrix.") + .SetDefault("L"); + AddComment(R"DOC( +Eigh Operator. + +Computes the eigenvalues and eigenvectors of a complex Hermitian + (conjugate symmetric) or a real symmetric matrix. + +)DOC"); + } +}; + +class EighGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", + "EighGrad"); + OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", + "EighGrad"); + OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")), + "Input", "Eigenvalues@GRAD", "EighGrad"); + OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")), + "Input", "Eigenvectors@GRAD", "EighGrad"); + auto dims = ctx->GetInputDim("Eigenvectors"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Eigenvectors")), + ctx.device_context()); + } +}; + +template +class EighGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("Eigenvalues", this->Output("Eigenvalues")); + op->SetInput("Eigenvectors", this->Output("Eigenvectors")); + op->SetInput(framework::GradVarName("Eigenvalues"), + this->OutputGrad("Eigenvalues")); + op->SetInput(framework::GradVarName("Eigenvectors"), + this->OutputGrad("Eigenvectors")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker, + ops::EighGradOpMaker, + ops::EighGradOpMaker); +REGISTER_OPERATOR(eigh_grad, ops::EighGradOp); + +REGISTER_OP_CPU_KERNEL( + eigh, ops::EighKernel, + ops::EighKernel, + ops::EighKernel>, + ops::EighKernel>); + +REGISTER_OP_CPU_KERNEL( + eigh_grad, + ops::EighGradKernel, + ops::EighGradKernel, + ops::EighGradKernel>, + ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.cu b/paddle/fluid/operators/eigh_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..cfc9eba450959615ea3c9ce20cdeadbaa014bb46 --- /dev/null +++ b/paddle/fluid/operators/eigh_op.cu @@ -0,0 +1,53 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigh_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class EighGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto input_var = ctx.Input("X"); + auto output_w_var = ctx.Output("Eigenvalues"); + auto output_v_var = ctx.Output("Eigenvectors"); + std::string lower = ctx.Attr("UPLO"); + bool is_lower = (lower == "L"); + math::MatrixEighFunctor functor; + functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + eigh, ops::EighGPUKernel, ops::EighGPUKernel, + ops::EighGPUKernel>, + ops::EighGPUKernel>); + +REGISTER_OP_CUDA_KERNEL( + eigh_grad, + ops::EighGradKernel, + ops::EighGradKernel, + ops::EighGradKernel>, + ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.h b/paddle/fluid/operators/eigh_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0af38d44e54570b8907b8507c37a71e0ffc83a31 --- /dev/null +++ b/paddle/fluid/operators/eigh_op.h @@ -0,0 +1,80 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/eigen_values_vectors.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenTensor = framework::EigenTensor; +template +using EigenVector = framework::EigenVector; + +template +class EighKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input_var = ctx.Input("X"); + auto output_w_var = ctx.Output("Eigenvalues"); + auto output_v_var = ctx.Output("Eigenvectors"); + std::string lower = ctx.Attr("UPLO"); + bool is_lower = (lower == "L"); + math::MatrixEighFunctorCPU functor; + functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); + } +}; + +template +class EighGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& x_grad = *ctx.Output(framework::GradVarName("X")); + x_grad.mutable_data(ctx.GetPlace()); + auto& output_w_var = *ctx.Input("Eigenvalues"); + auto& output_v_var = *ctx.Input("Eigenvectors"); + auto& output_w_grad = + *ctx.Input(framework::GradVarName("Eigenvalues")); + auto& output_v_grad = + *ctx.Input(framework::GradVarName("Eigenvectors")); + + auto& dims = output_v_var.dims(); + const int m = dims[dims.size() - 1]; + auto dito = + math::DeviceIndependenceTensorOperations( + ctx); + auto tV = dito.Transpose(dito.Conj(output_v_var)); + auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2), + dito.Unsqueeze(output_w_var, -1)); + Tensor result = dito.Matmul(tV, output_v_grad); + result.mutable_data(dims, ctx.GetPlace()); + std::vector out_shape = framework::vectorize(dims); + auto constant = dito.Fill(out_shape, 0.5); + result = dito.Sub(result, dito.Conj(dito.Transpose(result))); + result = dito.Mul(result, constant); + result = dito.Div_(result, W); + result = dito.DiagFill(m, m, m, 0, output_w_grad, result); + x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h new file mode 100644 index 0000000000000000000000000000000000000000..4e2d180e336281ba8708cb782cccfdd34b1a6876 --- /dev/null +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -0,0 +1,314 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "Eigen/Core" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/svd_helper.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/dynload/cusolver.h" +#endif // PADDLE_WITH_CUDA + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenTensor = framework::EigenTensor; + +template +using InputMatrixMap = Eigen::Map< + const Eigen::Matrix>; + +template +using OutputMatrixMap = Eigen::Map< + Eigen::Matrix>; + +template +inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data, + ValueType *eigenvalues_data, + ValueType *eigenvectors_data, + int batches, int rows, int cols, + bool has_vectors) { + int stride = rows * cols; + for (int i = 0; i < batches; i++) { + auto m = InputMatrixMap(x_data + i * stride, rows, cols); + auto eigenvalues = + OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); + auto eigenvectors = + OutputMatrixMap(eigenvectors_data + i * stride, rows, cols); + + Eigen::SelfAdjointEigenSolver> + eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors + : Eigen::EigenvaluesOnly); + PADDLE_ENFORCE_EQ( + eigen_solver.info(), Eigen::Success, + platform::errors::InvalidArgument( + "Self Adjoint Eigen decomposition is not successful. " + "The %d-th input matrice might not be not be positive definite.", + i)); + + eigenvalues = eigen_solver.eigenvalues().transpose(); + if (has_vectors) { + eigenvectors = eigen_solver.eigenvectors().transpose(); + } + } +} + +template +inline void ComputeComplexEigenvaluesAndVectors(T *x_data, + ValueType *eigenvalues_data, + T *eigenvectors_data, + int batches, int rows, int cols, + bool has_vectors) { + using Complex = std::complex; + Complex *input = reinterpret_cast(x_data); + Complex *eigenvectors_data_ = reinterpret_cast(eigenvectors_data); + + int stride = rows * cols; + for (int i = 0; i < batches; i++) { + auto m = InputMatrixMap(input + i * stride, rows, cols); + auto eigenvalues = + OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); + auto eigenvectors = + OutputMatrixMap(eigenvectors_data_ + i * stride, rows, cols); + + Eigen::SelfAdjointEigenSolver< + Eigen::Matrix> + eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors + : Eigen::EigenvaluesOnly); + PADDLE_ENFORCE_EQ( + eigen_solver.info(), Eigen::Success, + platform::errors::InvalidArgument( + "Self Adjoint Eigen decomposition is not successful. " + "The %d-th input matrice might not be not be positive definite.", + i)); + + eigenvalues = eigen_solver.eigenvalues().transpose(); + if (has_vectors) { + eigenvectors = eigen_solver.eigenvectors().transpose(); + } + } +} + +inline int64_t GetBatchSize(framework::DDim dims) { + int64_t batch_size = 1; + auto dim_size = dims.size(); + for (int i = 0; i < dim_size - 2; i++) { + batch_size *= dims[i]; + } + return batch_size; +} + +// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real +// symmetric matrices, and uses the variable has_vectors to +// control whether to return the eigenvectors. +template +struct MatrixEighFunctorCPU { + public: + void operator()(const framework::ExecutionContext &ctx, const Tensor &input, + Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, + bool has_vectors) { + auto dims = input.dims(); + auto output_value_dim = eigen_values->dims(); + + int64_t batch_size = 1; + int dim_size = dims.size(); + for (int64_t i = 0; i < dim_size - 2; i++) { + batch_size *= dims[i]; + } + auto dito = DeviceIndependenceTensorOperations(ctx); + Tensor input_tensor; + TensorCopy(input, ctx.GetPlace(), &input_tensor); + if (!is_lower) { + input_tensor = dito.Transpose(input); + } + int rows = dims[dims.size() - 2]; + + auto *value_data = + eigen_values->mutable_data(output_value_dim, ctx.GetPlace()); + + if (framework::IsComplexType(input_tensor.type())) { + auto *x_data = input_tensor.data(); + auto *vector_data = eigen_vectors->mutable_data(dims, ctx.GetPlace()); + ComputeComplexEigenvaluesAndVectors( + x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); + } else { + auto *x_data = input_tensor.data(); + auto *vector_data = + eigen_vectors->mutable_data(dims, ctx.GetPlace()); + ComputeFloatEigenvaluesAndVectors( + x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); + } + if (has_vectors) { + *eigen_vectors = dito.Transpose(*eigen_vectors); + } + } +}; + +#ifdef PADDLE_WITH_CUDA + +// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real +// symmetric matrices on GPU, and uses the variable has_vectors +// to control whether to return the eigenvectors. +template +struct MatrixEighFunctor { + public: + void operator()(const framework::ExecutionContext &ctx, const Tensor &input, + Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, + bool has_vectors) { + auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); + auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); + + auto &dims = input.dims(); + int dim_size = dims.size(); + int64_t batch_size = GetBatchSize(dims); + + cublasFillMode_t uplo = + is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + cusolverEigMode_t jobz = + has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + + int n = dims[dim_size - 1]; + int lda = std::max(1, n); + auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; + auto values_stride = dims[dim_size - 1]; + + auto &dev_ctx = ctx.template device_context(); + auto dito = + math::DeviceIndependenceTensorOperations(ctx); + Tensor output_v_var_trans = dito.Transpose(input); + TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors); + + int lwork = 0; + auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size); + auto *info_ptr = reinterpret_cast(info->ptr()); + + // When the input type is float32, and the feature value input dimension is + // greater than or equal to [*,32,32] and less than or equal to + // [*,512,512], Syevj has better performance. + bool use_syevj = + (eigen_vectors->type() == framework::proto::VarType::FP32 && + values_stride >= 32 && values_stride <= 512); + + syevjInfo_t syevj_params; + if (use_syevj) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateSyevjInfo(&syevj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnSsyevj_bufferSize( + dev_ctx.cusolver_dn_handle(), jobz, uplo, n, + reinterpret_cast(out_vector), lda, + reinterpret_cast(out_value), &lwork, + syevj_params)); + } else { + EvdBuffer(dev_ctx.cusolver_dn_handle(), jobz, uplo, n, out_vector, lda, + out_value, &lwork); + } + + auto work = memory::Alloc(dev_ctx, sizeof(T) * lwork); + auto *work_ptr = reinterpret_cast(work->ptr()); + + for (auto i = 0; i < batch_size; i++) { + auto vector_data = out_vector + i * vector_stride; + auto value_data = out_value + i * values_stride; + auto handle = dev_ctx.cusolver_dn_handle(); + if (use_syevj) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSsyevj( + handle, jobz, uplo, n, reinterpret_cast(vector_data), lda, + reinterpret_cast(value_data), + reinterpret_cast(work_ptr), lwork, info_ptr, + syevj_params)); + } else { + Evd(handle, jobz, uplo, n, vector_data, lda, value_data, work_ptr, + lwork, info_ptr); + } + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_ptr, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: the [%d] argument had an illegal value", i, + error_info)); + } + + if (use_syevj) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroySyevjInfo(syevj_params)); + } + + if (has_vectors) { + *eigen_vectors = dito.Transpose(*eigen_vectors); + } + } + + inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const T *A, int lda, + const ValueType *W, int *lwork) const; + + inline void Evd(cusolverDnHandle_t handle, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, + T *work, int lwork, int *devInfo) const; +}; + +#define FUNC_WITH_TYPES(m) \ + m(float, float, Ssy, float) m(double, double, Dsy, double) \ + m(float, paddle::platform::complex, Che, cuComplex) \ + m(double, paddle::platform::complex, Zhe, cuDoubleComplex) + +#define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \ + template <> \ + inline void MatrixEighFunctor::EvdBuffer( \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \ + int *lwork) const { \ + PADDLE_ENFORCE_CUDA_SUCCESS( \ + platform::dynload::cusolverDn##C##evd_bufferSize( \ + handle, jobz, uplo, n, reinterpret_cast(A), lda, \ + W, lwork)); \ + } + +FUNC_WITH_TYPES(EVDBUFFER_INSTANCE); + +#define EVD_INSTANCE(ValueType, T, C, CastType) \ + template <> \ + inline void MatrixEighFunctor::Evd( \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \ + int lwork, int *devInfo) const { \ + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDn##C##evd( \ + handle, jobz, uplo, n, reinterpret_cast(A), lda, W, \ + reinterpret_cast(work), lwork, devInfo)); \ + } + +FUNC_WITH_TYPES(EVD_INSTANCE); + +#undef FUNC_WITH_TYPES +#undef EVDBUFFER_INSTANCE +#undef EVD_INSTANCE + +#endif // PADDLE_WITH_CUDA + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index b0c361e86a531730a4d9999682a6ba784b6d415c..71d106c211f71a45c82b6bebb764ba70b9eed79e 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -25,6 +25,8 @@ #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" @@ -36,6 +38,9 @@ using Tensor = framework::Tensor; using InTensors = std::vector; using OutTensors = std::vector; using OpName = std::string; +template +using EigenVector = framework::EigenVector; template void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, @@ -140,7 +145,42 @@ static std::vector GetBroadcastShape(InTensors ins) { break; \ } -template +template +struct DiagAndFillFunctor { + DiagAndFillFunctor(const int m, const int n, const int num_lower_diags, + const int num_upper_diags, const ValueType* scale, + const T* input, T* output) + : m_(m), + n_(n), + num_lower_diags_(num_lower_diags), + num_upper_diags_(num_upper_diags), + scale_(scale), + input_(input), + output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int col = index % n_; + const int row = (index / n_) % m_; + const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_); + const int band_end = + (num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1); + if (col < band_start || col >= band_end) { + output_[index] = input_[index]; + } else if (col == band_end - 1) { + output_[index] = static_cast(scale_[index % m_]); + } else { + output_[index] = input_[index]; + } + } + + private: + const int m_, n_, num_lower_diags_, num_upper_diags_; + const ValueType* scale_; + const T* input_; + T* output_; +}; + +template struct DeviceIndependenceTensorOperations { // 1. Device indenpendence, for kernel reuse. // 2. Input and output is always tensor type. @@ -398,6 +438,60 @@ struct DeviceIndependenceTensorOperations { return ret; } + Tensor Conj(const Tensor& x) { + Tensor out; + auto* out_data = out.mutable_data(x.dims(), context.GetPlace()); + auto* x_data = x.data(); + auto for_range = GetForRange(x.numel()); + math::ConjFunctor functor(x_data, x.numel(), out_data); + for_range(functor); + return out; + } + + Tensor DiagFill(const int m, const int n, const int num_lower_diags, + const int num_upper_diags, const Tensor& scale, + const Tensor& input) { + Tensor out; + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, input.numel()); + DiagAndFillFunctor diag_and_copy_functor( + m, n, num_lower_diags, num_upper_diags, scale.data(), + input.data(), out.mutable_data(input.dims(), input.place())); + for_range(diag_and_copy_functor); + return out; + } + + // Support x and y are different data types + Tensor Div_(const Tensor& x, const Tensor& y) { + Tensor out; + out.mutable_data(x.dims(), context.GetPlace()); + auto x_vector = EigenVector::Flatten(x); + auto y_vector = EigenVector::Flatten(y); + auto out_vector = EigenVector::Flatten(out); + auto& place = + *context.template device_context().eigen_device(); + out_vector.device(place) = x_vector / y_vector; + return out; + } + + framework::Tensor Sub_(const framework::Tensor& x, + const framework::Tensor& y) { + framework::Tensor ret; + std::vector out_shape = GetBroadcastShape({&x, &y}); + ret.Resize(framework::make_ddim(out_shape)); + if (x.dims().size() >= y.dims().size()) { + ElementwiseComputeEx, DeviceContext, ValueType>( + context, &x, &y, -1, SubFunctor(), &ret); + } else { + ElementwiseComputeEx, DeviceContext, + ValueType>( + // This is copyed from elementwise_sub, which means we + // need reverse will xrank < yrank + context, &x, &y, -1, InverseSubFunctor(), &ret); + } + return ret; + } + private: const framework::ExecutionContext& context; BlasT GetBlas() { diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 36ba5dd0948159b91fd04362aff52dad0a61416a..a8ce1cc9d3a354e8a1deb90b8905ebe41e86d4c0 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -48,7 +48,15 @@ extern void *cusolver_dso_handle; __macro(cusolverDnSpotrf_bufferSize); \ __macro(cusolverDnDpotrf_bufferSize); \ __macro(cusolverDnSpotrf); \ - __macro(cusolverDnDpotrf); + __macro(cusolverDnDpotrf); \ + __macro(cusolverDnSsyevd_bufferSize); \ + __macro(cusolverDnDsyevd_bufferSize); \ + __macro(cusolverDnCheevd_bufferSize); \ + __macro(cusolverDnZheevd_bufferSize); \ + __macro(cusolverDnSsyevd); \ + __macro(cusolverDnDsyevd); \ + __macro(cusolverDnCheevd); \ + __macro(cusolverDnZheevd); CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 0d7169251a7cda9639ec7a852931aefabed6358b..555f53b16f420bd74f31ea077f710abeeb7bfdee 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -101,6 +101,8 @@ from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import mv # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 +from .tensor.linalg import svd # noqa: F401 +from .tensor.linalg import eigh # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_than # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_eigh_op.py b/python/paddle/fluid/tests/unittests/test_eigh_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e4343647025255d7daca7862bc7ad55e2a033db5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eigh_op.py @@ -0,0 +1,199 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +from op_test import OpTest +from gradient_checker import grad_check + + +class TestEighOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "eigh" + self.init_input() + self.init_config() + np.random.seed(123) + out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO) + self.inputs = {"X": self.x_np} + self.attrs = {"UPLO": self.UPLO} + self.outputs = {'Eigenvalues': out_w, "Eigenvectors": out_v} + + def init_config(self): + self.UPLO = 'L' + + def init_input(self): + self.x_shape = (10, 10) + self.x_type = np.float64 + self.x_np = np.random.random(self.x_shape).astype(self.x_type) + + def test_check_output(self): + self.check_output(no_check_set=['Eigenvectors']) + + def test_grad(self): + self.check_grad(["X"], ["Eigenvalues"]) + + +class TestEighUPLOCase(TestEighOp): + def init_config(self): + self.UPLO = 'U' + + +class TestEighGPUCase(unittest.TestCase): + def setUp(self): + self.x_shape = [32, 32] + self.dtype = "float32" + np.random.seed(123) + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.rtol = 1e-5 + self.atol = 1e-5 + + def test_check_output_gpu(self): + if paddle.is_compiled_with_cuda(): + paddle.disable_static(place=paddle.CUDAPlace(0)) + input_real_data = paddle.to_tensor(self.x_np) + expected_w, expected_v = np.linalg.eigh(self.x_np) + actual_w, actual_v = paddle.linalg.eigh(input_real_data) + np.testing.assert_allclose( + actual_w, expected_w, rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose( + abs(actual_v.numpy()), + abs(expected_v), + rtol=self.rtol, + atol=self.atol) + + +class TestEighAPI(unittest.TestCase): + def setUp(self): + self.init_input_shape() + self.dtype = "float32" + self.UPLO = 'L' + self.rtol = 1e-6 + self.atol = 1e-6 + self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ + else paddle.CPUPlace() + np.random.seed(123) + self.real_data = np.random.random(self.x_shape).astype(self.dtype) + self.complex_data = np.random.random(self.x_shape).astype( + self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype) + self.trans_dims = list(range(len(self.x_shape) - 2)) + [ + len(self.x_shape) - 1, len(self.x_shape) - 2 + ] + + def init_input_shape(self): + self.x_shape = [5, 5] + + def compare_result(self, actual_w, actual_v, expected_w, expected_v): + np.testing.assert_allclose( + actual_w, expected_w, rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose( + abs(actual_v), abs(expected_v), rtol=self.rtol, atol=self.atol) + + def check_static_float_result(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + input_x = paddle.static.data( + 'input_x', shape=self.x_shape, dtype=self.dtype) + output_w, output_v = paddle.linalg.eigh(input_x) + exe = paddle.static.Executor(self.place) + expected_w, expected_v = exe.run(main_prog, + feed={"input_x": self.real_data}, + fetch_list=[output_w, output_v]) + + actual_w, actual_v = np.linalg.eigh(self.real_data) + self.compare_result(actual_w, actual_v, expected_w, expected_v) + + def check_static_complex_result(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + x_dtype = np.complex64 if self.dtype == "float32" else np.complex128 + input_x = paddle.static.data( + 'input_x', shape=self.x_shape, dtype=x_dtype) + output_w, output_v = paddle.linalg.eigh(input_x) + exe = paddle.static.Executor(self.place) + expected_w, expected_v = exe.run( + main_prog, + feed={"input_x": self.complex_data}, + fetch_list=[output_w, output_v]) + actual_w, actual_v = np.linalg.eigh(self.complex_data) + self.compare_result(actual_w, actual_v, expected_w, expected_v) + + def test_in_static_mode(self): + paddle.enable_static() + self.check_static_float_result() + self.check_static_complex_result() + + def test_in_dynamic_mode(self): + paddle.disable_static(self.place) + input_real_data = paddle.to_tensor(self.real_data) + expected_w, expected_v = np.linalg.eigh(self.real_data) + actual_w, actual_v = paddle.linalg.eigh(input_real_data) + self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) + + input_complex_data = paddle.to_tensor(self.complex_data) + expected_w, expected_v = np.linalg.eigh(self.complex_data) + actual_w, actual_v = paddle.linalg.eigh(input_complex_data) + self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) + + def test_eigh_grad(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.complex_data, stop_gradient=False) + w, v = paddle.linalg.eigh(x) + (w.sum() + paddle.abs(v).sum()).backward() + np.testing.assert_allclose( + abs(x.grad.numpy()), + abs(x.grad.numpy().conj().transpose(self.trans_dims)), + rtol=self.rtol, + atol=self.atol) + + +class TestEighBatchAPI(TestEighAPI): + def init_input_shape(self): + self.x_shape = [2, 5, 5] + + +class TestEighAPIError(unittest.TestCase): + def test_error(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + #input maxtrix must greater than 2 dimensions + input_x = paddle.static.data( + name='x_1', shape=[12], dtype='float32') + self.assertRaises(ValueError, paddle.linalg.eigh, input_x) + + #input matrix must be square matrix + input_x = paddle.static.data( + name='x_2', shape=[12, 32], dtype='float32') + self.assertRaises(ValueError, paddle.linalg.eigh, input_x) + + #uplo must be in 'L' or 'U' + input_x = paddle.static.data( + name='x_3', shape=[4, 4], dtype="float32") + uplo = 'R' + self.assertRaises(ValueError, paddle.linalg.eigh, input_x, uplo) + + #x_data cannot be integer + input_x = paddle.static.data( + name='x_4', shape=[4, 4], dtype="int32") + self.assertRaises(TypeError, paddle.linalg.eigh, input_x) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 584c418675726de90eaef1abd7f795509b468b6b..fd87e7584cea52f3f14918bebb07c93a61533ec8 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -32,5 +32,6 @@ no_check_set_white_list = [ 'fusion_lstm', 'softmax_with_cross_entropy', 'svd', + 'eigh', 'class_center_sample', ] diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 27dc2595bfb29aa517d2713318b34c82a25ce8a3..cbb46ed424e3e1479fb1e8ec30136c17f5e0d295 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -19,6 +19,7 @@ from .tensor import inverse as inv # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd +from .tensor.linalg import eigh # noqa: F401 __all__ = [ 'cholesky', #noqa @@ -27,5 +28,6 @@ __all__ = [ 'multi_dot', 'matrix_rank', 'svd', - 'matrix_power' + 'matrix_power', + 'eigh' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index fce4b764a8c9e3818165adefca94efdcb8a8d67a..0f6d09e27a80437653714bba7c8f1eb9740818ae 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -47,6 +47,7 @@ from .linalg import mv # noqa: F401 from .linalg import matrix_power # noqa: F401 from .linalg import multi_dot # noqa: F401 from .linalg import svd # noqa: F401 +from .linalg import eigh # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 96a3610b1894f3e291485e82a88d6e91373fb007..62062377fffb1225522faf71f67884268a1dbd27 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1106,7 +1106,7 @@ def svd(x, full_matrices=False, name=None): def matrix_power(x, n, name=None): r""" Computes the n-th power of a square matrix or a batch of square matrices. - + Let :math:`X` be a sqaure matrix or a batch of square matrices, :math:`n` be an exponent, the equation should be: @@ -1251,3 +1251,72 @@ def multi_dot(x, name=None): out = helper.create_variable_for_type_inference(dtype) helper.append_op(type='multi_dot', inputs={"X": x}, outputs={"Out": out}) return out + + +def eigh(x, UPLO='L', name=None): + """ + Compute the eigenvalues and eigenvectors of a + complex Hermitian (conjugate symmetric) or a real symmetric matrix. + + Args: + x (Tensor): A tensor with shape :math:`[*, N, N]` , The data type of the input Tensor x + should be one of float32, float64, complex64, complex128. + UPLO(str, optional): (string, default 'L'), 'L' represents the lower triangular matrix, + "'U' represents the upper triangular matrix.". + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + + out_value(Tensor): A Tensor with shape [*, N] and data type of float32 and float64. The eigenvalues of eigh op. + out_vector(Tensor): A Tensor with shape [*, N, N] and data type of float32,float64,complex64 and complex128. The eigenvectors of eigh op. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + x_data = np.array([[1, -2j], [2j, 5]]) + x = paddle.to_tensor(x_data) + out_value, out_vector = paddle.eigh(x, UPLO='L') + print(out_value) + #[0.17157288, 5.82842712] + print(out_vector) + #[(-0.9238795325112867+0j), (-0.3826834323650898+0j)], + #[ 0.3826834323650898j , -0.9238795325112867j ]] + + """ + if in_dygraph_mode(): + return _C_ops.eigh(x, 'UPLO', UPLO) + + def __check_input(x, UPLO): + x_shape = list(x.shape) + if len(x.shape) < 2: + raise ValueError( + "Input(input) only support >=2 tensor, but received " + "length of Input(input) is %s." % len(x.shape)) + if x_shape[-1] != x_shape[-2]: + raise ValueError( + "The input matrix must be batches of square matrices. But received x's dimention: {}". + format(x_shape)) + if UPLO is not 'L' and UPLO is not 'U': + raise ValueError( + "UPLO must be L or U. But received UPLO is: {}".format(UPLO)) + + __check_input(x, UPLO) + + helper = LayerHelper('eigh', **locals()) + check_variable_and_dtype( + x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'eigh') + + out_value = helper.create_variable_for_type_inference(dtype=x.dtype) + out_vector = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='eigh', + inputs={'X': x}, + outputs={'Eigenvalues': out_value, + 'Eigenvectors': out_vector}, + attrs={'UPLO': UPLO}) + return out_value, out_vector