提交 cb5e15b9 编写于 作者: H hjchen2

Optimize gemm performance for m > n

上级 3dfb6c06
......@@ -27,9 +27,9 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N,
const int K, const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta, float *C,
const int ldc) {
if (N == 1) {
return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
}
// if (N == 1) {
// return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C);
// }
CPUInfo *info = CPUInfo::Info();
GemmExecutor<SgemmStrategy> exec(info, transA, transB, M, N, K);
......
......@@ -29,7 +29,19 @@ namespace paddle_mobile {
namespace operators {
namespace math {
inline int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; }
unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num,
const int N, const int K) {
unsigned int L1 = L1_size;
if (thread_num == 1) {
if (N >= 30000 && K > 100) {
L1 *= 4;
} else if (N >= 10000 && K > 100) {
L1 *= 2;
}
}
return L1;
}
class Executor {
public:
......@@ -59,29 +71,34 @@ class GemmExecutor : public Executor {
M_(M),
N_(N),
K_(K) {
unsigned int L1_size = info->L1_cache;
unsigned int L2_size = info->L2_cache;
if (N_ > 30000 && K_ > 100) L1_size *= 2;
if (num_threads_ >= 2) L1_size /= 2;
unsigned int L1_size = 0;
unsigned int L2_size = 0;
if (M_ > N_) {
L2_size = ResetL1Cache(info->L1_cache, num_threads_, M_, K_);
L1_size = info->L2_cache;
} else {
L1_size = ResetL1Cache(info->L1_cache, num_threads_, N_, K_);
L2_size = info->L2_cache;
}
rhs_tile_num_ = L1_size / (K * sizeof(Itype));
rhs_tile_num_ = L1_size / (K_ * sizeof(Itype));
if (rhs_tile_num_ == 0) {
rhs_tile_num_ = Strategy::out_width();
} else {
int n_block = CeilDiv(N, rhs_tile_num_);
rhs_tile_num_ = CeilDiv(N, n_block);
int n_block = CeilDiv(N_, rhs_tile_num_);
rhs_tile_num_ = CeilDiv(N_, n_block);
rhs_tile_num_ = CeilDiv(rhs_tile_num_, Strategy::out_width());
rhs_tile_num_ *= Strategy::out_width();
}
// lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) *
// Strategy::out_height();
lhs_tile_num_ = L2_size / (K * sizeof(Itype));
lhs_tile_num_ = L2_size / (K_ * sizeof(Itype));
if (lhs_tile_num_ == 0) {
lhs_tile_num_ = Strategy::out_height();
} else {
int m_block = CeilDiv(M, lhs_tile_num_);
lhs_tile_num_ = CeilDiv(M, m_block);
int m_block = CeilDiv(M_, lhs_tile_num_);
lhs_tile_num_ = CeilDiv(M_, m_block);
lhs_tile_num_ = CeilDiv(lhs_tile_num_, Strategy::out_height());
lhs_tile_num_ *= Strategy::out_height();
}
......@@ -92,11 +109,19 @@ class GemmExecutor : public Executor {
const int ldc) {
// struct timeval tv_begin, tv_end;
// gettimeofday(&tv_begin,NULL);
int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height();
lhs_worksize_ = sizeof(Itype) * mblock * K_;
rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_;
out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_;
if (M_ > N_) {
int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width();
lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_;
rhs_worksize_ = sizeof(Itype) * K_ * nblock * num_threads_;
out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_;
ldc_ = nblock;
} else {
int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height();
lhs_worksize_ = sizeof(Itype) * mblock * K_;
rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_;
out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_;
ldc_ = rhs_tile_num_;
}
lhs_workspace_ =
static_cast<Itype *>(paddle_mobile::memory::Alloc(lhs_worksize_));
......@@ -105,41 +130,71 @@ class GemmExecutor : public Executor {
out_workspace_ =
static_cast<Otype *>(paddle_mobile::memory::Alloc(out_worksize_));
strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true);
// std::cout << "M: " << M_ << ", N: " << N_ << ", K: " << K_ << std::endl;
// std::cout << "lhs_block: " << CeilDiv(M_, lhs_tile_num_) << ", "
// << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) << std::endl;
// std::cout << "M: " << M_ << ", N: " << N_
// << ", K: " << K_ << std::endl;
// std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_)
// << std::endl;
if (M_ > N_) {
strategy_.pack_rhs(K_, N_, B, ldb, rhs_workspace_, true);
#pragma omp parallel for if (N_ > 128)
for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) {
int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_);
#pragma omp parallel for if (M_ > 128)
for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) {
int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_);
#ifdef _OPENMP
int thread_id = omp_get_thread_num();
int thread_id = omp_get_thread_num();
#else
int thread_id = 0;
int thread_id = 0;
#endif
float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id;
float *local_C =
out_workspace_ + lhs_tile_num_ * rhs_tile_num_ * thread_id;
// load rhs into rhs_workspace
strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false);
for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) {
int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_);
float *local_A = lhs_workspace_ + lhs_block * lda;
for (int lhs_tile = 0; lhs_tile < lhs_range;
lhs_tile += Strategy::out_height()) {
float *local_A = lhs_workspace_ + lhs_tile_num_ * K_ * thread_id;
float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id;
// load lhs into lhs_workspace
strategy_.pack_lhs(lhs_range, K_, A + lhs_block * lda, lda, local_A,
false);
for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) {
int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_);
float *local_B = rhs_workspace_ + K_ * rhs_block;
for (int rhs_tile = 0; rhs_tile < rhs_range;
rhs_tile += Strategy::out_width()) {
int offset = (lhs_block + lhs_tile) * rhs_tile_num_ + rhs_tile;
strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_,
K_, local_C + offset, rhs_tile_num_);
for (int lhs_tile = 0; lhs_tile < lhs_range;
lhs_tile += Strategy::out_height()) {
int offset = lhs_tile * ldc_ + rhs_block + rhs_tile;
strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_,
K_, local_C + offset, ldc_);
}
}
}
strategy_.write(lhs_range, N_, local_C, ldc_, C + lhs_block * ldc, ldc);
}
} else {
strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true);
#pragma omp parallel for if (N_ > 128)
for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) {
int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_);
#ifdef _OPENMP
int thread_id = omp_get_thread_num();
#else
int thread_id = 0;
#endif
float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id;
float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id;
// load rhs into rhs_workspace
strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false);
for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) {
int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_);
float *local_A = lhs_workspace_ + lhs_block * K_;
for (int lhs_tile = 0; lhs_tile < lhs_range;
lhs_tile += Strategy::out_height()) {
for (int rhs_tile = 0; rhs_tile < rhs_range;
rhs_tile += Strategy::out_width()) {
int offset = (lhs_block + lhs_tile) * ldc_ + rhs_tile;
strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_,
K_, local_C + offset, ldc_);
}
}
}
strategy_.write(M_, rhs_range, local_C, ldc_, C + rhs_block, ldc);
}
strategy_.write(M_, rhs_range, local_C, rhs_tile_num_, C + rhs_block,
ldc);
}
paddle_mobile::memory::Free(lhs_workspace_);
......@@ -172,6 +227,7 @@ class GemmExecutor : public Executor {
unsigned int lhs_worksize_ = 0;
unsigned int rhs_worksize_ = 0;
unsigned int out_worksize_ = 0;
unsigned int ldc_ = 0;
Itype *lhs_workspace_ = nullptr;
Itype *rhs_workspace_ = nullptr;
......
......@@ -71,7 +71,7 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
a[index++] = tmp[i * n + j];
}
}
if (M > N || M == 1) {
if (M == 1) {
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
......@@ -84,7 +84,7 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
beta, matrix_out->data<float>(), N);
}
} else {
if (M > N || M == 1) {
if (M == 1) {
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册