未验证 提交 9b8aafe5 编写于 作者: F From00 提交者: GitHub

Replace Eigen with Lapack library for eigvals OP kernel (#35909)

上级 88ea8e6f
...@@ -14,60 +14,44 @@ ...@@ -14,60 +14,44 @@
#pragma once #pragma once
#include <complex> #include <string>
#include <vector> #include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using DDim = framework::DDim; using DDim = framework::DDim;
template <typename T> template <typename T, typename enable = void>
struct PaddleComplex { struct PaddleComplex;
using Type = paddle::platform::complex<T>;
};
template <>
struct PaddleComplex<paddle::platform::complex<float>> {
using Type = paddle::platform::complex<float>;
};
template <>
struct PaddleComplex<paddle::platform::complex<double>> {
using Type = paddle::platform::complex<double>;
};
template <typename T> template <typename T>
struct StdComplex { struct PaddleComplex<
using Type = std::complex<T>; T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using type = paddle::platform::complex<T>;
}; };
template <> template <typename T>
struct StdComplex<paddle::platform::complex<float>> { struct PaddleComplex<
using Type = std::complex<float>; T, typename std::enable_if<
}; std::is_same<T, platform::complex<float>>::value ||
template <> std::is_same<T, platform::complex<double>>::value>::type> {
struct StdComplex<paddle::platform::complex<double>> { using type = T;
using Type = std::complex<double>;
}; };
template <typename T> template <typename T>
using PaddleCType = typename PaddleComplex<T>::Type; using PaddleCType = typename PaddleComplex<T>::type;
template <typename T>
using StdCType = typename StdComplex<T>::Type;
template <typename T> template <typename T>
using EigenMatrixPaddle = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>; using Real = typename math::Real<T>;
template <typename T>
using EigenVectorPaddle = Eigen::Matrix<PaddleCType<T>, Eigen::Dynamic, 1>;
template <typename T>
using EigenMatrixStd =
Eigen::Matrix<StdCType<T>, Eigen::Dynamic, Eigen::Dynamic>;
template <typename T>
using EigenVectorStd = Eigen::Matrix<StdCType<T>, Eigen::Dynamic, 1>;
static void SpiltBatchSquareMatrix(const Tensor &input, static void SpiltBatchSquareMatrix(const Tensor& input,
std::vector<Tensor> *output) { std::vector<Tensor>* output) {
DDim input_dims = input.dims(); DDim input_dims = input.dims();
int last_dim = input_dims.size() - 1; int last_dim = input_dims.size() - 1;
int n_dim = input_dims[last_dim]; int n_dim = input_dims[last_dim];
...@@ -85,42 +69,148 @@ static void SpiltBatchSquareMatrix(const Tensor &input, ...@@ -85,42 +69,148 @@ static void SpiltBatchSquareMatrix(const Tensor &input,
(*output) = flattened_input.Split(1, 0); (*output) = flattened_input.Split(1, 0);
} }
static void CheckLapackEigResult(const int info, const std::string& name) {
PADDLE_ENFORCE_LE(info, 0, platform::errors::PreconditionNotMet(
"The QR algorithm failed to compute all the "
"eigenvalues in function %s.",
name.c_str()));
PADDLE_ENFORCE_GE(
info, 0, platform::errors::InvalidArgument(
"The %d-th argument has an illegal value in function %s.",
-info, name.c_str()));
}
template <typename DeviceContext, typename T>
static typename std::enable_if<std::is_floating_point<T>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* output, Tensor* work, Tensor* rwork /*unused*/) {
Tensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a);
Tensor w;
int64_t n_dim = input.dims()[1];
auto* w_data =
w.mutable_data<T>(framework::make_ddim({n_dim << 1}), ctx.GetPlace());
int64_t work_mem = work->memory_size();
int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE(
work_mem, 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
required_work_mem, work_mem));
int info = 0;
math::lapackEig<T>('N', 'N', static_cast<int>(n_dim), a.template data<T>(),
static_cast<int>(n_dim), w_data, NULL, 1, NULL, 1,
work->template data<T>(),
static_cast<int>(work_mem / sizeof(T)),
static_cast<T*>(NULL), &info);
std::string name = "framework::platform::dynload::dgeev_";
if (input.type() == framework::proto::VarType::FP64) {
name = "framework::platform::dynload::sgeev_";
}
CheckLapackEigResult(info, name);
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), n_dim);
math::RealImagToComplexFunctor<PaddleCType<T>> functor(
w_data, w_data + n_dim, output->template data<PaddleCType<T>>(), n_dim);
for_range(functor);
}
template <typename DeviceContext, typename T>
typename std::enable_if<std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* output, Tensor* work, Tensor* rwork) {
Tensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a);
int64_t work_mem = work->memory_size();
int64_t n_dim = input.dims()[1];
int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE(
work_mem, 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
required_work_mem, work_mem));
int64_t rwork_mem = rwork->memory_size();
int64_t required_rwork_mem = (n_dim << 1) * sizeof(Real<T>);
PADDLE_ENFORCE_GE(
rwork_mem, required_rwork_mem,
platform::errors::InvalidArgument(
"The memory size of the rwork tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received rwork\'s memory size = %" PRId64 " bytes.",
required_rwork_mem, rwork_mem));
int info = 0;
math::lapackEig<T, Real<T>>(
'N', 'N', static_cast<int>(n_dim), a.template data<T>(),
static_cast<int>(n_dim), output->template data<T>(), NULL, 1, NULL, 1,
work->template data<T>(), static_cast<int>(work_mem / sizeof(T)),
rwork->template data<Real<T>>(), &info);
std::string name = "framework::platform::dynload::cgeev_";
if (input.type() == framework::proto::VarType::COMPLEX64) {
name = "framework::platform::dynload::zgeev_";
}
CheckLapackEigResult(info, name);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class EigvalsKernel : public framework::OpKernel<T> { class EigvalsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor *input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
output->mutable_data<PaddleCType<T>>(ctx.GetPlace());
auto input_type = input->type();
auto output_type = framework::IsComplexType(input_type)
? input_type
: framework::ToComplexType(input_type);
output->mutable_data(ctx.GetPlace(), output_type);
std::vector<Tensor> input_matrices; std::vector<Tensor> input_matrices;
SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices);
int n_dim = input_matrices[0].dims()[1]; int64_t n_dim = input_matrices[0].dims()[1];
int n_batch = input_matrices.size(); int64_t n_batch = input_matrices.size();
DDim output_dims = output->dims(); DDim output_dims = output->dims();
output->Resize(framework::make_ddim({n_batch, n_dim})); output->Resize(framework::make_ddim({n_batch, n_dim}));
std::vector<Tensor> output_vectors = output->Split(1, 0); std::vector<Tensor> output_vectors = output->Split(1, 0);
Eigen::Map<EigenMatrixPaddle<T>> input_emp(NULL, n_dim, n_dim); // query workspace size
Eigen::Map<EigenVectorPaddle<T>> output_evp(NULL, n_dim); T qwork;
EigenMatrixStd<T> input_ems; int info;
EigenVectorStd<T> output_evs; math::lapackEig<T, Real<T>>('N', 'N', static_cast<int>(n_dim),
input_matrices[0].template data<T>(),
for (int i = 0; i < n_batch; ++i) { static_cast<int>(n_dim), NULL, NULL, 1, NULL, 1,
new (&input_emp) Eigen::Map<EigenMatrixPaddle<T>>( &qwork, -1, static_cast<Real<T>*>(NULL), &info);
input_matrices[i].data<T>(), n_dim, n_dim); int64_t lwork = static_cast<int64_t>(qwork);
new (&output_evp) Eigen::Map<EigenVectorPaddle<T>>(
output_vectors[i].data<PaddleCType<T>>(), n_dim); Tensor work, rwork;
input_ems = input_emp.template cast<StdCType<T>>(); try {
output_evs = input_ems.eigenvalues(); work.mutable_data<T>(framework::make_ddim({lwork}), ctx.GetPlace());
output_evp = output_evs.template cast<PaddleCType<T>>(); } catch (memory::allocation::BadAlloc&) {
LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal "
<< "memory size = " << lwork * sizeof(T) << " bytes, "
<< "try reallocating a smaller workspace with the minimum "
<< "required size = " << 3 * n_dim * sizeof(T) << " bytes, "
<< "this may lead to bad performance.";
lwork = 3 * n_dim;
work.mutable_data<T>(framework::make_ddim({lwork}), ctx.GetPlace());
}
if (framework::IsComplexType(input->type())) {
rwork.mutable_data<Real<T>>(framework::make_ddim({n_dim << 1}),
ctx.GetPlace());
}
for (int64_t i = 0; i < n_batch; ++i) {
LapackEigvals<DeviceContext, T>(ctx, input_matrices[i],
&output_vectors[i], &work, &rwork);
} }
output->Resize(output_dims); output->Resize(output_dims);
} }
......
...@@ -313,6 +313,29 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> { ...@@ -313,6 +313,29 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
int64_t numel_; int64_t numel_;
}; };
template <typename T, typename Enable = void>
struct RealImagToComplexFunctor;
template <typename T>
struct RealImagToComplexFunctor<T, Complex<T, Real<T>>> {
RealImagToComplexFunctor(const Real<T>* input_real, const Real<T>* input_imag,
T* output, int64_t numel)
: input_real_(input_real),
input_imag_(input_imag),
output_(output),
numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx].real = input_real_[idx];
output_[idx].imag = input_imag_[idx];
}
const Real<T>* input_real_;
const Real<T>* input_imag_;
T* output_;
int64_t numel_;
};
template <typename T, typename Enable = void> template <typename T, typename Enable = void>
struct ConjFunctor; struct ConjFunctor;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/math/lapack_function.h" #include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/dynload/lapack.h" #include "paddle/fluid/platform/dynload/lapack.h"
namespace paddle { namespace paddle {
...@@ -30,6 +31,57 @@ void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) { ...@@ -30,6 +31,57 @@ void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info);
} }
// Eig
template <>
void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda,
double *w, double *vl, int ldvl, double *vr, int ldvr,
double *work, int lwork, double *rwork, int *info) {
double *wr = w;
double *wi = w + n;
(void)rwork; // unused
platform::dynload::dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr,
&ldvr, work, &lwork, info);
}
template <>
void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda,
float *w, float *vl, int ldvl, float *vr, int ldvr,
float *work, int lwork, float *rwork, int *info) {
float *wr = w;
float *wi = w + n;
(void)rwork; // unused
platform::dynload::sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr,
&ldvr, work, &lwork, info);
}
template <>
void lapackEig<platform::complex<double>, double>(
char jobvl, char jobvr, int n, platform::complex<double> *a, int lda,
platform::complex<double> *w, platform::complex<double> *vl, int ldvl,
platform::complex<double> *vr, int ldvr, platform::complex<double> *work,
int lwork, double *rwork, int *info) {
platform::dynload::zgeev_(
&jobvl, &jobvr, &n, reinterpret_cast<std::complex<double> *>(a), &lda,
reinterpret_cast<std::complex<double> *>(w),
reinterpret_cast<std::complex<double> *>(vl), &ldvl,
reinterpret_cast<std::complex<double> *>(vr), &ldvr,
reinterpret_cast<std::complex<double> *>(work), &lwork, rwork, info);
}
template <>
void lapackEig<platform::complex<float>, float>(
char jobvl, char jobvr, int n, platform::complex<float> *a, int lda,
platform::complex<float> *w, platform::complex<float> *vl, int ldvl,
platform::complex<float> *vr, int ldvr, platform::complex<float> *work,
int lwork, float *rwork, int *info) {
platform::dynload::cgeev_(
&jobvl, &jobvr, &n, reinterpret_cast<std::complex<float> *>(a), &lda,
reinterpret_cast<std::complex<float> *>(w),
reinterpret_cast<std::complex<float> *>(vl), &ldvl,
reinterpret_cast<std::complex<float> *>(vr), &ldvr,
reinterpret_cast<std::complex<float> *>(work), &lwork, rwork, info);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,11 @@ namespace math { ...@@ -22,6 +22,11 @@ namespace math {
template <typename T> template <typename T>
void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info);
template <typename T1, typename T2 = T1>
void lapackEig(char jobvl, char jobvr, int n, T1 *a, int lda, T1 *w, T1 *vl,
int ldvl, T1 *vr, int ldvr, T1 *work, int lwork, T2 *rwork,
int *info);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <complex>
#include <mutex> #include <mutex>
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
...@@ -27,6 +28,27 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, ...@@ -27,6 +28,27 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv,
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv,
int *info); int *info);
// geev
extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda,
double *wr, double *wi, double *vl, int *ldvl,
double *vr, int *ldvr, double *work, int *lwork,
int *info);
extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda,
float *wr, float *wi, float *vl, int *ldvl, float *vr,
int *ldvr, float *work, int *lwork, int *info);
extern "C" void zgeev_(char *jobvl, char *jobvr, int *n,
std::complex<double> *a, int *lda,
std::complex<double> *w, std::complex<double> *vl,
int *ldvl, std::complex<double> *vr, int *ldvr,
std::complex<double> *work, int *lwork, double *rwork,
int *info);
extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, std::complex<float> *a,
int *lda, std::complex<float> *w,
std::complex<float> *vl, int *ldvl,
std::complex<float> *vr, int *ldvr,
std::complex<float> *work, int *lwork, float *rwork,
int *info);
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
...@@ -58,7 +80,11 @@ extern void *lapack_dso_handle; ...@@ -58,7 +80,11 @@ extern void *lapack_dso_handle;
#define LAPACK_ROUTINE_EACH(__macro) \ #define LAPACK_ROUTINE_EACH(__macro) \
__macro(dgetrf_); \ __macro(dgetrf_); \
__macro(sgetrf_); __macro(sgetrf_); \
__macro(dgeev_); \
__macro(sgeev_); \
__macro(zgeev_); \
__macro(cgeev_);
LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP); LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册