未验证 提交 bafd8dec 编写于 作者: X xiongkun 提交者: GitHub

change svd_cpu_kernel from Eigen to Lapack, speed up the compile from 120s -> 20s (#43784)

上级 23036031
......@@ -30,6 +30,7 @@
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......@@ -44,40 +45,42 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols,
void LapackSvd(const T* X, T* U, T* VH, T* S, int rows, int cols,
int full = false) {
auto flag = Eigen::DecompositionOptions::ComputeThinU |
Eigen::DecompositionOptions::ComputeThinV;
if (full) {
flag = Eigen::DecompositionOptions::ComputeFullU |
Eigen::DecompositionOptions::ComputeFullV;
}
Eigen::BDCSVD<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
svd(2, 2, flag);
/*NOTE(xiongkun03) Eigen::Matrix API need non-const pointer.*/
T* input = const_cast<T*>(X);
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input, rows, cols);
svd.compute(m);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> V_trans =
svd.matrixV().transpose();
memcpy(U, svd.matrixU().data(), svd.matrixU().size() * sizeof(T));
memcpy(VH, V_trans.data(), V_trans.size() * sizeof(T));
memcpy(S, svd.singularValues().data(),
svd.singularValues().size() * sizeof(T));
char jobz = full ? 'A' : 'S';
int mx = std::max(rows, cols);
int mn = std::min(rows, cols);
T* a = const_cast<T*>(X);
int lda = rows;
int ldu = rows;
int ldvt = full ? cols : mn;
int lwork = full ? (4 * mn * mn + 6 * mn + mx) : (4 * mn * mn + 7 * mn);
std::vector<T> work(lwork);
std::vector<int> iwork(8 * mn);
int info;
phi::funcs::lapackSvd<T>(jobz, rows, cols, a, lda, S, U, ldu, VH, ldvt,
work.data(), lwork, iwork.data(), &info);
if (info < 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"This %s-th argument has an illegal value", info));
}
if (info > 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"DBDSDC/SBDSDC did not converge, updating process failed. May be you "
"passes a invalid matrix."));
}
}
template <typename T>
void BatchSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, int batches,
int full = false) {
// NOTE: this function is row major, because this function called the lapack.
int stride = rows * cols;
int k = std::min(rows, cols);
int stride_u = full ? rows * rows : k * rows;
int stride_v = full ? cols * cols : k * cols;
for (int i = 0; i < batches; ++i) {
EigenSvd<T>(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k,
LapackSvd<T>(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k,
rows, cols, full);
}
return;
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle {
namespace operators {
......@@ -39,7 +40,12 @@ class SvdCPUKernel : public framework::OpKernel<T> {
/*Create Tensors and output, set the dim ...*/
auto numel = x->numel();
auto* x_data = x->data<T>();
auto& orig_dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto& dev_ctx = static_cast<const typename framework::ConvertToPhiContext<
platform::CPUDeviceContext>::TYPE&>(orig_dev_ctx);
Tensor trans_x = ::phi::TransposeLast2Dim<T>(dev_ctx, *x);
auto* x_data = trans_x.data<T>();
auto x_dims = x->dims();
int rows = x_dims[x_dims.size() - 2];
int cols = x_dims[x_dims.size() - 1];
......@@ -57,6 +63,20 @@ class SvdCPUKernel : public framework::OpKernel<T> {
context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real<T>)));
/*SVD Use the Eigen Library*/
math::BatchSvd<T>(x_data, U_out, VH_out, S_out, rows, cols, batches, full);
/* let C[m, n] as a col major matrix with m rows and n cols.
* let R[m, n] is row major matrix with m rows and n cols.
* then we have: R[m,n] = C[m, n].resize((n,m)).tranpose_last_two()
* */
auto col_major_to_row_major = [&dev_ctx](Tensor* out) {
auto origin_dim = out->dims();
int64_t& x = origin_dim[origin_dim.size() - 1];
int64_t& y = origin_dim[origin_dim.size() - 2];
std::swap(x, y);
out->Resize(origin_dim);
return ::phi::TransposeLast2Dim<T>(dev_ctx, *out);
};
*U = col_major_to_row_major(U);
*VH = col_major_to_row_major(VH);
}
};
......
......@@ -280,6 +280,34 @@ extern "C" void spotrs_(char *uplo,
float *b,
int *ldb,
int *info);
extern "C" void dgesdd_(char *,
int *,
int *,
double *,
int *,
double *,
double *,
int *,
double *,
int *,
double *,
int *,
int *,
int *);
extern "C" void sgesdd_(char *,
int *,
int *,
float *,
int *,
float *,
float *,
int *,
float *,
int *,
float *,
int *,
int *,
int *);
namespace phi {
namespace dynload {
......@@ -328,6 +356,8 @@ extern void *lapack_dso_handle;
__macro(sgelsy_); \
__macro(dgelss_); \
__macro(sgelss_); \
__macro(sgesdd_); \
__macro(dgesdd_); \
__macro(zpotrs_); \
__macro(cpotrs_); \
__macro(dpotrs_); \
......
......@@ -499,5 +499,43 @@ void lapackCholeskySolve<float>(char uplo,
dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
}
template <>
void lapackSvd<double>(char jobz,
int m,
int n,
double *a,
int lda,
double *s,
double *u,
int ldu,
double *vt,
int ldvt,
double *work,
int lwork,
int *iwork,
int *info) {
dynload::dgesdd_(
&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
}
template <>
void lapackSvd<float>(char jobz,
int m,
int n,
float *a,
int lda,
float *s,
float *u,
int ldu,
float *vt,
int ldvt,
float *work,
int lwork,
int *iwork,
int *info) {
dynload::sgesdd_(
&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
}
} // namespace funcs
} // namespace phi
......@@ -120,6 +120,22 @@ void lapackGelss(int m,
T2 *rwork,
int *info);
template <typename T>
void lapackSvd(char jobz,
int m,
int n,
T *a,
int lda,
T *s,
T *u,
int ldu,
T *vt,
int ldvt,
T *work,
int lwork,
int *iwork,
int *info);
template <typename T>
void lapackCholeskySolve(
char uplo, int n, int nrhs, T *a, int lda, T *b, int ldb, int *info);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册