未验证 提交 82cd8d21 编写于 作者: W WangZhen 提交者: GitHub

Speed up matrix_rank_tol_kernel.cc compile time (#43856)

上级 6d436f6e
...@@ -14,67 +14,66 @@ ...@@ -14,67 +14,66 @@
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h" #include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include <Eigen/Dense>
#include <Eigen/SVD>
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h" #include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/compare_functors.h" #include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h" #include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h"
#include "paddle/phi/kernels/reduce_max_kernel.h" #include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi { namespace phi {
template <typename T> template <typename T>
void BatchEigenvalues(const T* x_data, void LapackSVD(const T* x_data, T* eigenvalues_data, int rows, int cols) {
T* eigenvalues_data, char jobz = 'N';
int batches, int mx = std::max(rows, cols);
int rows, int mn = std::min(rows, cols);
int cols, T* a = const_cast<T*>(x_data);
int k) { int lda = rows;
// Eigen::Matrix API need non-const pointer. int lwork = 3 * mn + std::max(mx, 7 * mn);
T* input = const_cast<T*>(x_data); std::vector<T> work(lwork);
int stride = rows * cols; std::vector<int> iwork(8 * mn);
for (int i = 0; i < batches; i++) { int info;
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( phi::funcs::lapackSvd<T>(jobz,
input + i * stride, rows, rows); rows,
Eigen::SelfAdjointEigenSolver< cols,
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> a,
eigen_solver(m); lda,
auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs(); eigenvalues_data,
for (int j = 0; j < k; j++) { nullptr,
*(eigenvalues_data + i * k + j) = eigenvalues[j]; 1,
nullptr,
1,
work.data(),
lwork,
iwork.data(),
&info);
if (info < 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"This %s-th argument has an illegal value", info));
} }
if (info > 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"DBDSDC/SBDSDC did not converge, updating process failed. May be you "
"passes a invalid matrix."));
} }
} }
template <typename T> template <typename T>
void BatchSVD(const T* x_data, void BatchSVD(
T* eigenvalues_data, const T* x_data, T* eigenvalues_data, int batches, int rows, int cols) {
int batches,
int rows,
int cols,
int k) {
// Eigen::Matrix API need non-const pointer.
T* input = const_cast<T*>(x_data);
int stride = rows * cols; int stride = rows * cols;
Eigen::BDCSVD< int k = std::min(rows, cols);
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> for (int i = 0; i < batches; ++i) {
svd; LapackSVD<T>(x_data + i * stride, eigenvalues_data + i * k, rows, cols);
for (int i = 0; i < batches; i++) {
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input + i * stride, rows, cols);
svd.compute(m);
auto res_s = svd.singularValues();
for (int j = 0; j < k; j++) {
eigenvalues_data[i * k + j] = res_s[j];
}
} }
} }
...@@ -85,7 +84,6 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -85,7 +84,6 @@ void MatrixRankTolKernel(const Context& dev_ctx,
bool use_default_tol, bool use_default_tol,
bool hermitian, bool hermitian,
DenseTensor* out) { DenseTensor* out) {
auto* x_data = x.data<T>();
dev_ctx.template Alloc<int64_t>(out); dev_ctx.template Alloc<int64_t>(out);
auto dim_x = x.dims(); auto dim_x = x.dims();
auto dim_out = out->dims(); auto dim_out = out->dims();
...@@ -106,9 +104,13 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -106,9 +104,13 @@ void MatrixRankTolKernel(const Context& dev_ctx,
auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor); auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor);
if (hermitian) { if (hermitian) {
BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k); phi::funcs::MatrixEighFunctor<Context, T> functor;
functor(dev_ctx, x, &eigenvalue_tensor, nullptr, true, false);
phi::AbsKernel<T, Context>(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
} else { } else {
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k); DenseTensor trans_x = phi::TransposeLast2Dim<T>(dev_ctx, x);
auto* x_data = trans_x.data<T>();
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols);
} }
DenseTensor max_eigenvalue_tensor; DenseTensor max_eigenvalue_tensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册