提交 52b8d77f 编写于 作者: Y yongqiang

change sgemm.cc

上级 992f29cf
......@@ -323,7 +323,7 @@ void sgemm_prepack(bool is_transB,
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool has_beta = fabsf(beta) > 1e-8f ? true : false;
bool a53_sgemm = act_flag && !has_beta;
if (a53_sgemm) {//无act 无 beta
if (a53_sgemm) {
sgemm_prepacked_6x8_a53(is_transB,
M,
N,
......@@ -2368,19 +2368,16 @@ void sgemm_prepacked_8x12(bool is_transB,
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block = (l2_cache - (MBLOCK * K)) / (sizeof(float) * (K + MBLOCK));
x_block /= NBLOCK;
x_block *= NBLOCK;//一次可以放多少个B的列 为NBLOCK的整数倍
int x_num = (N + (x_block - 1)) / x_block; //可以分x_num 进行计算, 一次放x_block列,可以分x_num计算完成。
LOG(INFO) << "x_block:"<<x_block<<" "<<"x_num"<<x_num;
x_block = (N + x_num - 1) / x_num; //分x_num次计算完成的话,每次需要计算多少个列 (N方向),算一个平均值,因为最后一次可能会非常少,
//计算出x_num次之后,再求一下x_num次读取情况下,每次读取次数的平均值
//这样如果MP的时候,如果包含x_block个数比较少的情况下,可以使各线程耗时更加平均
x_block = (x_block + NBLOCK - 1) / NBLOCK;//算出每次NBLOCK的个数
x_block *= NBLOCK;//计算 一次做loadb 总的列数
x_block = x_block < NBLOCK ? NBLOCK : x_block;//如果不够NBLOCK,按NBLOCK来计算。
LOG(INFO) << "x_block:"<<x_block;
x_block *= NBLOCK;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + NBLOCK - 1) / NBLOCK;
x_block *= NBLOCK;
x_block = x_block < NBLOCK ? NBLOCK : x_block;
// unroll 2 loop
int tail_pre = (K & (KBLOCK - 1));//K方向 KBLOCK的余数
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;//K方向 整数倍KBLOCK的个数
int tail_pre = (K & (KBLOCK - 1));
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
bool flag_p_remain = false;
int remain = 0;
......@@ -2393,8 +2390,8 @@ void sgemm_prepacked_8x12(bool is_transB,
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK;//B 有多少个NBLOCK
remain = xmax - x0 - (bblocks - 1) * NBLOCK;//不够NBLOCK,的余数
int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK;
remain = xmax - x0 - (bblocks - 1) * NBLOCK;
if (remain > 0) {
flag_p_remain = true;
}
......@@ -2405,7 +2402,7 @@ void sgemm_prepacked_8x12(bool is_transB,
} else {
loadb(b_pannel, B, ldb, 0, K, x0, xmax);
}
#pragma omp parallel for num_threads(threads)//在A的M方向,按照MBLOCK进行MP
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += MBLOCK) {
unsigned int ymax = y + MBLOCK;
if (ymax > M) {
......@@ -2424,7 +2421,7 @@ void sgemm_prepacked_8x12(bool is_transB,
bias_local[7] = bias[y + 7];
}
float cout0[NBLOCK];//C 输出 8*12
float cout0[NBLOCK];
float cout1[NBLOCK];
float cout2[NBLOCK];
float cout3[NBLOCK];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册