提交 ce0f2bfd 编写于 作者: H hjchen2

Optimize gemm data package, it will bring 22% speedup for ocr detection model

上级 c070770c
此差异已折叠。
...@@ -46,37 +46,25 @@ namespace math { ...@@ -46,37 +46,25 @@ namespace math {
class Gemm { class Gemm {
public: public:
typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *); typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *,
const bool);
typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *, typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *,
int); int);
FnPack procPackA; FnPack procPackA;
FnPack procPackB; FnPack procPackB;
FnAddDot procAddDot; FnAddDot procAddDot;
// 将 A\B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer, const bool parallel);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer, const bool parallel);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer, const bool parallel);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
#if __aarch64__ #if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer, const bool parallel);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer, const bool parallel);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
#endif #endif
// 分块矩阵乘法 // 分块矩阵乘法
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册