未验证 提交 82cd8d21 编写于 作者: W WangZhen 提交者: GitHub

Speed up matrix_rank_tol_kernel.cc compile time (#43856)

上级 6d436f6e
......@@ -14,67 +14,66 @@
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include <Eigen/Dense>
#include <Eigen/SVD>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h"
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
template <typename T>
void BatchEigenvalues(const T* x_data,
T* eigenvalues_data,
int batches,
int rows,
int cols,
int k) {
// Eigen::Matrix API need non-const pointer.
T* input = const_cast<T*>(x_data);
int stride = rows * cols;
for (int i = 0; i < batches; i++) {
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input + i * stride, rows, rows);
Eigen::SelfAdjointEigenSolver<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
eigen_solver(m);
auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs();
for (int j = 0; j < k; j++) {
*(eigenvalues_data + i * k + j) = eigenvalues[j];
}
void LapackSVD(const T* x_data, T* eigenvalues_data, int rows, int cols) {
char jobz = 'N';
int mx = std::max(rows, cols);
int mn = std::min(rows, cols);
T* a = const_cast<T*>(x_data);
int lda = rows;
int lwork = 3 * mn + std::max(mx, 7 * mn);
std::vector<T> work(lwork);
std::vector<int> iwork(8 * mn);
int info;
phi::funcs::lapackSvd<T>(jobz,
rows,
cols,
a,
lda,
eigenvalues_data,
nullptr,
1,
nullptr,
1,
work.data(),
lwork,
iwork.data(),
&info);
if (info < 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"This %s-th argument has an illegal value", info));
}
if (info > 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"DBDSDC/SBDSDC did not converge, updating process failed. May be you "
"passes a invalid matrix."));
}
}
template <typename T>
void BatchSVD(const T* x_data,
T* eigenvalues_data,
int batches,
int rows,
int cols,
int k) {
// Eigen::Matrix API need non-const pointer.
T* input = const_cast<T*>(x_data);
void BatchSVD(
const T* x_data, T* eigenvalues_data, int batches, int rows, int cols) {
int stride = rows * cols;
Eigen::BDCSVD<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
svd;
for (int i = 0; i < batches; i++) {
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input + i * stride, rows, cols);
svd.compute(m);
auto res_s = svd.singularValues();
for (int j = 0; j < k; j++) {
eigenvalues_data[i * k + j] = res_s[j];
}
int k = std::min(rows, cols);
for (int i = 0; i < batches; ++i) {
LapackSVD<T>(x_data + i * stride, eigenvalues_data + i * k, rows, cols);
}
}
......@@ -85,7 +84,6 @@ void MatrixRankTolKernel(const Context& dev_ctx,
bool use_default_tol,
bool hermitian,
DenseTensor* out) {
auto* x_data = x.data<T>();
dev_ctx.template Alloc<int64_t>(out);
auto dim_x = x.dims();
auto dim_out = out->dims();
......@@ -106,9 +104,13 @@ void MatrixRankTolKernel(const Context& dev_ctx,
auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor);
if (hermitian) {
BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k);
phi::funcs::MatrixEighFunctor<Context, T> functor;
functor(dev_ctx, x, &eigenvalue_tensor, nullptr, true, false);
phi::AbsKernel<T, Context>(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
} else {
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k);
DenseTensor trans_x = phi::TransposeLast2Dim<T>(dev_ctx, x);
auto* x_data = trans_x.data<T>();
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols);
}
DenseTensor max_eigenvalue_tensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册