未验证 提交 bafd8dec 编写于 作者: X xiongkun 提交者: GitHub

change svd_cpu_kernel from Eigen to Lapack, speed up the compile from 120s -> 20s (#43784)

上级 23036031
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
...@@ -44,40 +45,42 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -44,40 +45,42 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T> template <typename T>
void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, void LapackSvd(const T* X, T* U, T* VH, T* S, int rows, int cols,
int full = false) { int full = false) {
auto flag = Eigen::DecompositionOptions::ComputeThinU | char jobz = full ? 'A' : 'S';
Eigen::DecompositionOptions::ComputeThinV; int mx = std::max(rows, cols);
if (full) { int mn = std::min(rows, cols);
flag = Eigen::DecompositionOptions::ComputeFullU | T* a = const_cast<T*>(X);
Eigen::DecompositionOptions::ComputeFullV; int lda = rows;
} int ldu = rows;
Eigen::BDCSVD< int ldvt = full ? cols : mn;
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> int lwork = full ? (4 * mn * mn + 6 * mn + mx) : (4 * mn * mn + 7 * mn);
svd(2, 2, flag); std::vector<T> work(lwork);
/*NOTE(xiongkun03) Eigen::Matrix API need non-const pointer.*/ std::vector<int> iwork(8 * mn);
T* input = const_cast<T*>(X); int info;
auto m = Eigen::Map< phi::funcs::lapackSvd<T>(jobz, rows, cols, a, lda, S, U, ldu, VH, ldvt,
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( work.data(), lwork, iwork.data(), &info);
input, rows, cols); if (info < 0) {
svd.compute(m); PADDLE_THROW(platform::errors::InvalidArgument(
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> V_trans = "This %s-th argument has an illegal value", info));
svd.matrixV().transpose(); }
memcpy(U, svd.matrixU().data(), svd.matrixU().size() * sizeof(T)); if (info > 0) {
memcpy(VH, V_trans.data(), V_trans.size() * sizeof(T)); PADDLE_THROW(platform::errors::InvalidArgument(
memcpy(S, svd.singularValues().data(), "DBDSDC/SBDSDC did not converge, updating process failed. May be you "
svd.singularValues().size() * sizeof(T)); "passes a invalid matrix."));
}
} }
template <typename T> template <typename T>
void BatchSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, int batches, void BatchSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, int batches,
int full = false) { int full = false) {
// NOTE: this function is row major, because this function called the lapack.
int stride = rows * cols; int stride = rows * cols;
int k = std::min(rows, cols); int k = std::min(rows, cols);
int stride_u = full ? rows * rows : k * rows; int stride_u = full ? rows * rows : k * rows;
int stride_v = full ? cols * cols : k * cols; int stride_v = full ? cols * cols : k * cols;
for (int i = 0; i < batches; ++i) { for (int i = 0; i < batches; ++i) {
EigenSvd<T>(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k, LapackSvd<T>(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k,
rows, cols, full); rows, cols, full);
} }
return; return;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -39,7 +40,12 @@ class SvdCPUKernel : public framework::OpKernel<T> { ...@@ -39,7 +40,12 @@ class SvdCPUKernel : public framework::OpKernel<T> {
/*Create Tensors and output, set the dim ...*/ /*Create Tensors and output, set the dim ...*/
auto numel = x->numel(); auto numel = x->numel();
auto* x_data = x->data<T>(); auto& orig_dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto& dev_ctx = static_cast<const typename framework::ConvertToPhiContext<
platform::CPUDeviceContext>::TYPE&>(orig_dev_ctx);
Tensor trans_x = ::phi::TransposeLast2Dim<T>(dev_ctx, *x);
auto* x_data = trans_x.data<T>();
auto x_dims = x->dims(); auto x_dims = x->dims();
int rows = x_dims[x_dims.size() - 2]; int rows = x_dims[x_dims.size() - 2];
int cols = x_dims[x_dims.size() - 1]; int cols = x_dims[x_dims.size() - 1];
...@@ -57,6 +63,20 @@ class SvdCPUKernel : public framework::OpKernel<T> { ...@@ -57,6 +63,20 @@ class SvdCPUKernel : public framework::OpKernel<T> {
context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real<T>))); context.GetPlace(), size_t(batches * k * sizeof(phi::dtype::Real<T>)));
/*SVD Use the Eigen Library*/ /*SVD Use the Eigen Library*/
math::BatchSvd<T>(x_data, U_out, VH_out, S_out, rows, cols, batches, full); math::BatchSvd<T>(x_data, U_out, VH_out, S_out, rows, cols, batches, full);
/* let C[m, n] as a col major matrix with m rows and n cols.
* let R[m, n] is row major matrix with m rows and n cols.
* then we have: R[m,n] = C[m, n].resize((n,m)).tranpose_last_two()
* */
auto col_major_to_row_major = [&dev_ctx](Tensor* out) {
auto origin_dim = out->dims();
int64_t& x = origin_dim[origin_dim.size() - 1];
int64_t& y = origin_dim[origin_dim.size() - 2];
std::swap(x, y);
out->Resize(origin_dim);
return ::phi::TransposeLast2Dim<T>(dev_ctx, *out);
};
*U = col_major_to_row_major(U);
*VH = col_major_to_row_major(VH);
} }
}; };
......
...@@ -280,6 +280,34 @@ extern "C" void spotrs_(char *uplo, ...@@ -280,6 +280,34 @@ extern "C" void spotrs_(char *uplo,
float *b, float *b,
int *ldb, int *ldb,
int *info); int *info);
extern "C" void dgesdd_(char *,
int *,
int *,
double *,
int *,
double *,
double *,
int *,
double *,
int *,
double *,
int *,
int *,
int *);
extern "C" void sgesdd_(char *,
int *,
int *,
float *,
int *,
float *,
float *,
int *,
float *,
int *,
float *,
int *,
int *,
int *);
namespace phi { namespace phi {
namespace dynload { namespace dynload {
...@@ -328,6 +356,8 @@ extern void *lapack_dso_handle; ...@@ -328,6 +356,8 @@ extern void *lapack_dso_handle;
__macro(sgelsy_); \ __macro(sgelsy_); \
__macro(dgelss_); \ __macro(dgelss_); \
__macro(sgelss_); \ __macro(sgelss_); \
__macro(sgesdd_); \
__macro(dgesdd_); \
__macro(zpotrs_); \ __macro(zpotrs_); \
__macro(cpotrs_); \ __macro(cpotrs_); \
__macro(dpotrs_); \ __macro(dpotrs_); \
......
...@@ -499,5 +499,43 @@ void lapackCholeskySolve<float>(char uplo, ...@@ -499,5 +499,43 @@ void lapackCholeskySolve<float>(char uplo,
dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
} }
template <>
void lapackSvd<double>(char jobz,
int m,
int n,
double *a,
int lda,
double *s,
double *u,
int ldu,
double *vt,
int ldvt,
double *work,
int lwork,
int *iwork,
int *info) {
dynload::dgesdd_(
&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
}
template <>
void lapackSvd<float>(char jobz,
int m,
int n,
float *a,
int lda,
float *s,
float *u,
int ldu,
float *vt,
int ldvt,
float *work,
int lwork,
int *iwork,
int *info) {
dynload::sgesdd_(
&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
}
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -120,6 +120,22 @@ void lapackGelss(int m, ...@@ -120,6 +120,22 @@ void lapackGelss(int m,
T2 *rwork, T2 *rwork,
int *info); int *info);
template <typename T>
void lapackSvd(char jobz,
int m,
int n,
T *a,
int lda,
T *s,
T *u,
int ldu,
T *vt,
int ldvt,
T *work,
int lwork,
int *iwork,
int *info);
template <typename T> template <typename T>
void lapackCholeskySolve( void lapackCholeskySolve(
char uplo, int n, int nrhs, T *a, int lda, T *b, int ldb, int *info); char uplo, int n, int nrhs, T *a, int lda, T *b, int ldb, int *info);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册