未验证 提交 f7f65134 编写于 作者: Y yiicy 提交者: GitHub

[ARM] improve sgemm performance with relu, relu6 and leakey relu, test=develop (#3181)

* [ARM] improve sgemm performance with relu, relu6 and leakey relu, test=develop

* [ARM] improve sgemm performance with relu, relu6 and leakey relu, test=develop
上级 50822e44
......@@ -2289,6 +2289,29 @@ void sgemm_prepacked_8x12(bool is_transB,
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block = (l2_cache - (MBLOCK * K)) / (sizeof(float) * (K + MBLOCK));
x_block /= NBLOCK;
......@@ -2837,7 +2860,172 @@ void sgemm_prepacked_8x12(bool is_transB,
"fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 =q7*/
"fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 =q7*/
"fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 =q7*/
"11: \n" /* check if relu */
"11: \n" /* check activation */
"cmp %w[flag_act], #1 \n" /* check if has relu */
"bne 12f \n" /* jump if no relu */
"movi v0.4s, #0 \n" /* for relu*/
"fmax v8.4s, v8.4s, v0.4s \n" /* relu*/
"fmax v9.4s, v9.4s, v0.4s \n" /* relu*/
"fmax v10.4s, v10.4s, v0.4s \n" /* relu*/
"fmax v11.4s, v11.4s, v0.4s \n" /* relu*/
"fmax v12.4s, v12.4s, v0.4s \n" /* relu*/
"fmax v13.4s, v13.4s, v0.4s \n" /* relu*/
"fmax v14.4s, v14.4s, v0.4s \n" /* relu*/
"fmax v15.4s, v15.4s, v0.4s \n" /* relu*/
"fmax v16.4s, v16.4s, v0.4s \n" /* relu*/
"fmax v17.4s, v17.4s, v0.4s \n" /* relu*/
"fmax v18.4s, v18.4s, v0.4s \n" /* relu*/
"fmax v19.4s, v19.4s, v0.4s \n" /* relu*/
"fmax v20.4s, v20.4s, v0.4s \n" /* relu*/
"fmax v21.4s, v21.4s, v0.4s \n" /* relu*/
"fmax v22.4s, v22.4s, v0.4s \n" /* relu*/
"fmax v23.4s, v23.4s, v0.4s \n" /* relu*/
"fmax v24.4s, v24.4s, v0.4s \n" /* relu*/
"fmax v25.4s, v25.4s, v0.4s \n" /* relu*/
"fmax v26.4s, v26.4s, v0.4s \n" /* relu*/
"fmax v27.4s, v27.4s, v0.4s \n" /* relu*/
"fmax v28.4s, v28.4s, v0.4s \n" /* relu*/
"fmax v29.4s, v29.4s, v0.4s \n" /* relu*/
"fmax v30.4s, v30.4s, v0.4s \n" /* relu*/
"fmax v31.4s, v31.4s, v0.4s \n" /* relu*/
"b 20f \n" /* relu end */
//! no act
"12: \n" /* no relu */
"cmp %w[flag_act], #0 \n" /* check no act */
"beq 20f \n" /* no act end */
//! relu6
"cmp %w[flag_act], #2 \n" /* check if has relu6 */
"bne 13f \n" /* jump if no relu6 */
"movi v0.4s, #0 \n" /* for relu6 */
"ld1 {v1.4s}, [%[alpha]] \n" /* relu6 alpha */
"fmax v8.4s, v8.4s, v0.4s \n" /* relu6 */
"fmax v9.4s, v9.4s, v0.4s \n" /* relu6 */
"fmax v10.4s, v10.4s, v0.4s \n" /* relu6 */
"fmax v11.4s, v11.4s, v0.4s \n" /* relu6 */
"fmax v12.4s, v12.4s, v0.4s \n" /* relu6 */
"fmax v13.4s, v13.4s, v0.4s \n" /* relu6 */
"fmax v14.4s, v14.4s, v0.4s \n" /* relu6 */
"fmax v15.4s, v15.4s, v0.4s \n" /* relu6 */
"fmax v16.4s, v16.4s, v0.4s \n" /* relu6 */
"fmax v17.4s, v17.4s, v0.4s \n" /* relu6 */
"fmax v18.4s, v18.4s, v0.4s \n" /* relu6 */
"fmax v19.4s, v19.4s, v0.4s \n" /* relu6 */
"fmax v20.4s, v20.4s, v0.4s \n" /* relu6 */
"fmax v21.4s, v21.4s, v0.4s \n" /* relu6 */
"fmax v22.4s, v22.4s, v0.4s \n" /* relu6 */
"fmax v23.4s, v23.4s, v0.4s \n" /* relu6 */
"fmax v24.4s, v24.4s, v0.4s \n" /* relu6 */
"fmax v25.4s, v25.4s, v0.4s \n" /* relu6 */
"fmax v26.4s, v26.4s, v0.4s \n" /* relu6 */
"fmax v27.4s, v27.4s, v0.4s \n" /* relu6 */
"fmax v28.4s, v28.4s, v0.4s \n" /* relu6 */
"fmax v29.4s, v29.4s, v0.4s \n" /* relu6 */
"fmax v30.4s, v30.4s, v0.4s \n" /* relu6 */
"fmax v31.4s, v31.4s, v0.4s \n" /* relu6 */
"fmin v8.4s, v8.4s, v1.4s \n" /* relu6 */
"fmin v9.4s, v9.4s, v1.4s \n" /* relu6 */
"fmin v10.4s, v10.4s, v1.4s \n" /* relu6 */
"fmin v11.4s, v11.4s, v1.4s \n" /* relu6 */
"fmin v12.4s, v12.4s, v1.4s \n" /* relu6 */
"fmin v13.4s, v13.4s, v1.4s \n" /* relu6 */
"fmin v14.4s, v14.4s, v1.4s \n" /* relu6 */
"fmin v15.4s, v15.4s, v1.4s \n" /* relu6 */
"fmin v16.4s, v16.4s, v1.4s \n" /* relu6 */
"fmin v17.4s, v17.4s, v1.4s \n" /* relu6 */
"fmin v18.4s, v18.4s, v1.4s \n" /* relu6 */
"fmin v19.4s, v19.4s, v1.4s \n" /* relu6 */
"fmin v20.4s, v20.4s, v1.4s \n" /* relu6 */
"fmin v21.4s, v21.4s, v1.4s \n" /* relu6 */
"fmin v22.4s, v22.4s, v1.4s \n" /* relu6 */
"fmin v23.4s, v23.4s, v1.4s \n" /* relu6 */
"fmin v24.4s, v24.4s, v1.4s \n" /* relu6 */
"fmin v25.4s, v25.4s, v1.4s \n" /* relu6 */
"fmin v26.4s, v26.4s, v1.4s \n" /* relu6 */
"fmin v27.4s, v27.4s, v1.4s \n" /* relu6 */
"fmin v28.4s, v28.4s, v1.4s \n" /* relu6 */
"fmin v29.4s, v29.4s, v1.4s \n" /* relu6 */
"fmin v30.4s, v30.4s, v1.4s \n" /* relu6 */
"fmin v31.4s, v31.4s, v1.4s \n" /* relu6 */
"b 20f \n" /* relu6 end */
//! leakey relu
"13: \n" /* otherwise is leakey relu */
"movi v0.4s, #0 \n" /* for leakey relu */
"ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */
"fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */
"bif v8.16b, v3.16b, v2.16b \n" /* choose*/
"bif v9.16b, v5.16b, v4.16b \n" /* choose*/
"bif v10.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v11.4s, v1.4s \n" /* vmulq_f32 */
"bif v11.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v12.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v12.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v13.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v13.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v14.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v14.4s, v1.4s \n" /* vmulq_f32 */
"bif v12.16b, v3.16b, v2.16b \n" /* choose*/
"bif v13.16b, v5.16b, v4.16b \n" /* choose*/
"bif v14.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v15.4s, v1.4s \n" /* vmulq_f32 */
"bif v15.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v16.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v17.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v18.4s, v1.4s \n" /* vmulq_f32 */
"bif v16.16b, v3.16b, v2.16b \n" /* choose*/
"bif v17.16b, v5.16b, v4.16b \n" /* choose*/
"bif v18.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v19.4s, v1.4s \n" /* vmulq_f32 */
"bif v19.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v20.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v21.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v22.4s, v1.4s \n" /* vmulq_f32 */
"bif v20.16b, v3.16b, v2.16b \n" /* choose*/
"bif v21.16b, v5.16b, v4.16b \n" /* choose*/
"bif v22.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v23.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v23.4s, v1.4s \n" /* vmulq_f32 */
"bif v23.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v24.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v24.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v25.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v25.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v26.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v26.4s, v1.4s \n" /* vmulq_f32 */
"bif v24.16b, v3.16b, v2.16b \n" /* choose*/
"bif v25.16b, v5.16b, v4.16b \n" /* choose*/
"bif v26.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v27.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v27.4s, v1.4s \n" /* vmulq_f32 */
"bif v27.16b, v3.16b, v2.16b \n" /* choose*/
"fcmge v2.4s, v28.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v28.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v29.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v29.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v30.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v30.4s, v1.4s \n" /* vmulq_f32 */
"bif v28.16b, v3.16b, v2.16b \n" /* choose*/
"bif v29.16b, v5.16b, v4.16b \n" /* choose*/
"bif v30.16b, v7.16b, v6.16b \n" /* choose*/
"fcmge v2.4s, v31.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v31.4s, v1.4s \n" /* vmulq_f32 */
"bif v31.16b, v3.16b, v2.16b \n" /* choose*/
"20: \n" /* act end */
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */
......@@ -2861,7 +3049,9 @@ void sgemm_prepacked_8x12(bool is_transB,
[c_ptr7] "+r"(c_ptr7)
: [bias_ptr] "r"(bias_local),
[has_beta] "r"(has_beta),
[beta] "r"(beta)
[beta] "r"(beta),
[alpha] "r"(alpha),
[flag_act] "r"(flag_act)
: "cc","memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11","v12","v13",
......@@ -2884,13 +3074,6 @@ void sgemm_prepacked_8x12(bool is_transB,
}
}
}
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float *dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
}
void sgemm_prepacked_4x4(bool is_transB,
......@@ -2911,6 +3094,28 @@ void sgemm_prepacked_4x4(bool is_transB,
auto workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
const int n_block = 4;
const int m_block = 4;
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
......@@ -3137,7 +3342,51 @@ void sgemm_prepacked_4x4(bool is_transB,
"fmla v10.4s, v6.4s, v2.s[2]\n" /* out2 = b2 * a20[2], b1 =q6*/
"fmla v11.4s, v6.4s, v2.s[3]\n" /* out3 = b2 * a20[3], b1 =q6*/
"11: \n" /* check if relu */
"11: \n" /* check activation */
"cmp %w[flag_act], #1 \n" /* check if has relu */
"bne 12f \n" /* jump if no relu */
"movi v0.4s, #0 \n" /* for relu*/
"fmax v8.4s, v8.4s, v0.4s \n" /* relu*/
"fmax v9.4s, v9.4s, v0.4s \n" /* relu*/
"fmax v10.4s, v10.4s, v0.4s \n" /* relu*/
"fmax v11.4s, v11.4s, v0.4s \n" /* relu*/
"b 20f \n" /* relu end */
//! no act
"12: \n" /* no relu */
"cmp %w[flag_act], #0 \n" /* check no act */
"beq 20f \n" /* no act end */
//! relu6
"cmp %w[flag_act], #2 \n" /* check if has relu6 */
"bne 13f \n" /* jump if no relu6 */
"movi v0.4s, #0 \n" /* for relu6 */
"ld1 {v1.4s}, [%[alpha]] \n" /* relu6 alpha */
"fmax v8.4s, v8.4s, v0.4s \n" /* relu6 */
"fmax v9.4s, v9.4s, v0.4s \n" /* relu6 */
"fmax v10.4s, v10.4s, v0.4s \n" /* relu6 */
"fmax v11.4s, v11.4s, v0.4s \n" /* relu6 */
"fmin v8.4s, v8.4s, v1.4s \n" /* relu6*/
"fmin v9.4s, v9.4s, v1.4s \n" /* relu6*/
"fmin v10.4s, v10.4s, v1.4s \n" /* relu6*/
"fmin v11.4s, v11.4s, v1.4s \n" /* relu6*/
"b 20f \n" /* relu6 end */
//! leakey relu
"13: \n" /* otherwise is leakey relu */
"movi v0.4s, #0 \n" /* for leakey relu */
"ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */
"fcmge v2.4s, v8.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v3.4s, v8.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v4.4s, v9.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v5.4s, v9.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v6.4s, v10.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v7.4s, v10.4s, v1.4s \n" /* vmulq_f32 */
"fcmge v12.4s, v11.4s, v0.4s \n" /* vcgeq_f32 */
"fmul v13.4s, v11.4s, v1.4s \n" /* vmulq_f32 */
"bif v8.16b, v3.16b, v2.16b \n" /* choose*/
"bif v9.16b, v5.16b, v4.16b \n" /* choose*/
"bif v10.16b, v7.16b, v6.16b \n" /* choose*/
"bif v11.16b, v13.16b, v12.16b \n" /* choose*/
"20: \n" /* act end */
"st1 {v8.4s}, [%[c_ptr0]], #16\n" /* store r0 */
"st1 {v9.4s}, [%[c_ptr1]], #16\n" /* store r1 */
"st1 {v10.4s}, [%[c_ptr2]], #16\n" /* store r2 */
......@@ -3153,7 +3402,9 @@ void sgemm_prepacked_4x4(bool is_transB,
[c_ptr3] "+r"(c_ptr3)
: [bias_ptr] "r"(bias_local),
[has_beta] "r"(has_beta),
[beta] "r"(beta)
[beta] "r"(beta),
[alpha] "r"(alpha),
[flag_act] "r"(flag_act)
: "cc","memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11");
......@@ -3169,13 +3420,6 @@ void sgemm_prepacked_4x4(bool is_transB,
}
}
}
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float *dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
}
#else // __aarch64__
/**
......@@ -3206,6 +3450,28 @@ void sgemm_prepacked_6x8(bool is_transB,
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto* workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block =
(l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH));
......@@ -3223,6 +3489,8 @@ void sgemm_prepacked_6x8(bool is_transB,
tail_pre = KBLOCK;
}
//! merge tail_pre and flag_act
tail_pre = (tail_pre << 2 | flag_act);
bool flag_p_remain = false;
int remain = 0;
......@@ -3461,8 +3729,9 @@ void sgemm_prepacked_6x8(bool is_transB,
"vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n"
"bne 1b @ jump to main loop\n"
"0: @ process tail\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"beq 3f @ jump to tail = 1\n"
"sub %[tails], %[tails], #4 @ tail--\n"
"cmp %[tails], #4 @ cmp with act bits\n"
"blt 3f @ jump to tail = 1\n"
/* Unroll 0*/
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n"
......@@ -3471,9 +3740,10 @@ void sgemm_prepacked_6x8(bool is_transB,
"vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"sub %[tails], %[tails], #4 @ tail--\n"
"vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"cmp %[tails], #4 @ cmp with act bits\n"
"vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n"
"vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n"
......@@ -3482,16 +3752,17 @@ void sgemm_prepacked_6x8(bool is_transB,
"vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"beq 4f @ jump to tail==2\n"
"blt 4f @ jump to tail==2\n"
/* Unroll 1*/
"vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n"
"subs %[tails], %[tails], #1 @ tail--\n"
"sub %[tails], %[tails], #4 @ tail--\n"
"vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n"
"vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n"
"vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n"
"vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n"
"vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n"
"cmp %[tails], #4 @ cmp with act bits\n"
"vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n"
"vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n"
"vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n"
......@@ -3500,8 +3771,9 @@ void sgemm_prepacked_6x8(bool is_transB,
"vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n"
"vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n"
"beq 5f @ jump to tail==3\n"
"blt 5f @ jump to tail==3\n"
/* Unroll 2 */
"sub %[tails], %[tails], #4 @ tail--\n"
"vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4,a5, a0,a1\n"
"vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n"
"vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n"
......@@ -3579,7 +3851,99 @@ void sgemm_prepacked_6x8(bool is_transB,
"vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n"
"vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n"
"vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n"
"2: @ check relu\n"
"2: @ check activation\n"
//! relu
"cmp %[tails], #1 @ check if has relu\n"
"bne 6f @ jump if not relu \n"
"vmov.u32 q0, #0 @ for relu\n"
"vmax.f32 q4, q4, q0 @ for relu\n"
"vmax.f32 q5, q5, q0 @ for relu\n"
"vmax.f32 q6, q6, q0 @ for relu\n"
"vmax.f32 q7, q7, q0 @ for relu\n"
"vmax.f32 q8, q8, q0 @ for relu\n"
"vmax.f32 q9, q9, q0 @ for relu\n"
"vmax.f32 q10, q10, q0 @ for relu\n"
"vmax.f32 q11, q11, q0 @ for relu\n"
"vmax.f32 q12, q12, q0 @ for relu\n"
"vmax.f32 q13, q13, q0 @ for relu\n"
"vmax.f32 q14, q14, q0 @ for relu\n"
"vmax.f32 q15, q15, q0 @ for relu\n"
"b 10f @ relu end\n"
"6: @ no relu \n"
"cmp %[tails], #0 @ check no act\n"
"beq 10f @ no act end \n"
//! relu6
"cmp %[tails], #2 @ check if has relu6\n"
"bne 7f @ jump if no relu6 \n"
"vmov.u32 q0, #0 @ for relu6\n"
"vmax.f32 q4, q4, q0 @ for relu6\n"
"vmax.f32 q5, q5, q0 @ for relu6\n"
"vmax.f32 q6, q6, q0 @ for relu6\n"
"vmax.f32 q7, q7, q0 @ for relu6\n"
"vmax.f32 q8, q8, q0 @ for relu6\n"
"vmax.f32 q9, q9, q0 @ for relu6\n"
"vld1.f32 {d2-d3}, [%[alpha]] @ load relu6 alpha\n"
"vmax.f32 q10, q10, q0 @ for relu6\n"
"vmax.f32 q11, q11, q0 @ for relu6\n"
"vmax.f32 q12, q12, q0 @ for relu6\n"
"vmax.f32 q13, q13, q0 @ for relu6\n"
"vmax.f32 q14, q14, q0 @ for relu6\n"
"vmax.f32 q15, q15, q0 @ for relu6\n"
"vmin.f32 q4, q4, q1 @ for relu6\n"
"vmin.f32 q5, q5, q1 @ for relu6\n"
"vmin.f32 q6, q6, q1 @ for relu6\n"
"vmin.f32 q7, q7, q1 @ for relu6\n"
"vmin.f32 q8, q8, q1 @ for relu6\n"
"vmin.f32 q9, q9, q1 @ for relu6\n"
"vmin.f32 q10, q10, q1 @ for relu6\n"
"vmin.f32 q11, q11, q1 @ for relu6\n"
"vmin.f32 q12, q12, q1 @ for relu6\n"
"vmin.f32 q13, q13, q1 @ for relu6\n"
"vmin.f32 q14, q14, q1 @ for relu6\n"
"vmin.f32 q15, q15, q1 @ for relu6\n"
"b 10f @ relu6 end \n"
//! leakey relu
"7: @ otherwise is leakey relu\n"
"vmov.u32 q0, #0 @ for leakey relu \n"
"vld1.f32 {d2-d3}, [%[alpha]] @ load leakey relu alpha\n"
"vcge.f32 q2, q4, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q4, q1 @ vmulq_f32 \n"
"vbif q4, q3, q2 @ choose \n"
"vcge.f32 q2, q5, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q5, q1 @ vmulq_f32 \n"
"vbif q5, q3, q2 @ choose \n"
"vcge.f32 q2, q6, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q6, q1 @ vmulq_f32 \n"
"vbif q6, q3, q2 @ choose \n"
"vcge.f32 q2, q7, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q7, q1 @ vmulq_f32 \n"
"vbif q7, q3, q2 @ choose \n"
"vcge.f32 q2, q8, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q8, q1 @ vmulq_f32 \n"
"vbif q8, q3, q2 @ choose \n"
"vcge.f32 q2, q9, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q9, q1 @ vmulq_f32 \n"
"vbif q9, q3, q2 @ choose \n"
"vcge.f32 q2, q10, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q10, q1 @ vmulq_f32 \n"
"vbif q10, q3, q2 @ choose \n"
"vcge.f32 q2, q11, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q11, q1 @ vmulq_f32 \n"
"vbif q11, q3, q2 @ choose \n"
"vcge.f32 q2, q12, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q12, q1 @ vmulq_f32 \n"
"vbif q12, q3, q2 @ choose \n"
"vcge.f32 q2, q13, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q13, q1 @ vmulq_f32 \n"
"vbif q13, q3, q2 @ choose \n"
"vcge.f32 q2, q14, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q14, q1 @ vmulq_f32 \n"
"vbif q14, q3, q2 @ choose \n"
"vcge.f32 q2, q15, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q15, q1 @ vmulq_f32 \n"
"vbif q15, q3, q2 @ choose \n"
"10: @ act end \n"
"vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n"
"vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n"
......@@ -3597,7 +3961,8 @@ void sgemm_prepacked_6x8(bool is_transB,
[k] "+r"(k),
[tails] "+r"(tails)
: [bias_ptr] "r"(bias_local),
[beta] "r"(beta)
[beta] "r"(beta),
[alpha] "r" (alpha)
: "q0","q1","q2","q3","q4",
"q5","q6","q7","q8","q9","q10","q11",
"q12","q13","q14","q15","cc","memory");
......@@ -3616,13 +3981,6 @@ void sgemm_prepacked_6x8(bool is_transB,
}
}
}
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float* dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
}
void sgemm_prepacked_4x8(bool is_transB,
......@@ -3642,6 +4000,28 @@ void sgemm_prepacked_4x8(bool is_transB,
size_t l2_cache = ctx->llc_size() > 0 ? ctx->llc_size() : 512 * 1024;
auto* workspace = ctx->workspace_data<float>();
int threads = ctx->threads();
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
//! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2
int x_block =
(l2_cache - (MBLOCK_A73 * K)) / (sizeof(float) * (K + MBLOCK_A73));
......@@ -3922,6 +4302,74 @@ void sgemm_prepacked_4x8(bool is_transB,
/*aptr - 16*/
"sub %[a_ptr], %[a_ptr], #16 @ tail--\n"
"2: @ check relu\n"
//! relu
"cmp %[flag_act], #1 @ check if has relu\n"
"bne 6f @ jump if not relu \n"
"vmov.u32 q0, #0 @ for relu\n"
"vmax.f32 q8, q8, q0 @ for relu\n"
"vmax.f32 q9, q9, q0 @ for relu\n"
"vmax.f32 q10, q10, q0 @ for relu\n"
"vmax.f32 q11, q11, q0 @ for relu\n"
"vmax.f32 q12, q12, q0 @ for relu\n"
"vmax.f32 q13, q13, q0 @ for relu\n"
"vmax.f32 q14, q14, q0 @ for relu\n"
"vmax.f32 q15, q15, q0 @ for relu\n"
"b 10f @ relu end\n"
"6: @ no relu \n"
"cmp %[flag_act], #0 @ check no act\n"
"beq 10f @ no act end \n"
//! relu6
"cmp %[flag_act], #2 @ check if has relu6\n"
"bne 7f @ jump if no relu6 \n"
"vmov.u32 q0, #0 @ for relu6\n"
"vld1.f32 {d2-d3}, [%[alpha]] @ load relu6 alpha\n"
"vmax.f32 q8, q8, q0 @ for relu6\n"
"vmax.f32 q9, q9, q0 @ for relu6\n"
"vmax.f32 q10, q10, q0 @ for relu6\n"
"vmax.f32 q11, q11, q0 @ for relu6\n"
"vmax.f32 q12, q12, q0 @ for relu6\n"
"vmax.f32 q13, q13, q0 @ for relu6\n"
"vmax.f32 q14, q14, q0 @ for relu6\n"
"vmax.f32 q15, q15, q0 @ for relu6\n"
"vmin.f32 q8, q8, q1 @ for relu6\n"
"vmin.f32 q9, q9, q1 @ for relu6\n"
"vmin.f32 q10, q10, q1 @ for relu6\n"
"vmin.f32 q11, q11, q1 @ for relu6\n"
"vmin.f32 q12, q12, q1 @ for relu6\n"
"vmin.f32 q13, q13, q1 @ for relu6\n"
"vmin.f32 q14, q14, q1 @ for relu6\n"
"vmin.f32 q15, q15, q1 @ for relu6\n"
"b 10f @ relu6 end \n"
//! leakey relu
"7: @ otherwise is leakey relu\n"
"vmov.u32 q0, #0 @ for leakey relu \n"
"vld1.f32 {d2-d3}, [%[alpha]] @ load leakey relu alpha\n"
"vcge.f32 q2, q8, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q8, q1 @ vmulq_f32 \n"
"vbif q8, q3, q2 @ choose \n"
"vcge.f32 q2, q9, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q9, q1 @ vmulq_f32 \n"
"vbif q9, q3, q2 @ choose \n"
"vcge.f32 q2, q10, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q10, q1 @ vmulq_f32 \n"
"vbif q10, q3, q2 @ choose \n"
"vcge.f32 q2, q11, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q11, q1 @ vmulq_f32 \n"
"vbif q11, q3, q2 @ choose \n"
"vcge.f32 q2, q12, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q12, q1 @ vmulq_f32 \n"
"vbif q12, q3, q2 @ choose \n"
"vcge.f32 q2, q13, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q13, q1 @ vmulq_f32 \n"
"vbif q13, q3, q2 @ choose \n"
"vcge.f32 q2, q14, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q14, q1 @ vmulq_f32 \n"
"vbif q14, q3, q2 @ choose \n"
"vcge.f32 q2, q15, q0 @ vcgeq_u32 \n"
"vmul.f32 q3, q15, q1 @ vmulq_f32 \n"
"vbif q15, q3, q2 @ choose \n"
"10: @ act end \n"
"vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n"
"vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n"
"vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n"
......@@ -3935,7 +4383,9 @@ void sgemm_prepacked_4x8(bool is_transB,
[k] "+r"(k),
[tails] "+r"(tails)
: [bias_ptr] "r"(bias_local),
[beta] "r"(beta)
[beta] "r"(beta),
[alpha] "r"(alpha),
[flag_act] "r"(flag_act)
: "q0","q1","q2","q3",
"q4","q5","q6","q7","q8","q9","q10",
"q11","q12","q13","q14","q15","cc","memory");
......@@ -3951,13 +4401,6 @@ void sgemm_prepacked_4x8(bool is_transB,
}
}
}
if (act_param.has_active) {
#pragma omp parallel for num_threads(threads)
for (unsigned int x = 0; x < M; x++) {
float* dst = C + x * ldc;
act_switch_process(dst, dst, N, &act_param);
}
}
}
#endif // __aarch64__
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册