未验证 提交 9b8aafe5 编写于 作者: F From00 提交者: GitHub

Replace Eigen with Lapack library for eigvals OP kernel (#35909)

上级 88ea8e6f
......@@ -14,60 +14,44 @@
#pragma once
#include <complex>
#include <string>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T>
struct PaddleComplex {
using Type = paddle::platform::complex<T>;
};
template <>
struct PaddleComplex<paddle::platform::complex<float>> {
using Type = paddle::platform::complex<float>;
};
template <>
struct PaddleComplex<paddle::platform::complex<double>> {
using Type = paddle::platform::complex<double>;
};
template <typename T, typename enable = void>
struct PaddleComplex;
template <typename T>
struct StdComplex {
using Type = std::complex<T>;
struct PaddleComplex<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using type = paddle::platform::complex<T>;
};
template <>
struct StdComplex<paddle::platform::complex<float>> {
using Type = std::complex<float>;
};
template <>
struct StdComplex<paddle::platform::complex<double>> {
using Type = std::complex<double>;
template <typename T>
struct PaddleComplex<
T, typename std::enable_if<
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type> {
using type = T;
};
template <typename T>
using PaddleCType = typename PaddleComplex<T>::Type;
template <typename T>
using StdCType = typename StdComplex<T>::Type;
using PaddleCType = typename PaddleComplex<T>::type;
template <typename T>
using EigenMatrixPaddle = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
template <typename T>
using EigenVectorPaddle = Eigen::Matrix<PaddleCType<T>, Eigen::Dynamic, 1>;
template <typename T>
using EigenMatrixStd =
Eigen::Matrix<StdCType<T>, Eigen::Dynamic, Eigen::Dynamic>;
template <typename T>
using EigenVectorStd = Eigen::Matrix<StdCType<T>, Eigen::Dynamic, 1>;
using Real = typename math::Real<T>;
static void SpiltBatchSquareMatrix(const Tensor &input,
std::vector<Tensor> *output) {
static void SpiltBatchSquareMatrix(const Tensor& input,
std::vector<Tensor>* output) {
DDim input_dims = input.dims();
int last_dim = input_dims.size() - 1;
int n_dim = input_dims[last_dim];
......@@ -85,42 +69,148 @@ static void SpiltBatchSquareMatrix(const Tensor &input,
(*output) = flattened_input.Split(1, 0);
}
static void CheckLapackEigResult(const int info, const std::string& name) {
PADDLE_ENFORCE_LE(info, 0, platform::errors::PreconditionNotMet(
"The QR algorithm failed to compute all the "
"eigenvalues in function %s.",
name.c_str()));
PADDLE_ENFORCE_GE(
info, 0, platform::errors::InvalidArgument(
"The %d-th argument has an illegal value in function %s.",
-info, name.c_str()));
}
template <typename DeviceContext, typename T>
static typename std::enable_if<std::is_floating_point<T>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* output, Tensor* work, Tensor* rwork /*unused*/) {
Tensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a);
Tensor w;
int64_t n_dim = input.dims()[1];
auto* w_data =
w.mutable_data<T>(framework::make_ddim({n_dim << 1}), ctx.GetPlace());
int64_t work_mem = work->memory_size();
int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE(
work_mem, 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
required_work_mem, work_mem));
int info = 0;
math::lapackEig<T>('N', 'N', static_cast<int>(n_dim), a.template data<T>(),
static_cast<int>(n_dim), w_data, NULL, 1, NULL, 1,
work->template data<T>(),
static_cast<int>(work_mem / sizeof(T)),
static_cast<T*>(NULL), &info);
std::string name = "framework::platform::dynload::dgeev_";
if (input.type() == framework::proto::VarType::FP64) {
name = "framework::platform::dynload::sgeev_";
}
CheckLapackEigResult(info, name);
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), n_dim);
math::RealImagToComplexFunctor<PaddleCType<T>> functor(
w_data, w_data + n_dim, output->template data<PaddleCType<T>>(), n_dim);
for_range(functor);
}
template <typename DeviceContext, typename T>
typename std::enable_if<std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
Tensor* output, Tensor* work, Tensor* rwork) {
Tensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a);
int64_t work_mem = work->memory_size();
int64_t n_dim = input.dims()[1];
int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE(
work_mem, 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
required_work_mem, work_mem));
int64_t rwork_mem = rwork->memory_size();
int64_t required_rwork_mem = (n_dim << 1) * sizeof(Real<T>);
PADDLE_ENFORCE_GE(
rwork_mem, required_rwork_mem,
platform::errors::InvalidArgument(
"The memory size of the rwork tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received rwork\'s memory size = %" PRId64 " bytes.",
required_rwork_mem, rwork_mem));
int info = 0;
math::lapackEig<T, Real<T>>(
'N', 'N', static_cast<int>(n_dim), a.template data<T>(),
static_cast<int>(n_dim), output->template data<T>(), NULL, 1, NULL, 1,
work->template data<T>(), static_cast<int>(work_mem / sizeof(T)),
rwork->template data<Real<T>>(), &info);
std::string name = "framework::platform::dynload::cgeev_";
if (input.type() == framework::proto::VarType::COMPLEX64) {
name = "framework::platform::dynload::zgeev_";
}
CheckLapackEigResult(info, name);
}
template <typename DeviceContext, typename T>
class EigvalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const Tensor *input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out");
auto input_type = input->type();
auto output_type = framework::IsComplexType(input_type)
? input_type
: framework::ToComplexType(input_type);
output->mutable_data(ctx.GetPlace(), output_type);
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
output->mutable_data<PaddleCType<T>>(ctx.GetPlace());
std::vector<Tensor> input_matrices;
SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices);
int n_dim = input_matrices[0].dims()[1];
int n_batch = input_matrices.size();
int64_t n_dim = input_matrices[0].dims()[1];
int64_t n_batch = input_matrices.size();
DDim output_dims = output->dims();
output->Resize(framework::make_ddim({n_batch, n_dim}));
std::vector<Tensor> output_vectors = output->Split(1, 0);
Eigen::Map<EigenMatrixPaddle<T>> input_emp(NULL, n_dim, n_dim);
Eigen::Map<EigenVectorPaddle<T>> output_evp(NULL, n_dim);
EigenMatrixStd<T> input_ems;
EigenVectorStd<T> output_evs;
for (int i = 0; i < n_batch; ++i) {
new (&input_emp) Eigen::Map<EigenMatrixPaddle<T>>(
input_matrices[i].data<T>(), n_dim, n_dim);
new (&output_evp) Eigen::Map<EigenVectorPaddle<T>>(
output_vectors[i].data<PaddleCType<T>>(), n_dim);
input_ems = input_emp.template cast<StdCType<T>>();
output_evs = input_ems.eigenvalues();
output_evp = output_evs.template cast<PaddleCType<T>>();
// query workspace size
T qwork;
int info;
math::lapackEig<T, Real<T>>('N', 'N', static_cast<int>(n_dim),
input_matrices[0].template data<T>(),
static_cast<int>(n_dim), NULL, NULL, 1, NULL, 1,
&qwork, -1, static_cast<Real<T>*>(NULL), &info);
int64_t lwork = static_cast<int64_t>(qwork);
Tensor work, rwork;
try {
work.mutable_data<T>(framework::make_ddim({lwork}), ctx.GetPlace());
} catch (memory::allocation::BadAlloc&) {
LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal "
<< "memory size = " << lwork * sizeof(T) << " bytes, "
<< "try reallocating a smaller workspace with the minimum "
<< "required size = " << 3 * n_dim * sizeof(T) << " bytes, "
<< "this may lead to bad performance.";
lwork = 3 * n_dim;
work.mutable_data<T>(framework::make_ddim({lwork}), ctx.GetPlace());
}
if (framework::IsComplexType(input->type())) {
rwork.mutable_data<Real<T>>(framework::make_ddim({n_dim << 1}),
ctx.GetPlace());
}
for (int64_t i = 0; i < n_batch; ++i) {
LapackEigvals<DeviceContext, T>(ctx, input_matrices[i],
&output_vectors[i], &work, &rwork);
}
output->Resize(output_dims);
}
......
......@@ -313,6 +313,29 @@ struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
int64_t numel_;
};
template <typename T, typename Enable = void>
struct RealImagToComplexFunctor;
template <typename T>
struct RealImagToComplexFunctor<T, Complex<T, Real<T>>> {
RealImagToComplexFunctor(const Real<T>* input_real, const Real<T>* input_imag,
T* output, int64_t numel)
: input_real_(input_real),
input_imag_(input_imag),
output_(output),
numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
output_[idx].real = input_real_[idx];
output_[idx].imag = input_imag_[idx];
}
const Real<T>* input_real_;
const Real<T>* input_imag_;
T* output_;
int64_t numel_;
};
template <typename T, typename Enable = void>
struct ConjFunctor;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/dynload/lapack.h"
namespace paddle {
......@@ -30,6 +31,57 @@ void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info);
}
// Eig
template <>
void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda,
double *w, double *vl, int ldvl, double *vr, int ldvr,
double *work, int lwork, double *rwork, int *info) {
double *wr = w;
double *wi = w + n;
(void)rwork; // unused
platform::dynload::dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr,
&ldvr, work, &lwork, info);
}
template <>
void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda,
float *w, float *vl, int ldvl, float *vr, int ldvr,
float *work, int lwork, float *rwork, int *info) {
float *wr = w;
float *wi = w + n;
(void)rwork; // unused
platform::dynload::sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr,
&ldvr, work, &lwork, info);
}
template <>
void lapackEig<platform::complex<double>, double>(
char jobvl, char jobvr, int n, platform::complex<double> *a, int lda,
platform::complex<double> *w, platform::complex<double> *vl, int ldvl,
platform::complex<double> *vr, int ldvr, platform::complex<double> *work,
int lwork, double *rwork, int *info) {
platform::dynload::zgeev_(
&jobvl, &jobvr, &n, reinterpret_cast<std::complex<double> *>(a), &lda,
reinterpret_cast<std::complex<double> *>(w),
reinterpret_cast<std::complex<double> *>(vl), &ldvl,
reinterpret_cast<std::complex<double> *>(vr), &ldvr,
reinterpret_cast<std::complex<double> *>(work), &lwork, rwork, info);
}
template <>
void lapackEig<platform::complex<float>, float>(
char jobvl, char jobvr, int n, platform::complex<float> *a, int lda,
platform::complex<float> *w, platform::complex<float> *vl, int ldvl,
platform::complex<float> *vr, int ldvr, platform::complex<float> *work,
int lwork, float *rwork, int *info) {
platform::dynload::cgeev_(
&jobvl, &jobvr, &n, reinterpret_cast<std::complex<float> *>(a), &lda,
reinterpret_cast<std::complex<float> *>(w),
reinterpret_cast<std::complex<float> *>(vl), &ldvl,
reinterpret_cast<std::complex<float> *>(vr), &ldvr,
reinterpret_cast<std::complex<float> *>(work), &lwork, rwork, info);
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -22,6 +22,11 @@ namespace math {
template <typename T>
void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info);
template <typename T1, typename T2 = T1>
void lapackEig(char jobvl, char jobvr, int n, T1 *a, int lda, T1 *w, T1 *vl,
int ldvl, T1 *vr, int ldvr, T1 *work, int lwork, T2 *rwork,
int *info);
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <complex>
#include <mutex>
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
......@@ -27,6 +28,27 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv,
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv,
int *info);
// geev
extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda,
double *wr, double *wi, double *vl, int *ldvl,
double *vr, int *ldvr, double *work, int *lwork,
int *info);
extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda,
float *wr, float *wi, float *vl, int *ldvl, float *vr,
int *ldvr, float *work, int *lwork, int *info);
extern "C" void zgeev_(char *jobvl, char *jobvr, int *n,
std::complex<double> *a, int *lda,
std::complex<double> *w, std::complex<double> *vl,
int *ldvl, std::complex<double> *vr, int *ldvr,
std::complex<double> *work, int *lwork, double *rwork,
int *info);
extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, std::complex<float> *a,
int *lda, std::complex<float> *w,
std::complex<float> *vl, int *ldvl,
std::complex<float> *vr, int *ldvr,
std::complex<float> *work, int *lwork, float *rwork,
int *info);
namespace paddle {
namespace platform {
namespace dynload {
......@@ -58,7 +80,11 @@ extern void *lapack_dso_handle;
#define LAPACK_ROUTINE_EACH(__macro) \
__macro(dgetrf_); \
__macro(sgetrf_);
__macro(sgetrf_); \
__macro(dgeev_); \
__macro(sgeev_); \
__macro(zgeev_); \
__macro(cgeev_);
LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册