未验证 提交 3b2d3189 编写于 作者: Y yiicy 提交者: GitHub

[arm] improve sgemm performance on A53, test=develop (#3439)

improve sgemm performance on A53
上级 afefe9cf
...@@ -72,6 +72,7 @@ void pack_trans_m4(float *out, ...@@ -72,6 +72,7 @@ void pack_trans_m4(float *out,
int mmax, int mmax,
int k0, int k0,
int kmax); int kmax);
void sgemm_prepacked_4x4(bool is_transB, void sgemm_prepacked_4x4(bool is_transB,
int M, int M,
int N, int N,
...@@ -154,6 +155,20 @@ void sgemm_prepacked_4x8(bool is_transB, ...@@ -154,6 +155,20 @@ void sgemm_prepacked_4x8(bool is_transB,
bool has_bias, bool has_bias,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext *ctx); ARMContext *ctx);
// for kA53
void sgemm_prepacked_6x8_a53(bool is_transB,
int M,
int N,
int K,
const float *A_packed,
const float *B,
int ldb,
float *C,
int ldc,
const float *bias,
bool has_bias,
int is_relu,
ARMContext *ctx);
#endif // __aarch64__ #endif // __aarch64__
/** /**
...@@ -300,6 +315,44 @@ void sgemm_prepack(bool is_transB, ...@@ -300,6 +315,44 @@ void sgemm_prepack(bool is_transB,
has_bias, has_bias,
act_param, act_param,
ctx); ctx);
} else if (ctx->arch() == kA53) {
auto act_type = act_param.active_type;
bool has_act = act_param.has_active;
bool act_flag =
(has_act == false) ||
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool has_beta = fabsf(beta) > 1e-8f ? true : false;
bool a53_sgemm = act_flag && !has_beta;
if (a53_sgemm) {
sgemm_prepacked_6x8_a53(is_transB,
M,
N,
K,
A_packed,
B,
ldb,
C,
ldc,
bias,
has_bias,
static_cast<int>(has_act),
ctx);
} else {
sgemm_prepacked_6x8(is_transB,
M,
N,
K,
A_packed,
B,
ldb,
beta,
C,
ldc,
bias,
has_bias,
act_param,
ctx);
}
} else { } else {
sgemm_prepacked_6x8(is_transB, sgemm_prepacked_6x8(is_transB,
M, M,
...@@ -3983,6 +4036,472 @@ void sgemm_prepacked_6x8(bool is_transB, ...@@ -3983,6 +4036,472 @@ void sgemm_prepacked_6x8(bool is_transB,
} }
} }
/**
* \brief gemm with ablock = 6, bblock = 8, output 6x8, optimize for a53 arch
* @param A
* @param B
* @param C
* @param M
* @param N
* @param K
* @param threads
* @param workspace
*/
void sgemm_prepacked_6x8_a53(bool is_transB,
int M,
int N,
int K,
const float* A_packed,
const float* B,
int ldb,
float* C,
int ldc,
const float* bias,
bool has_bias,
int is_relu,
ARMContext* ctx) {
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto* workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block =
(l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH));
x_block /= NBLOCK;
x_block *= NBLOCK;
int x_num = (N + (x_block - 1)) / x_block;
x_block = (N + x_num - 1) / x_num;
x_block = (x_block + NBLOCK - 1) / NBLOCK;
x_block *= NBLOCK;
x_block = x_block < NBLOCK ? NBLOCK : x_block;
int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1;
int tail_pre = (K & (KBLOCK - 1));
if (tail_pre == 0) {
tail_pre = KBLOCK;
}
//! merge tail_pre and flag_act
tail_pre = (tail_pre << 2 | is_relu);
bool flag_p_remain = false;
int remain = 0;
//! apanel is pre_compute outside gemm
for (unsigned int x0 = 0; x0 < N; x0 += x_block) {
unsigned int xmax = x0 + x_block;
if (xmax > N) {
xmax = N;
}
int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK;
remain = xmax - x0 - (bblocks - 1) * NBLOCK;
if (remain > 0) {
flag_p_remain = true;
}
//! load bpanel
auto b_pannel = static_cast<float*>(workspace);
if (is_transB) {
loadb_trans(b_pannel, B, ldb, 0, K, x0, xmax);
} else {
loadb(b_pannel, B, ldb, 0, K, x0, xmax);
}
#pragma omp parallel for num_threads(threads)
for (unsigned int y = 0; y < M; y += MBLOCK_OTH) {
unsigned int ymax = y + MBLOCK_OTH;
if (ymax > M) {
ymax = M;
}
float* c_ptr0 = C + y * ldc + x0;
float* c_ptr1 = c_ptr0 + ldc;
float* c_ptr2 = c_ptr1 + ldc;
float* c_ptr3 = c_ptr2 + ldc;
float* c_ptr4 = c_ptr3 + ldc;
float* c_ptr5 = c_ptr4 + ldc;
float* pout0 = c_ptr0;
float* pout1 = c_ptr1;
float* pout2 = c_ptr2;
float* pout3 = c_ptr3;
float* pout4 = c_ptr4;
float* pout5 = c_ptr5;
float bias_local[6] = {0};
if (has_bias) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
bias_local[4] = bias[y + 4];
bias_local[5] = bias[y + 5];
}
float cout0[NBLOCK];
float cout1[NBLOCK];
float cout2[NBLOCK];
float cout3[NBLOCK];
float cout4[NBLOCK];
float cout5[NBLOCK];
const float* a_ptr_l = A_packed + y * K;
const float* b_ptr = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if ((y + 5) >= ymax) {
switch ((y + 5) - ymax) {
case 4:
c_ptr1 = cout1;
case 3:
c_ptr2 = cout2;
case 2:
c_ptr3 = cout3;
case 1:
c_ptr4 = cout4;
case 0:
c_ptr5 = cout5;
default:
break;
}
}
if (flag_p_remain && (xb == bblocks - 1)) {
pout0 = c_ptr0;
pout1 = c_ptr1;
pout2 = c_ptr2;
pout3 = c_ptr3;
pout4 = c_ptr4;
pout5 = c_ptr5;
c_ptr0 = cout0;
c_ptr1 = cout1;
c_ptr2 = cout2;
c_ptr3 = cout3;
c_ptr4 = cout4;
c_ptr5 = cout5;
}
const float* a_ptr = a_ptr_l;
int tails = tail_pre;
int k = k_pre;
// clang-format off
asm volatile(
// sgemm 6x8 for a53
"vld1.32 {d2-d3}, [%[bias_ptr]] \n" /* load bias0-3 to d2,d3 */
"vdup.i32 q4, d2[0] \n" /* set out00 to bias0 */
"vld1.32 {d0-d1}, [%[a_ptr] :64] \n" /* load a00-a30 to d0,d1 */
"vdup.i32 q5, d2[0] \n" /* set out01 to bias0 */
"vld1.32 {d4-d5}, [%[b_ptr] :128] \n" /* load b00-b03 to d4,d5 */
"vdup.i32 q6, d2[1] \n" /* set out10 to bias1 */
"ldr r0, [%[a_ptr], #0x10] \n" /* load a40 to r0 */
"vdup.i32 q7, d2[1] \n" /* set out11 to bias1 */
"ldr r1, [%[a_ptr], #0x14] \n" /* load a50 to r1 */
"vdup.i32 q8, d3[0] \n" /* set out20 to bias2 */
"vldr d6, [%[bias_ptr], #0x10] \n" /* load bias 4,5 to d6 */
"pld [%[a_ptr], #0x40] \n" /* pre load apanel */
"vdup.i32 q9, d3[0] \n" /* set out21 to bias2 */
"pld [%[b_ptr], #0x40] \n" /* pre load bpanel */
"vdup.i32 q10, d3[1] \n" /* set out30 to bias3 */
"pld [%[a_ptr], #0x80] \n" /* pre load apanel */
"vdup.i32 q11, d3[1] \n" /* set out31 to bias3 */
"pld [%[b_ptr], #0x80] \n" /* pre load bpanel */
"vdup.i32 q12, d6[0] \n" /* set out40 to bias4 */
"vdup.i32 q13, d6[0] \n" /* set out41 to bias4 */
"pld [%[a_ptr], #0xC0] \n" /* pre load apanel */
"vdup.i32 q14, d6[1] \n" /* set out50 to bias5 */
"pld [%[b_ptr], #0XC0] \n" /* pre load bpanel */
"vdup.i32 q15, d6[1] \n" /* set out51 to bias5 */
"cmp %[k], #0 \n" /* check k loop */
"beq 6f \n" /* k==0, branch to 6 */
"1:\n"
/* Unroll 0 */
"vldr d6, [%[b_ptr], #0x10] \n" /* load b04, b05 to d6 */
"vmov d2, r0, r1 \n" /* mov a40, a50 to d2 */
"vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */
"ldr r0, [%[b_ptr], #0x18] \n" /* load b06 to r0 */
"vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */
"ldr r1, [%[b_ptr], #0x1C] \n" /* load b07 to r1 */
"vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */
"vldr d3, [%[a_ptr], #0x18] \n" /* load a01, a11 to d3 */
"vmov d7, r0, r1 \n" /* mov b06, b07 to d7 */
"vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */
"pld [%[a_ptr], #0x100] \n" /* pre load apanel */
"vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */
"vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */
"vldr d4, [%[b_ptr], #0x20] \n" /* load b10, b11 to d4 */
"vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */
"ldr r0, [%[b_ptr], #0x28] \n" /* load b12 to r0 */
"vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */
"ldr r1, [%[b_ptr], #0x2C] \n" /* load b13 to r1 */
"vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */
"vldr d0, [%[a_ptr], #0x20] \n" /* load a21, a31 to d0 */
"vmov d5, r0, r1 \n" /* mov b12, b13 to d5 */
"vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */
"ldr r0, [%[a_ptr], #0x28] \n" /* load a41 to r0 */
"vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */
"ldr r1, [%[a_ptr], #0x2C] \n" /* load a51 to r1 */
"vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */
/* Unroll 1 */
"vldr d6, [%[b_ptr], #0x30] \n" /* load b14, b15 to d6 */
"vmov d1, r0, r1 \n" /* mov a41, a51 to d1 */
"vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */
"ldr r0, [%[b_ptr], #0x38] \n" /* load b16 to r0 */
"vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */
"ldr r1, [%[b_ptr], #0x3C] \n" /* load b17 to r1 */
"vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */
"vldr d2, [%[a_ptr], #0x30] \n" /* load a02, a12 to d0 */
"vmov d7, r0, r1 \n" /* mov b16, b17 to d7 */
"vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */
"pld [%[b_ptr], #0x100] \n" /* pre load apanel */
"vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */
"vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */
"vldr d4, [%[b_ptr], #0x40] \n" /* load b20, b21 to d4 */
"vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */
"ldr r0, [%[b_ptr], #0x48] \n" /* load b22 to r0 */
"vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */
"ldr r1, [%[b_ptr], #0x4C] \n" /* load b23 to r1 */
"vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */
"vldr d3, [%[a_ptr], #0x38] \n" /* load a22, a32 to d3 */
"vmov d5, r0, r1 \n" /* mov b22, b23 to d5 */
"vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */
"ldr r0, [%[a_ptr], #0x40] \n" /* load a42 to r0 */
"vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */
"ldr r1, [%[a_ptr], #0x44] \n" /* load a52 to r1 */
"vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */
/* Unroll 2 */
"vldr d6, [%[b_ptr], #0x50] \n" /* load b24, b25 to d6 */
"vmov d0, r0, r1 \n" /* mov a42, a52 to d0 */
"vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */
"ldr r0, [%[b_ptr], #0x58] \n" /* load b26 to r0 */
"vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */
"ldr r1, [%[b_ptr], #0x5C] \n" /* load b27 to r1 */
"vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */
"vldr d1, [%[a_ptr], #0x48] \n" /* load a03, a13 to d1 */
"vmov d7, r0, r1 \n" /* mov b26, b27 to d7 */
"vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */
"pld [%[a_ptr], #0x140] \n" /* pre load apanel */
"vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */
"vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */
"vldr d4, [%[b_ptr], #0x60] \n" /* load b30, b31 to d4 */
"vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */
"ldr r0, [%[b_ptr], #0x68] \n" /* load b32 to r0 */
"vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */
"ldr r1, [%[b_ptr], #0x6C] \n" /* load b33 to r1 */
"vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */
"vldr d2, [%[a_ptr], #0x50] \n" /* load a23, a33 to d2 */
"vmov d5, r0, r1 \n" /* mov b32, b33 to d5 */
"vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */
"ldr r0, [%[a_ptr], #0x58] \n" /* load a43 to r0 */
"vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */
"ldr r1, [%[a_ptr], #0x5C] \n" /* load a53 to r1 */
"vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */
"add %[a_ptr], %[a_ptr], #0x60 \n" /* aptr += 96 */
/* Unroll 3 */
"vldr d6, [%[b_ptr], #0x70] \n" /* load b34, b35 to d6 */
"vmov d3, r0, r1 \n" /* mov a43, a53 to d3 */
"vmla.f32 q4, q2, d1[0] \n" /* out00 += a03 * b3l */
"ldr r0, [%[b_ptr], #0x78] \n" /* load b36 to r0 */
"vmla.f32 q6, q2, d1[1] \n" /* out10 += a13 * b3l */
"ldr r1, [%[b_ptr], #0x7C] \n" /* load b37 to r1 */
"vmla.f32 q8, q2, d2[0] \n" /* out20 += a23 * b3l */
"add %[b_ptr], %[b_ptr], #0x80 \n" /* bptr += 108 */
"vldr d0, [%[a_ptr], #0x00] \n" /* load a00, a10 to d0 */
"vmov d7, r0, r1 \n" /* mov b36, b37 to d7 */
"vmla.f32 q10, q2, d2[1] \n" /* out30 += a33 * b3l */
"pld [%[b_ptr], #0xC0] \n" /* pre load bpanel */
"vmla.f32 q12, q2, d3[0] \n" /* out40 += a43 * b3l */
"vmla.f32 q14, q2, d3[1] \n" /* out50 += a53 * b3l */
"vldr d4, [%[b_ptr], #0x00] \n" /* load b00, b01 to d4 */
"vmla.f32 q5, q3, d1[0] \n" /* out01 += a03 * b3h */
"ldr r0, [%[b_ptr], #0x08] \n" /* load b02 to r0 */
"vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */
"ldr r1, [%[b_ptr], #0x0C] \n" /* load b03 to r1 */
"vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */
"subs %[k], %[k], #1 \n" /* loop k -= 1 */
"vldr d1, [%[a_ptr], #0x08] \n" /* load a20, a30 to d1 */
"vmov d5, r0, r1 \n" /* mov b02, b03 to d5 */
"vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */
"ldr r0, [%[a_ptr], #0x10] \n" /* load a40 to r0 */
"vmla.f32 q13, q3, d3[0] \n" /* out41 += a43 * b3h */
"ldr r1, [%[a_ptr], #0x14] \n" /* load a50 to r1 */
"vmla.f32 q15, q3, d3[1] \n" /* out51 += a53 * b3h */
"bne 1b \n" /* branch to k loop */
"6:\n"
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"cmp %[tails], #4 \n" /* cmp tail with 4 */
"blt 3f \n" /* branch to tail == 1 */
/* Tail Unroll 0 */
"vmov d2, r0, r1 \n" /* mov b02, b03 to d2 */
"add %[a_ptr], %[a_ptr], #0x18 \n" /* aptr += 24 */
"vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */
"vld1.32 {d3}, [%[a_ptr] :64]! \n" /* load a01, a11 to d3 */
"vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */
"add %[b_ptr], %[b_ptr], #0x10 \n" /* bptr += 16 */
"vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b04-b07 to d6,d7 */
"vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */
"vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b10-b13 to d4,d5 */
"vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */
"vld1.32 {d0-d1}, [%[a_ptr] :64]! \n" /* load a21-a51 to d0,d1 */
"cmp %[tails], #4 \n" /* cmp tail with 4 */
"vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b14-b17 to d6,d7 */
"blt 4f \n" /* branch to tail == 2 */
/* Tail Unroll 1 */
"vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */
"vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */
"vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */
"vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */
"vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b20-b23 to d4,d5 */
"vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */
"vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */
"cmp %[tails], #4 \n" /* cmp tail with 4 */
"vld1.32 {d2-d3}, [%[a_ptr] :64]! \n" /* load a02-a32 to d2,d3 */
"vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */
"vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */
"vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */
"vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b24-b27 to d6,d7 */
"blt 5f \n" /* branch to tail == 3 */
/* Tail Unroll 2 */
"sub %[tails], %[tails], #4 \n" /* tail -= 4 */
"vld1.32 {d0-d1}, [%[a_ptr] :64]! \n" /* a42a52a03a13 to d0,d1 */
"vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */
"vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */
"vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */
"vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */
"vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */
"vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */
"vld1.32 {d4-d5}, [%[b_ptr] :128]! \n" /* load b30-b33 to d4,d5 */
"vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */
"vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */
"vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */
"vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */
"vld1.32 {d2-d3}, [%[a_ptr] :64]! \n" /* load a23-a53 to d2,d3 */
"vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */
"vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b34-b37 to d6,d7 */
/* Tail Unroll 3 */
"vmla.f32 q4, q2, d1[0] \n" /* out00 += a03 * b3l */
"vmla.f32 q5, q3, d1[0] \n" /* out01 += a03 * b3h */
"vmla.f32 q6, q2, d1[1] \n" /* out10 += a13 * b3l */
"vmla.f32 q7, q3, d1[1] \n" /* out11 += a13 * b3h */
"vmla.f32 q8, q2, d2[0] \n" /* out20 += a23 * b3l */
"vmla.f32 q9, q3, d2[0] \n" /* out21 += a23 * b3h */
"vmla.f32 q10, q2, d2[1] \n" /* out30 += a33 * b3l */
"vmla.f32 q11, q3, d2[1] \n" /* out31 += a33 * b3h */
"vmla.f32 q12, q2, d3[0] \n" /* out40 += a43 * b3l */
"vmla.f32 q13, q3, d3[0] \n" /* out41 += a43 * b3h */
"vmla.f32 q14, q2, d3[1] \n" /* out50 += a53 * b3l */
"vmla.f32 q15, q3, d3[1] \n" /* out51 += a53 * b3h */
"b 2f \n" /* branch to check relu */
/* tails==1 final tail */
"3:\n"
"vmov d2, r0, r1 \n" /* mov b02, b03 to d2 */
"add %[b_ptr], %[b_ptr], #0x10 \n" /* bptr += 16 */
"vmla.f32 q4, q2, d0[0] \n" /* out00 += a00 * b0l */
"add %[a_ptr], %[a_ptr], #0x18 \n" /* aptr += 24 */
"vmla.f32 q6, q2, d0[1] \n" /* out10 += a10 * b0l */
"vld1.32 {d6-d7}, [%[b_ptr] :128]! \n" /* load b04-b07 to d6,d7 */
"vmla.f32 q8, q2, d1[0] \n" /* out20 += a20 * b0l */
"vmla.f32 q10, q2, d1[1] \n" /* out30 += a30 * b0l */
"vmla.f32 q12, q2, d2[0] \n" /* out40 += a40 * b0l */
"vmla.f32 q14, q2, d2[1] \n" /* out50 += a50 * b0l */
"vmla.f32 q5, q3, d0[0] \n" /* out01 += a00 * b0h */
"vmla.f32 q7, q3, d0[1] \n" /* out11 += a10 * b0h */
"vmla.f32 q9, q3, d1[0] \n" /* out21 += a20 * b0h */
"vmla.f32 q11, q3, d1[1] \n" /* out31 += a30 * b0h */
"vmla.f32 q13, q3, d2[0] \n" /* out41 += a40 * b0h */
"vmla.f32 q15, q3, d2[1] \n" /* out51 += a50 * b0h */
"b 2f \n" /* branch to check relu */
/* tails==2 final tail */
"4:\n"
"vmla.f32 q4, q2, d3[0] \n" /* out00 += a01 * b1l */
"vmla.f32 q5, q3, d3[0] \n" /* out01 += a01 * b1h */
"vmla.f32 q6, q2, d3[1] \n" /* out10 += a11 * b1l */
"vmla.f32 q7, q3, d3[1] \n" /* out11 += a11 * b1h */
"vmla.f32 q8, q2, d0[0] \n" /* out20 += a21 * b1l */
"vmla.f32 q9, q3, d0[0] \n" /* out21 += a21 * b1h */
"vmla.f32 q10, q2, d0[1] \n" /* out30 += a31 * b1l */
"vmla.f32 q11, q3, d0[1] \n" /* out31 += a31 * b1h */
"vmla.f32 q12, q2, d1[0] \n" /* out40 += a41 * b1l */
"vmla.f32 q13, q3, d1[0] \n" /* out41 += a41 * b1h */
"vmla.f32 q14, q2, d1[1] \n" /* out50 += a51 * b1l */
"vmla.f32 q15, q3, d1[1] \n" /* out51 += a51 * b1h */
"b 2f \n" /* branch to check relu */
/* tails==3 final tail */
"5:\n"
"vmla.f32 q4, q2, d2[0] \n" /* out00 += a02 * b2l */
"vld1.32 {d0}, [%[a_ptr] :64]! \n" /* load a42, a52 to d0 */
"vmla.f32 q6, q2, d2[1] \n" /* out10 += a12 * b2l */
"vmla.f32 q8, q2, d3[0] \n" /* out20 += a22 * b2l */
"vmla.f32 q5, q3, d2[0] \n" /* out01 += a02 * b2h */
"vmla.f32 q7, q3, d2[1] \n" /* out11 += a12 * b2h */
"vmla.f32 q9, q3, d3[0] \n" /* out21 += a22 * b2h */
"vmla.f32 q10, q2, d3[1] \n" /* out30 += a32 * b2l */
"vmla.f32 q11, q3, d3[1] \n" /* out31 += a32 * b2h */
"vmla.f32 q12, q2, d0[0] \n" /* out40 += a42 * b2l */
"vmla.f32 q13, q3, d0[0] \n" /* out41 += a42 * b2h */
"vmla.f32 q14, q2, d0[1] \n" /* out50 += a52 * b2l */
"vmla.f32 q15, q3, d0[1] \n" /* out51 += a52 * b2h */
/* relu */
"2:\n"
"cmp %[tails], #1 \n" /* cmp tail is relu */
"bne 0f \n" /* no relu branch to end */
"vmov.i32 q0, #0 \n" /* mov 0.f to q0 */
"vmax.f32 q4, q4, q0 \n" /* out00 relu */
"vmax.f32 q5, q5, q0 \n" /* out01 relu */
"vmax.f32 q6, q6, q0 \n" /* out10 relu */
"vmax.f32 q7, q7, q0 \n" /* out11 relu */
"vmax.f32 q8, q8, q0 \n" /* out20 relu */
"vmax.f32 q9, q9, q0 \n" /* out21 relu */
"vmax.f32 q10, q10, q0 \n" /* out30 relu */
"vmax.f32 q11, q11, q0 \n" /* out31 relu */
"vmax.f32 q12, q12, q0 \n" /* out40 relu */
"vmax.f32 q13, q13, q0 \n" /* out41 relu */
"vmax.f32 q14, q14, q0 \n" /* out50 relu */
"vmax.f32 q15, q15, q0 \n" /* out51 relu */
"0:\n"
"vst1.32 {d8-d11}, [%[c_ptr0]]! \n" /* store out0 to cptr0 */
"vst1.32 {d12-d15}, [%[c_ptr1]]! \n" /* store out1 to cptr1 */
"vst1.32 {d16-d19}, [%[c_ptr2]]! \n" /* store out2 to cptr2 */
"vst1.32 {d20-d23}, [%[c_ptr3]]! \n" /* store out3 to cptr3 */
"vst1.32 {d24-d27}, [%[c_ptr4]]! \n" /* store out4 to cptr4 */
"vst1.32 {d28-d31}, [%[c_ptr5]]! \n" /* store out5 to cptr5 */
: [a_ptr] "+r"(a_ptr),
[b_ptr] "+r"(b_ptr),
[c_ptr0] "+r"(c_ptr0),
[c_ptr1] "+r"(c_ptr1),
[c_ptr2] "+r"(c_ptr2),
[c_ptr3] "+r"(c_ptr3),
[c_ptr4] "+r"(c_ptr4),
[c_ptr5] "+r"(c_ptr5),
[k] "+r"(k),
[tails] "+r"(tails)
: [bias_ptr] "r"(bias_local)
: "r0", "r1", "q0","q1","q2","q3","q4",
"q5","q6","q7","q8","q9","q10","q11",
"q12","q13","q14","q15","cc","memory");
// clang-format on
if (flag_p_remain && (xb == bblocks - 1)) {
for (int i = 0; i < remain; ++i) {
*pout0++ = cout0[i];
*pout1++ = cout1[i];
*pout2++ = cout2[i];
*pout3++ = cout3[i];
*pout4++ = cout4[i];
*pout5++ = cout5[i];
}
}
}
}
}
}
void sgemm_prepacked_4x8(bool is_transB, void sgemm_prepacked_4x8(bool is_transB,
int M, int M,
int N, int N,
......
...@@ -50,11 +50,11 @@ class TopkComputeTester : public arena::TestCase { ...@@ -50,11 +50,11 @@ class TopkComputeTester : public arena::TestCase {
out_dims[out_dims.size() - 1] = k_; out_dims[out_dims.size() - 1] = k_;
out_val->Resize(out_dims); out_val->Resize(out_dims);
out_ind->Resize(out_dims); out_ind->Resize(out_dims);
auto* out_val_data = out_val->mutable_data<T1>(); auto* out_val_data = out_val->template mutable_data<T1>();
auto* out_ind_data = out_ind->mutable_data<T2>(); auto* out_ind_data = out_ind->template mutable_data<T2>();
auto* x = scope->FindTensor(x_); auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<T1>(); const auto* x_data = x->template data<T1>();
int m = out_dims.production() / k_; int m = out_dims.production() / k_;
int n = x_dims_[x_dims_.size() - 1]; int n = x_dims_[x_dims_.size() - 1];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册