提交 f30ae5ff 编写于 作者: C chenjiaoAngel

add gemm+relu6

上级 a4770bd7
...@@ -264,6 +264,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, ...@@ -264,6 +264,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
} }
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
//! use gemv when the output channel size = 1 //! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) { for (int b = 0; b < num; ++b) {
// dC // dC
...@@ -294,9 +295,9 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, ...@@ -294,9 +295,9 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
n, n,
k, k,
flag_bias, flag_bias,
flag_relu,
false, false,
scale_group, scale_group,
act_param,
ctx); ctx);
} }
} }
...@@ -474,6 +475,8 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -474,6 +475,8 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
int hblock = get_hblock_int8(ctx); int hblock = get_hblock_int8(ctx);
int k_roundup = ROUNDUP(k, KBLOCK_INT8); int k_roundup = ROUNDUP(k, KBLOCK_INT8);
int m_roundup = ROUNDUP(m, hblock); int m_roundup = ROUNDUP(m, hblock);
...@@ -534,9 +537,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -534,9 +537,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
n, n,
k, k,
flag_bias, flag_bias,
flag_relu,
false, false,
scale_group, scale_group,
act_param,
ctx); ctx);
} }
} }
......
...@@ -195,7 +195,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -195,7 +195,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
Dtype*& c_ptr2, // NOLINT Dtype*& c_ptr2, // NOLINT
Dtype*& c_ptr3, // NOLINT Dtype*& c_ptr3, // NOLINT
const float* scale, const float* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem); int rem);
// clang-format off // clang-format off
...@@ -483,7 +484,10 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -483,7 +484,10 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_RELU \ #define GEMM_INT8_RELU \
/* do relu */ \ /* do relu */ \
"cbz %w[is_relu], 9f\n" /* skip relu */ \ "cmp %w[flag_act], #0\n" /* skip relu */ \
"beq 9f \n" /* no act end */ \
"cmp %w[flag_act], #1\n" /* skip relu */ \
"beq 10f \n" /* other act */ \
"movi v0.4s, #0\n" /* for relu */ \ "movi v0.4s, #0\n" /* for relu */ \
"fmax v16.4s, v16.4s, v0.4s\n" /* relu */ \ "fmax v16.4s, v16.4s, v0.4s\n" /* relu */ \
"fmax v17.4s, v17.4s, v0.4s\n" /* relu */ \ "fmax v17.4s, v17.4s, v0.4s\n" /* relu */ \
...@@ -501,6 +505,102 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -501,6 +505,102 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"fmax v29.4s, v29.4s, v0.4s\n" /* relu */ \ "fmax v29.4s, v29.4s, v0.4s\n" /* relu */ \
"fmax v30.4s, v30.4s, v0.4s\n" /* relu */ \ "fmax v30.4s, v30.4s, v0.4s\n" /* relu */ \
"fmax v31.4s, v31.4s, v0.4s\n" /* relu */ \ "fmax v31.4s, v31.4s, v0.4s\n" /* relu */ \
"b 9f \n" /* relu end */
#define GEMM_INT8_RELU6 \
/* do relu6 */ \
"10: \n" \
"cmp %w[flag_act], #2 \n" /* check relu6 */ \
"beq 11f \n" /* no act end */ \
"movi v0.4s, #0\n" /* for relu6 */ \
"fmax v16.4s, v16.4s, v0.4s\n" /* relu */ \
"fmax v17.4s, v17.4s, v0.4s\n" /* relu */ \
"fmax v18.4s, v18.4s, v0.4s\n" /* relu */ \
"fmax v19.4s, v19.4s, v0.4s\n" /* relu */ \
"fmax v20.4s, v20.4s, v0.4s\n" /* relu */ \
"ld1 {v1.4s}, [%[alpha]] \n" /* relu6 alpha */ \
"fmax v21.4s, v21.4s, v0.4s\n" /* relu */ \
"fmax v22.4s, v22.4s, v0.4s\n" /* relu */ \
"fmax v23.4s, v23.4s, v0.4s\n" /* relu */ \
"fmax v24.4s, v24.4s, v0.4s\n" /* relu */ \
"fmax v25.4s, v25.4s, v0.4s\n" /* relu */ \
"fmax v26.4s, v26.4s, v0.4s\n" /* relu */ \
"fmax v27.4s, v27.4s, v0.4s\n" /* relu */ \
"fmax v28.4s, v28.4s, v0.4s\n" /* relu */ \
"fmax v29.4s, v29.4s, v0.4s\n" /* relu */ \
"fmax v30.4s, v30.4s, v0.4s\n" /* relu */ \
"fmax v31.4s, v31.4s, v0.4s\n" /* relu */ \
"fmin v16.4s, v16.4s, v1.4s\n" /* relu6 */ \
"fmin v17.4s, v17.4s, v1.4s\n" /* relu6 */ \
"fmin v18.4s, v18.4s, v1.4s\n" /* relu6 */ \
"fmin v19.4s, v19.4s, v1.4s\n" /* relu6 */ \
"fmin v20.4s, v20.4s, v0.4s\n" /* relu6 */ \
"fmin v21.4s, v21.4s, v0.4s\n" /* relu6 */ \
"fmin v22.4s, v22.4s, v0.4s\n" /* relu6 */ \
"fmin v23.4s, v23.4s, v0.4s\n" /* relu6 */ \
"fmin v24.4s, v24.4s, v0.4s\n" /* relu6 */ \
"fmin v25.4s, v25.4s, v0.4s\n" /* relu6 */ \
"fmin v26.4s, v26.4s, v0.4s\n" /* relu6 */ \
"fmin v27.4s, v27.4s, v0.4s\n" /* relu6 */ \
"fmin v28.4s, v28.4s, v0.4s\n" /* relu6 */ \
"fmin v29.4s, v29.4s, v0.4s\n" /* relu6 */ \
"fmin v30.4s, v30.4s, v0.4s\n" /* relu6 */ \
"fmin v31.4s, v31.4s, v0.4s\n" /* relu6 */ \
"b 9f \n" /* relu end */
#define GEMM_INT8_LEAKY_RELU \
/* do relu */ \
"11: \n" \
"movi v0.4s, #0\n" /* for relu6 */ \
"ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */ \
"fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v16.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v17.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v18.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v19.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v16.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v17.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v18.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v19.16b, v9.16b, v8.16b \n" /* choose*/ \
"fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v20.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v21.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v22.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v23.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v23.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v20.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v21.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v22.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v23.16b, v9.16b, v8.16b \n" /* choose*/ \
"fcmge v2.4s, v24.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v24.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v25.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v25.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v26.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v26.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v27.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v27.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v24.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v25.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v26.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v27.16b, v9.16b, v8.16b \n" /* choose*/ \
"fcmge v2.4s, v28.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v28.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v29.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v29.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v30.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v30.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v31.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v31.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v28.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v29.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v30.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v31.16b, v9.16b, v8.16b \n" /* choose*/ \
"9:\n" "9:\n"
#define GEMM_TRANS_INT32_TO_FP32 \ #define GEMM_TRANS_INT32_TO_FP32 \
...@@ -559,6 +659,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -559,6 +659,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_FP32_OUT \ #define GEMM_INT8_FP32_OUT \
GEMM_TRANS_INT32_TO_FP32 \ GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
/* store result */ \ /* store result */ \
"stp q16, q17, [%[c_ptr0]], #32\n" \ "stp q16, q17, [%[c_ptr0]], #32\n" \
"stp q18, q19, [%[c_ptr0]], #32\n" \ "stp q18, q19, [%[c_ptr0]], #32\n" \
...@@ -571,7 +673,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -571,7 +673,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_INT8_OUT \ #define GEMM_INT8_INT8_OUT \
GEMM_TRANS_INT32_TO_FP32 \ GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"ld1 {v8.4s}, [%[vmax]] \n" /* v8 = -127 */ \ "ld1 {v8.4s}, [%[vmax]] \n" /* v8 = -127 */ \
/* data >= -127 */ \ /* data >= -127 */ \
"fcmge v0.4s, v16.4s, v8.4s\n" \ "fcmge v0.4s, v16.4s, v8.4s\n" \
...@@ -665,7 +768,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -665,7 +768,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
float32_t*& c_ptr2, // NOLINT float32_t*& c_ptr2, // NOLINT
float32_t*& c_ptr3, // NOLINT float32_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
// clang-format off // clang-format off
...@@ -678,6 +782,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -678,6 +782,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[c_ptr3] "+r"(c_ptr3), [c_ptr3] "+r"(c_ptr3),
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[alpha] "r"(alpha),
[bias] "r"(bias), [bias] "r"(bias),
[rem] "r"(rem), [rem] "r"(rem),
[scale] "r"(scale) [scale] "r"(scale)
...@@ -698,7 +803,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -698,7 +803,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
int8_t*& c_ptr2, // NOLINT int8_t*& c_ptr2, // NOLINT
int8_t*& c_ptr3, // NOLINT int8_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
// clang-format off // clang-format off
...@@ -712,6 +818,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -712,6 +818,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[c_ptr3] "+r"(c_ptr3), [c_ptr3] "+r"(c_ptr3),
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[alpha] "r"(alpha),
[bias] "r"(bias), [bias] "r"(bias),
[rem] "r"(rem), [rem] "r"(rem),
[scale] "r"(scale), [scale] "r"(scale),
...@@ -739,7 +846,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -739,7 +846,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
Dtype*& c_ptr6, // NOLINT Dtype*& c_ptr6, // NOLINT
Dtype*& c_ptr7, // NOLINT Dtype*& c_ptr7, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem); int rem);
#if 0 #if 0
...@@ -1099,12 +1207,47 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1099,12 +1207,47 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#endif #endif
#define GEMM_SDOT_RELU \ #define GEMM_SDOT_RELU \
"cbz %w[relu], 12f\n" /* skip relu */ \ "cmp %w[relu], #0\n" /* skip relu */ \
"beq 12f\n" \
"cmp %w[relu], #1\n" /* skip relu */ \
"beq 13f\n" /* other act */ \
"movi v2.4s, #0\n" /* for relu*/ \
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ \
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ \
"fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ \
"fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ \
"fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ \
"fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ \
"fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ \
"fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ \
"fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ \
"fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ \
"fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ \
"fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ \
"fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ \
"fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ \
"fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ \
"fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ \
"fmax v24.4s, v24.4s, v2.4s\n" /* relu*/ \
"fmax v25.4s, v25.4s, v2.4s\n" /* relu*/ \
"fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ \
"fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ \
"fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ \
"fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ \
"fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ \
"fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ \
"b 12f \n" /* relu end */
#define GEMM_SDOT_RELU6 \
"13: \n" \
"cmp %w[relu], #2\n" /* skip relu6 */ \
"beq 14f\n" \
"movi v2.4s, #0\n" /* for relu*/ \ "movi v2.4s, #0\n" /* for relu*/ \
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ \ "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ \
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ \ "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ \
"fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ \ "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ \
"fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ \ "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ \
"ld1 {v3.4s}, [%[alpha]] \n" /* relu6 alpha */ \
"fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ \ "fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ \
"fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ \ "fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ \
"fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ \ "fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ \
...@@ -1125,6 +1268,108 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1125,6 +1268,108 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ \ "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ \
"fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ \ "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ \
"fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ \ "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ \
"fmin v8.4s, v8.4s, v3.4s\n" /* relu6*/ \
"fmin v9.4s, v9.4s, v3.4s\n" /* relu6*/ \
"fmin v10.4s, v10.4s, v3.4s\n" /* relu6*/ \
"fmin v11.4s, v11.4s, v3.4s\n" /* relu6*/ \
"fmin v12.4s, v12.4s, v3.4s\n" /* relu6*/ \
"fmin v13.4s, v13.4s, v3.4s\n" /* relu6*/ \
"fmin v14.4s, v14.4s, v3.4s\n" /* relu6*/ \
"fmin v15.4s, v15.4s, v3.4s\n" /* relu6*/ \
"fmin v16.4s, v16.4s, v3.4s\n" /* relu6*/ \
"fmin v17.4s, v17.4s, v3.4s\n" /* relu6*/ \
"fmin v18.4s, v18.4s, v3.4s\n" /* relu6*/ \
"fmin v19.4s, v19.4s, v3.4s\n" /* relu6*/ \
"fmin v20.4s, v20.4s, v3.4s\n" /* relu6*/ \
"fmin v21.4s, v21.4s, v3.4s\n" /* relu6*/ \
"fmin v22.4s, v22.4s, v3.4s\n" /* relu6*/ \
"fmin v23.4s, v23.4s, v3.4s\n" /* relu6*/ \
"fmin v24.4s, v24.4s, v3.4s\n" /* relu6*/ \
"fmin v25.4s, v25.4s, v3.4s\n" /* relu6*/ \
"fmin v26.4s, v26.4s, v3.4s\n" /* relu6*/ \
"fmin v27.4s, v27.4s, v3.4s\n" /* relu6*/ \
"fmin v28.4s, v28.4s, v3.4s\n" /* relu6*/ \
"fmin v29.4s, v29.4s, v3.4s\n" /* relu6*/ \
"fmin v30.4s, v30.4s, v3.4s\n" /* relu6*/ \
"fmin v31.4s, v31.4s, v3.4s\n" /* relu6*/ \
"b 12f \n" /* relu end */
#define GEMM_SDOT_LEAKY_RELU \
"14: \n" \
"movi v2.4s, #0\n" /* for leakyrelu*/ \
"ld1 {v3.4s}, [%[alpha]]\n" /* leakyrelu alpha */ \
"fcmge v4.4s, v8.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v8.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v9.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v9.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v8.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v9.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v10.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v10.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v11.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v11.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v10.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v11.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v12.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v12.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v13.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v13.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v12.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v13.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v14.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v14.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v15.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v15.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v14.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v15.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v16.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v16.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v17.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v17.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v16.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v17.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v18.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v18.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v19.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v19.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v18.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v19.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v20.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v20.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v21.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v21.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v20.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v21.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v22.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v22.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v23.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v23.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v22.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v23.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v24.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v24.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v25.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v25.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v24.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v25.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v26.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v26.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v27.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v27.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v26.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v27.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v28.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v28.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v29.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v29.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v28.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v29.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v30.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v30.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v31.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v32.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v30.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v31.16b, v7.16b, v6.16b \n" /* choose*/ \
"12: \n" "12: \n"
#define GEMM_SDOT_CVT_INT32_TO_FP32 \ #define GEMM_SDOT_CVT_INT32_TO_FP32 \
...@@ -1206,6 +1451,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1206,6 +1451,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_SDOT_FP32_OUT \ #define GEMM_SDOT_FP32_OUT \
GEMM_SDOT_CVT_INT32_TO_FP32 \ GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_RELU \ GEMM_SDOT_RELU \
GEMM_SDOT_RELU6 \
GEMM_SDOT_LEAKY_RELU \
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \ "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \ "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \ "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \
...@@ -1218,6 +1465,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1218,6 +1465,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_SDOT_INT8_OUT \ #define GEMM_SDOT_INT8_OUT \
GEMM_SDOT_CVT_INT32_TO_FP32 \ GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_RELU \ GEMM_SDOT_RELU \
GEMM_SDOT_RELU6 \
GEMM_SDOT_LEAKY_RELU \
"ld1 {v6.4s}, [%[vmax]]\n" /* v8 = -127.f */ \ "ld1 {v6.4s}, [%[vmax]]\n" /* v8 = -127.f */ \
/* data >= -127 */ \ /* data >= -127 */ \
"fcmge v0.4s, v8.4s, v6.4s\n" \ "fcmge v0.4s, v8.4s, v6.4s\n" \
...@@ -1371,7 +1620,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1371,7 +1620,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
float32_t*& c_ptr6, // NOLINT float32_t*& c_ptr6, // NOLINT
float32_t*& c_ptr7, // NOLINT float32_t*& c_ptr7, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int tail) { int tail) {
// clang-format off // clang-format off
...@@ -1389,7 +1639,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1389,7 +1639,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
[c_ptr5] "+r"(c_ptr5), [c_ptr5] "+r"(c_ptr5),
[c_ptr6] "+r"(c_ptr6), [c_ptr6] "+r"(c_ptr6),
[c_ptr7] "+r"(c_ptr7) [c_ptr7] "+r"(c_ptr7)
: [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu),
[alpha] "r"(alpha)
: "cc","memory","v0","v1","v2", : "cc","memory","v0","v1","v2",
"v3","v4","v5","v6","v7","v8","v9","v10", "v3","v4","v5","v6","v7","v8","v9","v10",
"v11","v12","v13","v14","v15","v16","v17", "v11","v12","v13","v14","v15","v16","v17",
...@@ -1410,7 +1661,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1410,7 +1661,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
int8_t*& c_ptr6, // NOLINT int8_t*& c_ptr6, // NOLINT
int8_t*& c_ptr7, // NOLINT int8_t*& c_ptr7, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int tail) { int tail) {
// clang-format off // clang-format off
...@@ -1428,7 +1680,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1428,7 +1680,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
[c_ptr5] "+r"(c_ptr5), [c_ptr5] "+r"(c_ptr5),
[c_ptr6] "+r"(c_ptr6), [c_ptr6] "+r"(c_ptr6),
[c_ptr7] "+r"(c_ptr7) [c_ptr7] "+r"(c_ptr7)
: [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu), [vmax] "r"(vmax) : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu), [vmax] "r"(vmax),
[alpha] "r"(alpha)
: "cc","memory","v0","v1","v2","v3", : "cc","memory","v0","v1","v2","v3",
"v4","v5","v6","v7","v8","v9","v10", "v4","v5","v6","v7","v8","v9","v10",
"v11","v12","v13","v14","v15","v16","v17", "v11","v12","v13","v14","v15","v16","v17",
...@@ -1654,6 +1907,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1654,6 +1907,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
/* do relu */ \ /* do relu */ \
"cmp %[is_relu], #0\n" /* skip relu */ \ "cmp %[is_relu], #0\n" /* skip relu */ \
"beq 9f\n" /* skip relu */ \ "beq 9f\n" /* skip relu */ \
"cmp %[is_relu], #1\n" /* check if has relu6 */ \
"beq 10f\n" /* skip relu */ \
"vmov.i32 q15, #0\n" /* for relu */ \ "vmov.i32 q15, #0\n" /* for relu */ \
"vmax.f32 q8, q8, q15\n" /* relu */ \ "vmax.f32 q8, q8, q15\n" /* relu */ \
"vmax.f32 q9, q9, q15\n" /* relu */ \ "vmax.f32 q9, q9, q15\n" /* relu */ \
...@@ -1663,12 +1918,69 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1663,12 +1918,69 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vmax.f32 q3,q3, q15\n" /* relu */ \ "vmax.f32 q3,q3, q15\n" /* relu */ \
"vmax.f32 q4,q4, q15\n" /* relu */ \ "vmax.f32 q4,q4, q15\n" /* relu */ \
"vmax.f32 q5,q5, q15\n" /* relu */ \ "vmax.f32 q5,q5, q15\n" /* relu */ \
"9:\n" "b: 9f\n"
#define GEMM_INT8_RELU6 \
/* do relu6 */ \
"10: \n" \
"cmp %[is_relu], #2\n" /*heck if has relu6*/ \
"beq 11f\n" /* skip relu */ \
"vmov.i32 q15, #0\n" /* for relu */ \
"vmax.f32 q8, q8, q15\n" /* relu */ \
"vmax.f32 q9, q9, q15\n" /* relu */ \
"vmax.f32 q0,q0, q15\n" /* relu */ \
"vmax.f32 q1,q1, q15\n" /* relu */ \
"vld1.f32 {d28-d29}, [%[alpha]] @ load relu6 alpha\n" \
"vmax.f32 q2,q2, q15\n" /* relu */ \
"vmax.f32 q3,q3, q15\n" /* relu */ \
"vmax.f32 q4,q4, q15\n" /* relu */ \
"vmax.f32 q5,q5, q15\n" /* relu */ \
"vmin.f32 q8, q8, q14\n" /* relu6 */ \
"vmin.f32 q9, q9, q14\n" /* relu6 */ \
"vmin.f32 q0,q0, q14\n" /* relu6 */ \
"vmin.f32 q1,q1, q14\n" /* relu6 */ \
"vmin.f32 q2,q2, q14\n" /* relu6 */ \
"vmin.f32 q3,q3, q14\n" /* relu6 */ \
"vmin.f32 q4,q4, q14\n" /* relu6 */ \
"vmin.f32 q5,q5, q14\n" /* relu6 */ \
"b: 9f\n"
#define GEMM_INT8_LEAKY_RELU \
/* do relu6 */ \
"11: \n" \
"vmov.i32 q15, #0\n" /* for relu */ \
"vld1.f32 {d28-d29}, [%[alpha]] @ load relu6 alpha\n" \
"vcge.f32 q6, q8, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q8, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q9, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q9, q14 @ vmulq_f32 \n" \
"vcge.f32 q12, q0, q15 @ vcgeq_u32 \n" \
"vmul.f32 q13, q0, q14 @ vmulq_f32 \n" \
"vbif q8, q7, q6 @ choose \n" \
"vbif q9, q11, q10 @ choose \n" \
"vbif q0, q13, q12 @ choose \n" \
"vcge.f32 q6, q1, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q1, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q2, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q2, q14 @ vmulq_f32 \n" \
"vcge.f32 q12, q3, q15 @ vcgeq_u32 \n" \
"vmul.f32 q13, q3, q14 @ vmulq_f32 \n" \
"vbif q1, q7, q6 @ choose \n" \
"vbif q2, q11, q10 @ choose \n" \
"vbif q3, q13, q12 @ choose \n" \
"vcge.f32 q6, q4, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q4, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q5, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q5, q14 @ vmulq_f32 \n" \
"vbif q4, q7, q6 @ choose \n" \
"vbif q5, q11, q10 @ choose \n" \
"9: \n"
#define GEMM_INT8_FP32_OUT \ #define GEMM_INT8_FP32_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \ GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \ "vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \
"vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \ "vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \
"vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \ "vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \
...@@ -1678,6 +1990,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1678,6 +1990,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_INT8_OUT \ #define GEMM_INT8_INT8_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \ GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"vmov.f32 q7, #-0.5\n" /* neg offset */ \ "vmov.f32 q7, #-0.5\n" /* neg offset */ \
"vmov.f32 q10, #0.5\n" /* pos offset */ \ "vmov.f32 q10, #0.5\n" /* pos offset */ \
"vmov.f32 q11, #0.5\n" /* pos offset */ \ "vmov.f32 q11, #0.5\n" /* pos offset */ \
...@@ -1765,7 +2079,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1765,7 +2079,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
float32_t*& c_ptr2, // NOLINT float32_t*& c_ptr2, // NOLINT
float32_t*& c_ptr3, // NOLINT float32_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT
...@@ -1778,6 +2093,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1778,6 +2093,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[bias] "r"(bias), [bias] "r"(bias),
[alpha] "r"(alpha),
[rem] "r"(rem), [rem] "r"(rem),
[scale] "r"(scale) [scale] "r"(scale)
: "q0", : "q0",
...@@ -1810,7 +2126,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1810,7 +2126,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
int8_t*& c_ptr2, // NOLINT int8_t*& c_ptr2, // NOLINT
int8_t*& c_ptr3, // NOLINT int8_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
float vmax[4] = {-127.0, -127.0, -127.0, -127.0}; float vmax[4] = {-127.0, -127.0, -127.0, -127.0};
...@@ -1823,6 +2140,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1823,6 +2140,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[c_ptr3] "+r"(c_ptr3), [c_ptr3] "+r"(c_ptr3),
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[alpha] "r"(alpha),
[bias] "r"(bias), [bias] "r"(bias),
[rem] "r"(rem), [rem] "r"(rem),
[vmax] "r"(vmax), [vmax] "r"(vmax),
...@@ -1859,9 +2177,10 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, ...@@ -1859,9 +2177,10 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu, int flag_act,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const float* alpha,
ARMContext* ctx) { ARMContext* ctx) {
const int KUP = ROUNDUP(K, KBLOCK_INT8); const int KUP = ROUNDUP(K, KBLOCK_INT8);
size_t llc_size = ctx->llc_size() / 4; size_t llc_size = ctx->llc_size() / 4;
...@@ -1969,7 +2288,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, ...@@ -1969,7 +2288,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
c_ptr2, c_ptr2,
c_ptr3, c_ptr3,
scale_local, scale_local,
is_relu, alpha,
flag_act,
k, k,
k_rem); k_rem);
if (flag_rem && (xb == bblocks - 1)) { if (flag_rem && (xb == bblocks - 1)) {
...@@ -3090,9 +3410,10 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed, ...@@ -3090,9 +3410,10 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu, int is_relu,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const float* alpha,
ARMContext* ctx) { ARMContext* ctx) {
size_t llc_size = ctx->llc_size() / 4; size_t llc_size = ctx->llc_size() / 4;
auto workspace = ctx->workspace_data<int8_t>(); auto workspace = ctx->workspace_data<int8_t>();
...@@ -3250,6 +3571,7 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed, ...@@ -3250,6 +3571,7 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
c_ptr6, c_ptr6,
c_ptr7, c_ptr7,
scale_local, scale_local,
alpha,
is_relu, is_relu,
k, k,
tail); tail);
...@@ -3871,21 +4193,43 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -3871,21 +4193,43 @@ void gemm_prepack_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
gemm_prepack_sdot_int8<float32_t>( gemm_prepack_sdot_int8<float32_t>(
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
} else { } else {
gemm_prepack_oth_int8<float32_t>( gemm_prepack_oth_int8<float32_t>(
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
} }
#else #else
gemm_prepack_oth_int8<float32_t>( gemm_prepack_oth_int8<float32_t>(
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
#endif #endif
} }
...@@ -3898,21 +4242,43 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -3898,21 +4242,43 @@ void gemm_prepack_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
gemm_prepack_sdot_int8<int8_t>( gemm_prepack_sdot_int8<int8_t>(
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
} else { } else {
gemm_prepack_oth_int8<int8_t>( gemm_prepack_oth_int8<int8_t>(
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
} }
#else #else
gemm_prepack_oth_int8<int8_t>( gemm_prepack_oth_int8<int8_t>(
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册