From cb5e15b9e0fe3ad446dd3fcc91eb2c376acb9b17 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Tue, 5 Mar 2019 12:02:00 +0800 Subject: [PATCH] Optimize gemm performance for m > n --- src/operators/math/gemm/cblas.cc | 6 +- src/operators/math/gemm/executor.h | 138 +++++++++++++++++++-------- src/operators/math/math_function.cpp | 4 +- 3 files changed, 102 insertions(+), 46 deletions(-) diff --git a/src/operators/math/gemm/cblas.cc b/src/operators/math/gemm/cblas.cc index 5fb9e290b3..ccca4d7681 100644 --- a/src/operators/math/gemm/cblas.cc +++ b/src/operators/math/gemm/cblas.cc @@ -27,9 +27,9 @@ void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, const int K, const float alpha, const float *A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc) { - if (N == 1) { - return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); - } + // if (N == 1) { + // return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); + // } CPUInfo *info = CPUInfo::Info(); GemmExecutor exec(info, transA, transB, M, N, K); diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h index cf7c5687c8..b6ca66eb7e 100644 --- a/src/operators/math/gemm/executor.h +++ b/src/operators/math/gemm/executor.h @@ -29,7 +29,19 @@ namespace paddle_mobile { namespace operators { namespace math { -inline int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } +int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } +unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num, + const int N, const int K) { + unsigned int L1 = L1_size; + if (thread_num == 1) { + if (N >= 30000 && K > 100) { + L1 *= 4; + } else if (N >= 10000 && K > 100) { + L1 *= 2; + } + } + return L1; +} class Executor { public: @@ -59,29 +71,34 @@ class GemmExecutor : public Executor { M_(M), N_(N), K_(K) { - unsigned int L1_size = info->L1_cache; - unsigned int L2_size = info->L2_cache; - if (N_ > 30000 && K_ > 100) L1_size *= 2; - if (num_threads_ >= 2) L1_size /= 2; + unsigned int L1_size = 0; + unsigned int L2_size = 0; + if (M_ > N_) { + L2_size = ResetL1Cache(info->L1_cache, num_threads_, M_, K_); + L1_size = info->L2_cache; + } else { + L1_size = ResetL1Cache(info->L1_cache, num_threads_, N_, K_); + L2_size = info->L2_cache; + } - rhs_tile_num_ = L1_size / (K * sizeof(Itype)); + rhs_tile_num_ = L1_size / (K_ * sizeof(Itype)); if (rhs_tile_num_ == 0) { rhs_tile_num_ = Strategy::out_width(); } else { - int n_block = CeilDiv(N, rhs_tile_num_); - rhs_tile_num_ = CeilDiv(N, n_block); + int n_block = CeilDiv(N_, rhs_tile_num_); + rhs_tile_num_ = CeilDiv(N_, n_block); rhs_tile_num_ = CeilDiv(rhs_tile_num_, Strategy::out_width()); rhs_tile_num_ *= Strategy::out_width(); } // lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) * // Strategy::out_height(); - lhs_tile_num_ = L2_size / (K * sizeof(Itype)); + lhs_tile_num_ = L2_size / (K_ * sizeof(Itype)); if (lhs_tile_num_ == 0) { lhs_tile_num_ = Strategy::out_height(); } else { - int m_block = CeilDiv(M, lhs_tile_num_); - lhs_tile_num_ = CeilDiv(M, m_block); + int m_block = CeilDiv(M_, lhs_tile_num_); + lhs_tile_num_ = CeilDiv(M_, m_block); lhs_tile_num_ = CeilDiv(lhs_tile_num_, Strategy::out_height()); lhs_tile_num_ *= Strategy::out_height(); } @@ -92,11 +109,19 @@ class GemmExecutor : public Executor { const int ldc) { // struct timeval tv_begin, tv_end; // gettimeofday(&tv_begin,NULL); - - int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); - lhs_worksize_ = sizeof(Itype) * mblock * K_; - rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_; - out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_; + if (M_ > N_) { + int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width(); + lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_; + rhs_worksize_ = sizeof(Itype) * K_ * nblock * num_threads_; + out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_; + ldc_ = nblock; + } else { + int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); + lhs_worksize_ = sizeof(Itype) * mblock * K_; + rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_; + out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_; + ldc_ = rhs_tile_num_; + } lhs_workspace_ = static_cast(paddle_mobile::memory::Alloc(lhs_worksize_)); @@ -105,41 +130,71 @@ class GemmExecutor : public Executor { out_workspace_ = static_cast(paddle_mobile::memory::Alloc(out_worksize_)); - strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true); + // std::cout << "M: " << M_ << ", N: " << N_ << ", K: " << K_ << std::endl; + // std::cout << "lhs_block: " << CeilDiv(M_, lhs_tile_num_) << ", " + // << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) << std::endl; - // std::cout << "M: " << M_ << ", N: " << N_ - // << ", K: " << K_ << std::endl; - // std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) - // << std::endl; + if (M_ > N_) { + strategy_.pack_rhs(K_, N_, B, ldb, rhs_workspace_, true); - #pragma omp parallel for if (N_ > 128) - for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { - int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_); + #pragma omp parallel for if (M_ > 128) + for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { + int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_); #ifdef _OPENMP - int thread_id = omp_get_thread_num(); + int thread_id = omp_get_thread_num(); #else - int thread_id = 0; + int thread_id = 0; #endif - float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id; - float *local_C = - out_workspace_ + lhs_tile_num_ * rhs_tile_num_ * thread_id; - // load rhs into rhs_workspace - strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false); - for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { - int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_); - float *local_A = lhs_workspace_ + lhs_block * lda; - for (int lhs_tile = 0; lhs_tile < lhs_range; - lhs_tile += Strategy::out_height()) { + float *local_A = lhs_workspace_ + lhs_tile_num_ * K_ * thread_id; + float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id; + // load lhs into lhs_workspace + strategy_.pack_lhs(lhs_range, K_, A + lhs_block * lda, lda, local_A, + false); + for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { + int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_); + float *local_B = rhs_workspace_ + K_ * rhs_block; for (int rhs_tile = 0; rhs_tile < rhs_range; rhs_tile += Strategy::out_width()) { - int offset = (lhs_block + lhs_tile) * rhs_tile_num_ + rhs_tile; - strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_, - K_, local_C + offset, rhs_tile_num_); + for (int lhs_tile = 0; lhs_tile < lhs_range; + lhs_tile += Strategy::out_height()) { + int offset = lhs_tile * ldc_ + rhs_block + rhs_tile; + strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_, + K_, local_C + offset, ldc_); + } + } + } + strategy_.write(lhs_range, N_, local_C, ldc_, C + lhs_block * ldc, ldc); + } + } else { + strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true); + + #pragma omp parallel for if (N_ > 128) + for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { + int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_); +#ifdef _OPENMP + int thread_id = omp_get_thread_num(); +#else + int thread_id = 0; +#endif + float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id; + float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id; + // load rhs into rhs_workspace + strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false); + for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { + int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_); + float *local_A = lhs_workspace_ + lhs_block * K_; + for (int lhs_tile = 0; lhs_tile < lhs_range; + lhs_tile += Strategy::out_height()) { + for (int rhs_tile = 0; rhs_tile < rhs_range; + rhs_tile += Strategy::out_width()) { + int offset = (lhs_block + lhs_tile) * ldc_ + rhs_tile; + strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_, + K_, local_C + offset, ldc_); + } } } + strategy_.write(M_, rhs_range, local_C, ldc_, C + rhs_block, ldc); } - strategy_.write(M_, rhs_range, local_C, rhs_tile_num_, C + rhs_block, - ldc); } paddle_mobile::memory::Free(lhs_workspace_); @@ -172,6 +227,7 @@ class GemmExecutor : public Executor { unsigned int lhs_worksize_ = 0; unsigned int rhs_worksize_ = 0; unsigned int out_worksize_ = 0; + unsigned int ldc_ = 0; Itype *lhs_workspace_ = nullptr; Itype *rhs_workspace_ = nullptr; diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index b576963cc4..6cd0514832 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -71,7 +71,7 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, a[index++] = tmp[i * n + j]; } } - if (M > N || M == 1) { + if (M == 1) { #ifdef _OPENMP gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, matrix_out->data(), N, relu, bias); @@ -84,7 +84,7 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, beta, matrix_out->data(), N); } } else { - if (M > N || M == 1) { + if (M == 1) { #ifdef _OPENMP gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, matrix_b.data(), N, beta, matrix_out->data(), -- GitLab