From e9c04149aeab020784ca3fd72184dc397e79440a Mon Sep 17 00:00:00 2001 From: From00 Date: Fri, 24 Sep 2021 15:17:50 +0800 Subject: [PATCH] [cherry-pick] Replace Eigen with Lapack library for eigvals OP kernel (#35909) (#36038) This PR implements the kernel of "eigvals" OP with the Lapack library, which has a better performance than the previous Eigen library. --- paddle/fluid/operators/eigvals_op.h | 212 +++++++++++++----- .../fluid/operators/math/complex_functors.h | 23 ++ .../fluid/operators/math/lapack_function.cc | 52 +++++ paddle/fluid/operators/math/lapack_function.h | 5 + paddle/fluid/platform/dynload/lapack.h | 28 ++- 5 files changed, 258 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h index 998dcd9f1e..6fdf849ac7 100644 --- a/paddle/fluid/operators/eigvals_op.h +++ b/paddle/fluid/operators/eigvals_op.h @@ -14,60 +14,44 @@ #pragma once -#include +#include #include -#include "Eigen/Dense" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/lapack_function.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; -template -struct PaddleComplex { - using Type = paddle::platform::complex; -}; -template <> -struct PaddleComplex> { - using Type = paddle::platform::complex; -}; -template <> -struct PaddleComplex> { - using Type = paddle::platform::complex; -}; +template +struct PaddleComplex; template -struct StdComplex { - using Type = std::complex; +struct PaddleComplex< + T, typename std::enable_if::value>::type> { + using type = paddle::platform::complex; }; -template <> -struct StdComplex> { - using Type = std::complex; -}; -template <> -struct StdComplex> { - using Type = std::complex; +template +struct PaddleComplex< + T, typename std::enable_if< + std::is_same>::value || + std::is_same>::value>::type> { + using type = T; }; template -using PaddleCType = typename PaddleComplex::Type; -template -using StdCType = typename StdComplex::Type; +using PaddleCType = typename PaddleComplex::type; template -using EigenMatrixPaddle = Eigen::Matrix; -template -using EigenVectorPaddle = Eigen::Matrix, Eigen::Dynamic, 1>; -template -using EigenMatrixStd = - Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>; -template -using EigenVectorStd = Eigen::Matrix, Eigen::Dynamic, 1>; +using Real = typename math::Real; -static void SpiltBatchSquareMatrix(const Tensor &input, - std::vector *output) { +static void SpiltBatchSquareMatrix(const Tensor& input, + std::vector* output) { DDim input_dims = input.dims(); int last_dim = input_dims.size() - 1; int n_dim = input_dims[last_dim]; @@ -85,42 +69,148 @@ static void SpiltBatchSquareMatrix(const Tensor &input, (*output) = flattened_input.Split(1, 0); } +static void CheckLapackEigResult(const int info, const std::string& name) { + PADDLE_ENFORCE_LE(info, 0, platform::errors::PreconditionNotMet( + "The QR algorithm failed to compute all the " + "eigenvalues in function %s.", + name.c_str())); + PADDLE_ENFORCE_GE( + info, 0, platform::errors::InvalidArgument( + "The %d-th argument has an illegal value in function %s.", + -info, name.c_str())); +} + +template +static typename std::enable_if::value>::type +LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, + Tensor* output, Tensor* work, Tensor* rwork /*unused*/) { + Tensor a; // will be overwritten when lapackEig exit + framework::TensorCopy(input, input.place(), &a); + + Tensor w; + int64_t n_dim = input.dims()[1]; + auto* w_data = + w.mutable_data(framework::make_ddim({n_dim << 1}), ctx.GetPlace()); + + int64_t work_mem = work->memory_size(); + int64_t required_work_mem = 3 * n_dim * sizeof(T); + PADDLE_ENFORCE_GE( + work_mem, 3 * n_dim * sizeof(T), + platform::errors::InvalidArgument( + "The memory size of the work tensor in LapackEigvals function " + "should be at least %" PRId64 " bytes, " + "but received work\'s memory size = %" PRId64 " bytes.", + required_work_mem, work_mem)); + + int info = 0; + math::lapackEig('N', 'N', static_cast(n_dim), a.template data(), + static_cast(n_dim), w_data, NULL, 1, NULL, 1, + work->template data(), + static_cast(work_mem / sizeof(T)), + static_cast(NULL), &info); + + std::string name = "framework::platform::dynload::dgeev_"; + if (input.type() == framework::proto::VarType::FP64) { + name = "framework::platform::dynload::sgeev_"; + } + CheckLapackEigResult(info, name); + + platform::ForRange for_range( + ctx.template device_context(), n_dim); + math::RealImagToComplexFunctor> functor( + w_data, w_data + n_dim, output->template data>(), n_dim); + for_range(functor); +} + +template +typename std::enable_if>::value || + std::is_same>::value>::type +LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input, + Tensor* output, Tensor* work, Tensor* rwork) { + Tensor a; // will be overwritten when lapackEig exit + framework::TensorCopy(input, input.place(), &a); + + int64_t work_mem = work->memory_size(); + int64_t n_dim = input.dims()[1]; + int64_t required_work_mem = 3 * n_dim * sizeof(T); + PADDLE_ENFORCE_GE( + work_mem, 3 * n_dim * sizeof(T), + platform::errors::InvalidArgument( + "The memory size of the work tensor in LapackEigvals function " + "should be at least %" PRId64 " bytes, " + "but received work\'s memory size = %" PRId64 " bytes.", + required_work_mem, work_mem)); + + int64_t rwork_mem = rwork->memory_size(); + int64_t required_rwork_mem = (n_dim << 1) * sizeof(Real); + PADDLE_ENFORCE_GE( + rwork_mem, required_rwork_mem, + platform::errors::InvalidArgument( + "The memory size of the rwork tensor in LapackEigvals function " + "should be at least %" PRId64 " bytes, " + "but received rwork\'s memory size = %" PRId64 " bytes.", + required_rwork_mem, rwork_mem)); + + int info = 0; + math::lapackEig>( + 'N', 'N', static_cast(n_dim), a.template data(), + static_cast(n_dim), output->template data(), NULL, 1, NULL, 1, + work->template data(), static_cast(work_mem / sizeof(T)), + rwork->template data>(), &info); + + std::string name = "framework::platform::dynload::cgeev_"; + if (input.type() == framework::proto::VarType::COMPLEX64) { + name = "framework::platform::dynload::zgeev_"; + } + CheckLapackEigResult(info, name); +} + template class EigvalsKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - const Tensor *input = ctx.Input("X"); - Tensor *output = ctx.Output("Out"); - - auto input_type = input->type(); - auto output_type = framework::IsComplexType(input_type) - ? input_type - : framework::ToComplexType(input_type); - output->mutable_data(ctx.GetPlace(), output_type); + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("X"); + Tensor* output = ctx.Output("Out"); + output->mutable_data>(ctx.GetPlace()); std::vector input_matrices; SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); - int n_dim = input_matrices[0].dims()[1]; - int n_batch = input_matrices.size(); - + int64_t n_dim = input_matrices[0].dims()[1]; + int64_t n_batch = input_matrices.size(); DDim output_dims = output->dims(); output->Resize(framework::make_ddim({n_batch, n_dim})); std::vector output_vectors = output->Split(1, 0); - Eigen::Map> input_emp(NULL, n_dim, n_dim); - Eigen::Map> output_evp(NULL, n_dim); - EigenMatrixStd input_ems; - EigenVectorStd output_evs; - - for (int i = 0; i < n_batch; ++i) { - new (&input_emp) Eigen::Map>( - input_matrices[i].data(), n_dim, n_dim); - new (&output_evp) Eigen::Map>( - output_vectors[i].data>(), n_dim); - input_ems = input_emp.template cast>(); - output_evs = input_ems.eigenvalues(); - output_evp = output_evs.template cast>(); + // query workspace size + T qwork; + int info; + math::lapackEig>('N', 'N', static_cast(n_dim), + input_matrices[0].template data(), + static_cast(n_dim), NULL, NULL, 1, NULL, 1, + &qwork, -1, static_cast*>(NULL), &info); + int64_t lwork = static_cast(qwork); + + Tensor work, rwork; + try { + work.mutable_data(framework::make_ddim({lwork}), ctx.GetPlace()); + } catch (memory::allocation::BadAlloc&) { + LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal " + << "memory size = " << lwork * sizeof(T) << " bytes, " + << "try reallocating a smaller workspace with the minimum " + << "required size = " << 3 * n_dim * sizeof(T) << " bytes, " + << "this may lead to bad performance."; + lwork = 3 * n_dim; + work.mutable_data(framework::make_ddim({lwork}), ctx.GetPlace()); + } + if (framework::IsComplexType(input->type())) { + rwork.mutable_data>(framework::make_ddim({n_dim << 1}), + ctx.GetPlace()); + } + + for (int64_t i = 0; i < n_batch; ++i) { + LapackEigvals(ctx, input_matrices[i], + &output_vectors[i], &work, &rwork); } output->Resize(output_dims); } diff --git a/paddle/fluid/operators/math/complex_functors.h b/paddle/fluid/operators/math/complex_functors.h index c4bd6ec4f1..3214adb095 100644 --- a/paddle/fluid/operators/math/complex_functors.h +++ b/paddle/fluid/operators/math/complex_functors.h @@ -313,6 +313,29 @@ struct ImagToComplexFunctor>> { int64_t numel_; }; +template +struct RealImagToComplexFunctor; + +template +struct RealImagToComplexFunctor>> { + RealImagToComplexFunctor(const Real* input_real, const Real* input_imag, + T* output, int64_t numel) + : input_real_(input_real), + input_imag_(input_imag), + output_(output), + numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx].real = input_real_[idx]; + output_[idx].imag = input_imag_[idx]; + } + + const Real* input_real_; + const Real* input_imag_; + T* output_; + int64_t numel_; +}; + template struct ConjFunctor; diff --git a/paddle/fluid/operators/math/lapack_function.cc b/paddle/fluid/operators/math/lapack_function.cc index 54033a444a..5adb20b9a7 100644 --- a/paddle/fluid/operators/math/lapack_function.cc +++ b/paddle/fluid/operators/math/lapack_function.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/math/lapack_function.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/dynload/lapack.h" namespace paddle { @@ -30,6 +31,57 @@ void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); } +// Eig +template <> +void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, + double *w, double *vl, int ldvl, double *vr, int ldvr, + double *work, int lwork, double *rwork, int *info) { + double *wr = w; + double *wi = w + n; + (void)rwork; // unused + platform::dynload::dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, + &ldvr, work, &lwork, info); +} + +template <> +void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, + float *w, float *vl, int ldvl, float *vr, int ldvr, + float *work, int lwork, float *rwork, int *info) { + float *wr = w; + float *wi = w + n; + (void)rwork; // unused + platform::dynload::sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, + &ldvr, work, &lwork, info); +} + +template <> +void lapackEig, double>( + char jobvl, char jobvr, int n, platform::complex *a, int lda, + platform::complex *w, platform::complex *vl, int ldvl, + platform::complex *vr, int ldvr, platform::complex *work, + int lwork, double *rwork, int *info) { + platform::dynload::zgeev_( + &jobvl, &jobvr, &n, reinterpret_cast *>(a), &lda, + reinterpret_cast *>(w), + reinterpret_cast *>(vl), &ldvl, + reinterpret_cast *>(vr), &ldvr, + reinterpret_cast *>(work), &lwork, rwork, info); +} + +template <> +void lapackEig, float>( + char jobvl, char jobvr, int n, platform::complex *a, int lda, + platform::complex *w, platform::complex *vl, int ldvl, + platform::complex *vr, int ldvr, platform::complex *work, + int lwork, float *rwork, int *info) { + platform::dynload::cgeev_( + &jobvl, &jobvr, &n, reinterpret_cast *>(a), &lda, + reinterpret_cast *>(w), + reinterpret_cast *>(vl), &ldvl, + reinterpret_cast *>(vr), &ldvr, + reinterpret_cast *>(work), &lwork, rwork, info); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/lapack_function.h b/paddle/fluid/operators/math/lapack_function.h index 694da4603b..a9cc2d2c00 100644 --- a/paddle/fluid/operators/math/lapack_function.h +++ b/paddle/fluid/operators/math/lapack_function.h @@ -22,6 +22,11 @@ namespace math { template void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); +template +void lapackEig(char jobvl, char jobvr, int n, T1 *a, int lda, T1 *w, T1 *vl, + int ldvl, T1 *vr, int ldvr, T1 *work, int lwork, T2 *rwork, + int *info); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/dynload/lapack.h b/paddle/fluid/platform/dynload/lapack.h index ffb3d3e0f6..db95e557eb 100644 --- a/paddle/fluid/platform/dynload/lapack.h +++ b/paddle/fluid/platform/dynload/lapack.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/port.h" @@ -27,6 +28,27 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); +// geev +extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, + double *wr, double *wi, double *vl, int *ldvl, + double *vr, int *ldvr, double *work, int *lwork, + int *info); +extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, + float *wr, float *wi, float *vl, int *ldvl, float *vr, + int *ldvr, float *work, int *lwork, int *info); +extern "C" void zgeev_(char *jobvl, char *jobvr, int *n, + std::complex *a, int *lda, + std::complex *w, std::complex *vl, + int *ldvl, std::complex *vr, int *ldvr, + std::complex *work, int *lwork, double *rwork, + int *info); +extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, std::complex *a, + int *lda, std::complex *w, + std::complex *vl, int *ldvl, + std::complex *vr, int *ldvr, + std::complex *work, int *lwork, float *rwork, + int *info); + namespace paddle { namespace platform { namespace dynload { @@ -58,7 +80,11 @@ extern void *lapack_dso_handle; #define LAPACK_ROUTINE_EACH(__macro) \ __macro(dgetrf_); \ - __macro(sgetrf_); + __macro(sgetrf_); \ + __macro(dgeev_); \ + __macro(sgeev_); \ + __macro(zgeev_); \ + __macro(cgeev_); LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP); -- GitLab