未验证 提交 c7dd1458 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add int8 gemm/gemv + relu6/leakyRelu fusion (#3572)

* improve 3x3s1 direct profile

* add gemv+relu6/lleakyRelu

* fix relu6 problem, test=develop

* fix format, test=develop

* add six / scale , test=develop
上级 fcf6fa0c
...@@ -264,6 +264,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, ...@@ -264,6 +264,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
} }
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
//! use gemv when the output channel size = 1 //! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) { for (int b = 0; b < num; ++b) {
// dC // dC
...@@ -283,8 +284,11 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, ...@@ -283,8 +284,11 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
scale_group, scale_group,
flag_bias, flag_bias,
bias_group, bias_group,
flag_relu, act_param.has_active,
ctx); act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else { } else {
gemm_prepack_int8(weights_group, gemm_prepack_int8(weights_group,
din_group, din_group,
...@@ -294,9 +298,9 @@ void conv1x1s1_gemm_int8(const int8_t* i_data, ...@@ -294,9 +298,9 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
n, n,
k, k,
flag_bias, flag_bias,
flag_relu,
false, false,
scale_group, scale_group,
act_param,
ctx); ctx);
} }
} }
...@@ -474,6 +478,8 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -474,6 +478,8 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
int hblock = get_hblock_int8(ctx); int hblock = get_hblock_int8(ctx);
int k_roundup = ROUNDUP(k, KBLOCK_INT8); int k_roundup = ROUNDUP(k, KBLOCK_INT8);
int m_roundup = ROUNDUP(m, hblock); int m_roundup = ROUNDUP(m, hblock);
...@@ -523,8 +529,11 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -523,8 +529,11 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
scale_group, scale_group,
flag_bias, flag_bias,
bias_group, bias_group,
flag_relu, act_param.has_active,
ctx); act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else { } else {
gemm_prepack_int8(weights_group, gemm_prepack_int8(weights_group,
dB, dB,
...@@ -534,9 +543,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data, ...@@ -534,9 +543,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
n, n,
k, k,
flag_bias, flag_bias,
flag_relu,
false, false,
scale_group, scale_group,
act_param,
ctx); ctx);
} }
} }
......
...@@ -195,7 +195,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -195,7 +195,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
Dtype*& c_ptr2, // NOLINT Dtype*& c_ptr2, // NOLINT
Dtype*& c_ptr3, // NOLINT Dtype*& c_ptr3, // NOLINT
const float* scale, const float* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem); int rem);
// clang-format off // clang-format off
...@@ -483,7 +484,10 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -483,7 +484,10 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_RELU \ #define GEMM_INT8_RELU \
/* do relu */ \ /* do relu */ \
"cbz %w[is_relu], 9f\n" /* skip relu */ \ "cmp %w[is_relu], #0\n" /* skip relu */ \
"beq 9f \n" /* no act end */ \
"cmp %w[is_relu], #1\n" /* skip relu */ \
"bne 10f \n" /* other act */ \
"movi v0.4s, #0\n" /* for relu */ \ "movi v0.4s, #0\n" /* for relu */ \
"fmax v16.4s, v16.4s, v0.4s\n" /* relu */ \ "fmax v16.4s, v16.4s, v0.4s\n" /* relu */ \
"fmax v17.4s, v17.4s, v0.4s\n" /* relu */ \ "fmax v17.4s, v17.4s, v0.4s\n" /* relu */ \
...@@ -501,6 +505,102 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -501,6 +505,102 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"fmax v29.4s, v29.4s, v0.4s\n" /* relu */ \ "fmax v29.4s, v29.4s, v0.4s\n" /* relu */ \
"fmax v30.4s, v30.4s, v0.4s\n" /* relu */ \ "fmax v30.4s, v30.4s, v0.4s\n" /* relu */ \
"fmax v31.4s, v31.4s, v0.4s\n" /* relu */ \ "fmax v31.4s, v31.4s, v0.4s\n" /* relu */ \
"b 9f \n" /* relu end */
#define GEMM_INT8_RELU6 \
/* do relu6 */ \
"10: \n" \
"cmp %w[is_relu], #2 \n" /* check relu6 */ \
"bne 11f \n" /* no act end */ \
"movi v0.4s, #0\n" /* for relu6 */ \
"fmax v16.4s, v16.4s, v0.4s\n" /* relu */ \
"fmax v17.4s, v17.4s, v0.4s\n" /* relu */ \
"fmax v18.4s, v18.4s, v0.4s\n" /* relu */ \
"fmax v19.4s, v19.4s, v0.4s\n" /* relu */ \
"fmax v20.4s, v20.4s, v0.4s\n" /* relu */ \
"ld1 {v1.4s}, [%[alpha]] \n" /* relu6 alpha */ \
"fmax v21.4s, v21.4s, v0.4s\n" /* relu */ \
"fmax v22.4s, v22.4s, v0.4s\n" /* relu */ \
"fmax v23.4s, v23.4s, v0.4s\n" /* relu */ \
"fmax v24.4s, v24.4s, v0.4s\n" /* relu */ \
"fmax v25.4s, v25.4s, v0.4s\n" /* relu */ \
"fmax v26.4s, v26.4s, v0.4s\n" /* relu */ \
"fmax v27.4s, v27.4s, v0.4s\n" /* relu */ \
"fmax v28.4s, v28.4s, v0.4s\n" /* relu */ \
"fmax v29.4s, v29.4s, v0.4s\n" /* relu */ \
"fmax v30.4s, v30.4s, v0.4s\n" /* relu */ \
"fmax v31.4s, v31.4s, v0.4s\n" /* relu */ \
"fmin v16.4s, v16.4s, v1.4s\n" /* relu6 */ \
"fmin v17.4s, v17.4s, v1.4s\n" /* relu6 */ \
"fmin v18.4s, v18.4s, v1.4s\n" /* relu6 */ \
"fmin v19.4s, v19.4s, v1.4s\n" /* relu6 */ \
"fmin v20.4s, v20.4s, v0.4s\n" /* relu6 */ \
"fmin v21.4s, v21.4s, v0.4s\n" /* relu6 */ \
"fmin v22.4s, v22.4s, v0.4s\n" /* relu6 */ \
"fmin v23.4s, v23.4s, v0.4s\n" /* relu6 */ \
"fmin v24.4s, v24.4s, v0.4s\n" /* relu6 */ \
"fmin v25.4s, v25.4s, v0.4s\n" /* relu6 */ \
"fmin v26.4s, v26.4s, v0.4s\n" /* relu6 */ \
"fmin v27.4s, v27.4s, v0.4s\n" /* relu6 */ \
"fmin v28.4s, v28.4s, v0.4s\n" /* relu6 */ \
"fmin v29.4s, v29.4s, v0.4s\n" /* relu6 */ \
"fmin v30.4s, v30.4s, v0.4s\n" /* relu6 */ \
"fmin v31.4s, v31.4s, v0.4s\n" /* relu6 */ \
"b 9f \n" /* relu end */
#define GEMM_INT8_LEAKY_RELU \
/* do relu */ \
"11: \n" \
"movi v0.4s, #0\n" /* for relu6 */ \
"ld1 {v1.4s}, [%[alpha]] \n" /* leakey relu alpha */ \
"fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v16.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v17.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v18.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v19.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v16.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v17.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v18.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v19.16b, v9.16b, v8.16b \n" /* choose*/ \
"fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v20.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v21.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v22.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v23.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v23.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v20.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v21.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v22.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v23.16b, v9.16b, v8.16b \n" /* choose*/ \
"fcmge v2.4s, v24.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v24.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v25.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v25.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v26.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v26.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v27.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v27.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v24.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v25.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v26.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v27.16b, v9.16b, v8.16b \n" /* choose*/ \
"fcmge v2.4s, v28.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v3.4s, v28.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v4.4s, v29.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v29.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v30.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v30.4s, v1.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v31.4s, v0.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v31.4s, v1.4s \n" /* vmulq_f32 */ \
"bif v28.16b, v3.16b, v2.16b \n" /* choose*/ \
"bif v29.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v30.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v31.16b, v9.16b, v8.16b \n" /* choose*/ \
"9:\n" "9:\n"
#define GEMM_TRANS_INT32_TO_FP32 \ #define GEMM_TRANS_INT32_TO_FP32 \
...@@ -559,6 +659,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -559,6 +659,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_FP32_OUT \ #define GEMM_INT8_FP32_OUT \
GEMM_TRANS_INT32_TO_FP32 \ GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
/* store result */ \ /* store result */ \
"stp q16, q17, [%[c_ptr0]], #32\n" \ "stp q16, q17, [%[c_ptr0]], #32\n" \
"stp q18, q19, [%[c_ptr0]], #32\n" \ "stp q18, q19, [%[c_ptr0]], #32\n" \
...@@ -572,6 +674,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -572,6 +674,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_INT8_OUT \ #define GEMM_INT8_INT8_OUT \
GEMM_TRANS_INT32_TO_FP32 \ GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"ld1 {v8.4s}, [%[vmax]] \n" /* v8 = -127 */ \ "ld1 {v8.4s}, [%[vmax]] \n" /* v8 = -127 */ \
/* data >= -127 */ \ /* data >= -127 */ \
"fcmge v0.4s, v16.4s, v8.4s\n" \ "fcmge v0.4s, v16.4s, v8.4s\n" \
...@@ -665,7 +769,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -665,7 +769,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
float32_t*& c_ptr2, // NOLINT float32_t*& c_ptr2, // NOLINT
float32_t*& c_ptr3, // NOLINT float32_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
// clang-format off // clang-format off
...@@ -678,6 +783,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -678,6 +783,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[c_ptr3] "+r"(c_ptr3), [c_ptr3] "+r"(c_ptr3),
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[alpha] "r"(alpha),
[bias] "r"(bias), [bias] "r"(bias),
[rem] "r"(rem), [rem] "r"(rem),
[scale] "r"(scale) [scale] "r"(scale)
...@@ -698,7 +804,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -698,7 +804,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
int8_t*& c_ptr2, // NOLINT int8_t*& c_ptr2, // NOLINT
int8_t*& c_ptr3, // NOLINT int8_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
// clang-format off // clang-format off
...@@ -712,6 +819,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -712,6 +819,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[c_ptr3] "+r"(c_ptr3), [c_ptr3] "+r"(c_ptr3),
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[alpha] "r"(alpha),
[bias] "r"(bias), [bias] "r"(bias),
[rem] "r"(rem), [rem] "r"(rem),
[scale] "r"(scale), [scale] "r"(scale),
...@@ -739,7 +847,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -739,7 +847,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
Dtype*& c_ptr6, // NOLINT Dtype*& c_ptr6, // NOLINT
Dtype*& c_ptr7, // NOLINT Dtype*& c_ptr7, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem); int rem);
#if 0 #if 0
...@@ -1099,7 +1208,10 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1099,7 +1208,10 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#endif #endif
#define GEMM_SDOT_RELU \ #define GEMM_SDOT_RELU \
"cbz %w[relu], 12f\n" /* skip relu */ \ "cmp %w[relu], #0\n" /* skip relu */ \
"beq 12f\n" \
"cmp %w[relu], #1\n" /* skip relu */ \
"bne 13f\n" /* other act */ \
"movi v2.4s, #0\n" /* for relu*/ \ "movi v2.4s, #0\n" /* for relu*/ \
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ \ "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ \
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ \ "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ \
...@@ -1125,6 +1237,140 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1125,6 +1237,140 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ \ "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ \
"fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ \ "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ \
"fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ \ "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ \
"b 12f \n" /* relu end */
#define GEMM_SDOT_RELU6 \
"13: \n" \
"cmp %w[relu], #2\n" /* skip relu6 */ \
"bne 14f\n" \
"movi v2.4s, #0\n" /* for relu*/ \
"fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ \
"fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ \
"fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ \
"fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ \
"ld1 {v3.4s}, [%[alpha]] \n" /* relu6 alpha */ \
"fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ \
"fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ \
"fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ \
"fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ \
"fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ \
"fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ \
"fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ \
"fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ \
"fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ \
"fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ \
"fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ \
"fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ \
"fmax v24.4s, v24.4s, v2.4s\n" /* relu*/ \
"fmax v25.4s, v25.4s, v2.4s\n" /* relu*/ \
"fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ \
"fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ \
"fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ \
"fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ \
"fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ \
"fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ \
"fmin v8.4s, v8.4s, v3.4s\n" /* relu6*/ \
"fmin v9.4s, v9.4s, v3.4s\n" /* relu6*/ \
"fmin v10.4s, v10.4s, v3.4s\n" /* relu6*/ \
"fmin v11.4s, v11.4s, v3.4s\n" /* relu6*/ \
"fmin v12.4s, v12.4s, v3.4s\n" /* relu6*/ \
"fmin v13.4s, v13.4s, v3.4s\n" /* relu6*/ \
"fmin v14.4s, v14.4s, v3.4s\n" /* relu6*/ \
"fmin v15.4s, v15.4s, v3.4s\n" /* relu6*/ \
"fmin v16.4s, v16.4s, v3.4s\n" /* relu6*/ \
"fmin v17.4s, v17.4s, v3.4s\n" /* relu6*/ \
"fmin v18.4s, v18.4s, v3.4s\n" /* relu6*/ \
"fmin v19.4s, v19.4s, v3.4s\n" /* relu6*/ \
"fmin v20.4s, v20.4s, v3.4s\n" /* relu6*/ \
"fmin v21.4s, v21.4s, v3.4s\n" /* relu6*/ \
"fmin v22.4s, v22.4s, v3.4s\n" /* relu6*/ \
"fmin v23.4s, v23.4s, v3.4s\n" /* relu6*/ \
"fmin v24.4s, v24.4s, v3.4s\n" /* relu6*/ \
"fmin v25.4s, v25.4s, v3.4s\n" /* relu6*/ \
"fmin v26.4s, v26.4s, v3.4s\n" /* relu6*/ \
"fmin v27.4s, v27.4s, v3.4s\n" /* relu6*/ \
"fmin v28.4s, v28.4s, v3.4s\n" /* relu6*/ \
"fmin v29.4s, v29.4s, v3.4s\n" /* relu6*/ \
"fmin v30.4s, v30.4s, v3.4s\n" /* relu6*/ \
"fmin v31.4s, v31.4s, v3.4s\n" /* relu6*/ \
"b 12f \n" /* relu end */
#define GEMM_SDOT_LEAKY_RELU \
"14: \n" \
"movi v2.4s, #0\n" /* for leakyrelu*/ \
"ld1 {v3.4s}, [%[alpha]]\n" /* leakyrelu alpha */ \
"fcmge v4.4s, v8.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v8.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v9.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v9.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v8.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v9.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v10.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v10.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v11.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v11.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v10.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v11.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v12.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v12.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v13.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v13.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v12.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v13.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v14.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v14.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v15.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v15.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v14.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v15.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v16.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v16.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v17.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v17.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v16.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v17.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v18.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v18.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v19.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v19.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v18.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v19.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v20.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v20.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v21.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v21.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v20.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v21.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v22.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v22.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v23.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v23.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v22.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v23.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v24.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v24.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v25.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v25.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v24.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v25.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v26.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v26.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v27.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v27.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v26.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v27.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v28.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v28.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v29.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v29.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v28.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v29.16b, v7.16b, v6.16b \n" /* choose*/ \
"fcmge v4.4s, v30.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v30.4s, v3.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v31.4s, v2.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v31.4s, v3.4s \n" /* vmulq_f32 */ \
"bif v30.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v31.16b, v7.16b, v6.16b \n" /* choose*/ \
"12: \n" "12: \n"
#define GEMM_SDOT_CVT_INT32_TO_FP32 \ #define GEMM_SDOT_CVT_INT32_TO_FP32 \
...@@ -1206,6 +1452,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1206,6 +1452,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_SDOT_FP32_OUT \ #define GEMM_SDOT_FP32_OUT \
GEMM_SDOT_CVT_INT32_TO_FP32 \ GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_RELU \ GEMM_SDOT_RELU \
GEMM_SDOT_RELU6 \
GEMM_SDOT_LEAKY_RELU \
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \ "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ \
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \ "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ \
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \ "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ \
...@@ -1218,6 +1466,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1218,6 +1466,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_SDOT_INT8_OUT \ #define GEMM_SDOT_INT8_OUT \
GEMM_SDOT_CVT_INT32_TO_FP32 \ GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_RELU \ GEMM_SDOT_RELU \
GEMM_SDOT_RELU6 \
GEMM_SDOT_LEAKY_RELU \
"ld1 {v6.4s}, [%[vmax]]\n" /* v8 = -127.f */ \ "ld1 {v6.4s}, [%[vmax]]\n" /* v8 = -127.f */ \
/* data >= -127 */ \ /* data >= -127 */ \
"fcmge v0.4s, v8.4s, v6.4s\n" \ "fcmge v0.4s, v8.4s, v6.4s\n" \
...@@ -1371,7 +1621,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1371,7 +1621,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
float32_t*& c_ptr6, // NOLINT float32_t*& c_ptr6, // NOLINT
float32_t*& c_ptr7, // NOLINT float32_t*& c_ptr7, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int tail) { int tail) {
// clang-format off // clang-format off
...@@ -1389,7 +1640,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1389,7 +1640,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
[c_ptr5] "+r"(c_ptr5), [c_ptr5] "+r"(c_ptr5),
[c_ptr6] "+r"(c_ptr6), [c_ptr6] "+r"(c_ptr6),
[c_ptr7] "+r"(c_ptr7) [c_ptr7] "+r"(c_ptr7)
: [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu) : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu),
[alpha] "r"(alpha)
: "cc","memory","v0","v1","v2", : "cc","memory","v0","v1","v2",
"v3","v4","v5","v6","v7","v8","v9","v10", "v3","v4","v5","v6","v7","v8","v9","v10",
"v11","v12","v13","v14","v15","v16","v17", "v11","v12","v13","v14","v15","v16","v17",
...@@ -1410,7 +1662,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1410,7 +1662,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
int8_t*& c_ptr6, // NOLINT int8_t*& c_ptr6, // NOLINT
int8_t*& c_ptr7, // NOLINT int8_t*& c_ptr7, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int tail) { int tail) {
// clang-format off // clang-format off
...@@ -1428,7 +1681,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1428,7 +1681,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
[c_ptr5] "+r"(c_ptr5), [c_ptr5] "+r"(c_ptr5),
[c_ptr6] "+r"(c_ptr6), [c_ptr6] "+r"(c_ptr6),
[c_ptr7] "+r"(c_ptr7) [c_ptr7] "+r"(c_ptr7)
: [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu), [vmax] "r"(vmax) : [bias_ptr] "r"(bias), [scale] "r"(scale), [relu] "r"(is_relu), [vmax] "r"(vmax),
[alpha] "r"(alpha)
: "cc","memory","v0","v1","v2","v3", : "cc","memory","v0","v1","v2","v3",
"v4","v5","v6","v7","v8","v9","v10", "v4","v5","v6","v7","v8","v9","v10",
"v11","v12","v13","v14","v15","v16","v17", "v11","v12","v13","v14","v15","v16","v17",
...@@ -1534,9 +1788,9 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1534,9 +1788,9 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c31 */ \ "vmlal.s8 q7, d3, d7\n" /* a1 * b1 = c31 */ \
"cmp %[rem], #0\n" /* skip remain */ \ "cmp %[rem], #0\n" /* skip remain */ \
"beq 5f\n" \ "beq 5f\n" \
"mov r0, #32\n" /* address offset */ \ "mov %[k], #32\n" /* address offset */ \
"vld1.8 {d0}, [%[a_ptr]]\n" /* load a to d0, final */ \ "vld1.8 {d0}, [%[a_ptr]]\n" /* load a to d0, final */ \
"vld1.8 {d4-d5}, [%[b_ptr]], r0\n" /* load b to d4, d5 */ \ "vld1.8 {d4-d5}, [%[b_ptr]], %[k]\n" /* load b to d4, d5 */ \
"5:\n" /* skip rem */ \ "5:\n" /* skip rem */ \
"vpadal.s16 q12, q4\n" /* pair add and accumulate, c20 */ \ "vpadal.s16 q12, q4\n" /* pair add and accumulate, c20 */ \
"vpadal.s16 q13, q5\n" /* pair add and accumulate, c21 */ \ "vpadal.s16 q13, q5\n" /* pair add and accumulate, c21 */ \
...@@ -1654,6 +1908,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1654,6 +1908,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
/* do relu */ \ /* do relu */ \
"cmp %[is_relu], #0\n" /* skip relu */ \ "cmp %[is_relu], #0\n" /* skip relu */ \
"beq 9f\n" /* skip relu */ \ "beq 9f\n" /* skip relu */ \
"cmp %[is_relu], #1\n" /* check if has relu6 */ \
"bne 10f\n" /* skip relu */ \
"vmov.i32 q15, #0\n" /* for relu */ \ "vmov.i32 q15, #0\n" /* for relu */ \
"vmax.f32 q8, q8, q15\n" /* relu */ \ "vmax.f32 q8, q8, q15\n" /* relu */ \
"vmax.f32 q9, q9, q15\n" /* relu */ \ "vmax.f32 q9, q9, q15\n" /* relu */ \
...@@ -1663,12 +1919,69 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1663,12 +1919,69 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vmax.f32 q3,q3, q15\n" /* relu */ \ "vmax.f32 q3,q3, q15\n" /* relu */ \
"vmax.f32 q4,q4, q15\n" /* relu */ \ "vmax.f32 q4,q4, q15\n" /* relu */ \
"vmax.f32 q5,q5, q15\n" /* relu */ \ "vmax.f32 q5,q5, q15\n" /* relu */ \
"9:\n" "b 9f\n"
#define GEMM_INT8_RELU6 \
/* do relu6 */ \
"10: \n" \
"cmp %[is_relu], #2\n" /*heck if has relu6*/ \
"bne 11f\n" /* skip relu */ \
"vmov.i32 q15, #0\n" /* for relu */ \
"vmax.f32 q8, q8, q15\n" /* relu */ \
"vmax.f32 q9, q9, q15\n" /* relu */ \
"vmax.f32 q0,q0, q15\n" /* relu */ \
"vmax.f32 q1,q1, q15\n" /* relu */ \
"vld1.f32 {d28-d29}, [%[alpha]] @ load relu6 alpha\n" \
"vmax.f32 q2,q2, q15\n" /* relu */ \
"vmax.f32 q3,q3, q15\n" /* relu */ \
"vmax.f32 q4,q4, q15\n" /* relu */ \
"vmax.f32 q5,q5, q15\n" /* relu */ \
"vmin.f32 q8, q8, q14\n" /* relu6 */ \
"vmin.f32 q9, q9, q14\n" /* relu6 */ \
"vmin.f32 q0,q0, q14\n" /* relu6 */ \
"vmin.f32 q1,q1, q14\n" /* relu6 */ \
"vmin.f32 q2,q2, q14\n" /* relu6 */ \
"vmin.f32 q3,q3, q14\n" /* relu6 */ \
"vmin.f32 q4,q4, q14\n" /* relu6 */ \
"vmin.f32 q5,q5, q14\n" /* relu6 */ \
"b 9f\n"
#define GEMM_INT8_LEAKY_RELU \
/* do relu6 */ \
"11: \n" \
"vmov.i32 q15, #0\n" /* for relu */ \
"vld1.f32 {d28-d29}, [%[alpha]] @ load relu6 alpha\n" \
"vcge.f32 q6, q8, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q8, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q9, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q9, q14 @ vmulq_f32 \n" \
"vcge.f32 q12, q0, q15 @ vcgeq_u32 \n" \
"vmul.f32 q13, q0, q14 @ vmulq_f32 \n" \
"vbif q8, q7, q6 @ choose \n" \
"vbif q9, q11, q10 @ choose \n" \
"vbif q0, q13, q12 @ choose \n" \
"vcge.f32 q6, q1, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q1, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q2, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q2, q14 @ vmulq_f32 \n" \
"vcge.f32 q12, q3, q15 @ vcgeq_u32 \n" \
"vmul.f32 q13, q3, q14 @ vmulq_f32 \n" \
"vbif q1, q7, q6 @ choose \n" \
"vbif q2, q11, q10 @ choose \n" \
"vbif q3, q13, q12 @ choose \n" \
"vcge.f32 q6, q4, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q4, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q5, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q5, q14 @ vmulq_f32 \n" \
"vbif q4, q7, q6 @ choose \n" \
"vbif q5, q11, q10 @ choose \n" \
"9: \n"
#define GEMM_INT8_FP32_OUT \ #define GEMM_INT8_FP32_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \ GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \ "vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \
"vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \ "vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \
"vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \ "vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \
...@@ -1678,6 +1991,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1678,6 +1991,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_INT8_OUT \ #define GEMM_INT8_INT8_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \ GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \ GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"vmov.f32 q7, #-0.5\n" /* neg offset */ \ "vmov.f32 q7, #-0.5\n" /* neg offset */ \
"vmov.f32 q10, #0.5\n" /* pos offset */ \ "vmov.f32 q10, #0.5\n" /* pos offset */ \
"vmov.f32 q11, #0.5\n" /* pos offset */ \ "vmov.f32 q11, #0.5\n" /* pos offset */ \
...@@ -1707,12 +2022,14 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1707,12 +2022,14 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vcgt.f32 q15, q5, #0\n" /* get pos mask */ \ "vcgt.f32 q15, q5, #0\n" /* get pos mask */ \
"vbif.f32 q12, q7, q14\n" /* get right offset */ \ "vbif.f32 q12, q7, q14\n" /* get right offset */ \
"vbif.f32 q13, q7, q15\n" /* get right offset */ \ "vbif.f32 q13, q7, q15\n" /* get right offset */ \
"add %[alpha], #16 \n" \
"vadd.f32 q2, q10, q2\n" /* r20, add offset */ \ "vadd.f32 q2, q10, q2\n" /* r20, add offset */ \
"vadd.f32 q3, q11, q3\n" /* r21, add offset */ \ "vadd.f32 q3, q11, q3\n" /* r21, add offset */ \
"vadd.f32 q4, q12, q4\n" /* r30, add offset */ \ "vadd.f32 q4, q12, q4\n" /* r30, add offset */ \
"vadd.f32 q5, q13, q5\n" /* r31, add offset */ \ "vadd.f32 q5, q13, q5\n" /* r31, add offset */ \
"vld1.32 {d12-d13}, [%[vmax]]\n" /* set q4 = -127 \n"*/ \ "vld1.f32 {d12-d13}, [%[alpha]] \n" \
"vcge.f32 q7, q8, q6\n" /* @ q8 >= -127 \n */ \ "sub %[alpha], #16 \n" \
"vcge.f32 q7, q8, q6\n" /* @ q8 >= -127 \n */ \
"vcge.f32 q10, q9, q6\n" /* @ q8 >= -127 \n */ \ "vcge.f32 q10, q9, q6\n" /* @ q8 >= -127 \n */ \
"vcge.f32 q11, q0, q6\n" /* @ q8 >= -127 \n */ \ "vcge.f32 q11, q0, q6\n" /* @ q8 >= -127 \n */ \
"vcge.f32 q12, q1, q6\n" /* @ q8 >= -127 \n */ \ "vcge.f32 q12, q1, q6\n" /* @ q8 >= -127 \n */ \
...@@ -1765,7 +2082,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1765,7 +2082,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
float32_t*& c_ptr2, // NOLINT float32_t*& c_ptr2, // NOLINT
float32_t*& c_ptr3, // NOLINT float32_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT
...@@ -1778,6 +2096,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1778,6 +2096,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[bias] "r"(bias), [bias] "r"(bias),
[alpha] "r"(alpha),
[rem] "r"(rem), [rem] "r"(rem),
[scale] "r"(scale) [scale] "r"(scale)
: "q0", : "q0",
...@@ -1796,7 +2115,6 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1796,7 +2115,6 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"q13", "q13",
"q14", "q14",
"q15", "q15",
"r0",
"cc", "cc",
"memory"); "memory");
} }
...@@ -1810,10 +2128,12 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1810,10 +2128,12 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
int8_t*& c_ptr2, // NOLINT int8_t*& c_ptr2, // NOLINT
int8_t*& c_ptr3, // NOLINT int8_t*& c_ptr3, // NOLINT
const float32_t* scale, const float32_t* scale,
bool is_relu, const float32_t* alpha,
int is_relu,
int k, int k,
int rem) { int rem) {
float vmax[4] = {-127.0, -127.0, -127.0, -127.0}; float new_ptr[8] = {
alpha[0], alpha[1], alpha[2], alpha[3], -127.0, -127.0, -127.0, -127.0};
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT
: [a_ptr] "+r"(a_ptr), : [a_ptr] "+r"(a_ptr),
[b_ptr] "+r"(b_ptr), [b_ptr] "+r"(b_ptr),
...@@ -1823,9 +2143,9 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1823,9 +2143,9 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[c_ptr3] "+r"(c_ptr3), [c_ptr3] "+r"(c_ptr3),
[k] "+r"(k) [k] "+r"(k)
: [is_relu] "r"(is_relu), : [is_relu] "r"(is_relu),
[alpha] "r"(new_ptr),
[bias] "r"(bias), [bias] "r"(bias),
[rem] "r"(rem), [rem] "r"(rem),
[vmax] "r"(vmax),
[scale] "r"(scale) [scale] "r"(scale)
: "q0", : "q0",
"q1", "q1",
...@@ -1843,7 +2163,6 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, ...@@ -1843,7 +2163,6 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"q13", "q13",
"q14", "q14",
"q15", "q15",
"r0",
"cc", "cc",
"memory"); "memory");
} }
...@@ -1859,9 +2178,10 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, ...@@ -1859,9 +2178,10 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu, int flag_act,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const float* alpha,
ARMContext* ctx) { ARMContext* ctx) {
const int KUP = ROUNDUP(K, KBLOCK_INT8); const int KUP = ROUNDUP(K, KBLOCK_INT8);
size_t llc_size = ctx->llc_size() / 4; size_t llc_size = ctx->llc_size() / 4;
...@@ -1969,7 +2289,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, ...@@ -1969,7 +2289,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
c_ptr2, c_ptr2,
c_ptr3, c_ptr3,
scale_local, scale_local,
is_relu, alpha,
flag_act,
k, k,
k_rem); k_rem);
if (flag_rem && (xb == bblocks - 1)) { if (flag_rem && (xb == bblocks - 1)) {
...@@ -3090,9 +3411,10 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed, ...@@ -3090,9 +3411,10 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu, int is_relu,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const float* alpha,
ARMContext* ctx) { ARMContext* ctx) {
size_t llc_size = ctx->llc_size() / 4; size_t llc_size = ctx->llc_size() / 4;
auto workspace = ctx->workspace_data<int8_t>(); auto workspace = ctx->workspace_data<int8_t>();
...@@ -3250,6 +3572,7 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed, ...@@ -3250,6 +3572,7 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
c_ptr6, c_ptr6,
c_ptr7, c_ptr7,
scale_local, scale_local,
alpha,
is_relu, is_relu,
k, k,
tail); tail);
...@@ -3871,21 +4194,76 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -3871,21 +4194,76 @@ void gemm_prepack_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
gemm_prepack_sdot_int8<float32_t>( gemm_prepack_sdot_int8<float32_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} else { } else {
gemm_prepack_oth_int8<float32_t>( gemm_prepack_oth_int8<float32_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} }
#else #else
gemm_prepack_oth_int8<float32_t>( gemm_prepack_oth_int8<float32_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
#endif #endif
} }
...@@ -3898,21 +4276,76 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -3898,21 +4276,76 @@ void gemm_prepack_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
auto act_type = act_param.active_type;
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
int flag_act = 0x00; // relu: 1, relu6: 2, leakey: 3
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 0x01;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 0x02;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 0x03;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
gemm_prepack_sdot_int8<int8_t>( gemm_prepack_sdot_int8<int8_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} else { } else {
gemm_prepack_oth_int8<int8_t>( gemm_prepack_oth_int8<int8_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} }
#else #else
gemm_prepack_oth_int8<int8_t>( gemm_prepack_oth_int8<int8_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
#endif #endif
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cmath> #include <cmath>
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -80,9 +81,9 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -80,9 +81,9 @@ void gemm_prepack_int8(const int8_t* A_packed,
int N, int N,
int K, int K,
bool is_bias, bool is_bias,
bool is_relu,
bool is_transB, bool is_transB,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) #define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b))
......
...@@ -30,8 +30,8 @@ void gemm_s8(bool is_transA, ...@@ -30,8 +30,8 @@ void gemm_s8(bool is_transA,
Dtype* C, Dtype* C,
const float* bias, const float* bias,
bool is_bias, bool is_bias,
bool is_relu,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
int hblock = get_hblock_int8(ctx); int hblock = get_hblock_int8(ctx);
int m_roundup = hblock * ((M + hblock - 1) / hblock); int m_roundup = hblock * ((M + hblock - 1) / hblock);
...@@ -42,7 +42,7 @@ void gemm_s8(bool is_transA, ...@@ -42,7 +42,7 @@ void gemm_s8(bool is_transA,
prepackA_int8(packed_A, A, lda, 0, M, 0, K, is_transA, ctx); prepackA_int8(packed_A, A, lda, 0, M, 0, K, is_transA, ctx);
gemm_prepack_int8( gemm_prepack_int8(
packed_A, B, bias, C, M, N, K, is_bias, is_relu, is_transB, scale, ctx); packed_A, B, bias, C, M, N, K, is_bias, is_transB, scale, act_param, ctx);
TargetFree(TargetType::kARM, packed_A); TargetFree(TargetType::kARM, packed_A);
} }
...@@ -56,8 +56,8 @@ template void gemm_s8<float>(bool is_transA, ...@@ -56,8 +56,8 @@ template void gemm_s8<float>(bool is_transA,
float* C, float* C,
const float* bias, const float* bias,
bool is_bias, bool is_bias,
bool is_relu,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
template void gemm_s8<int8_t>(bool is_transA, template void gemm_s8<int8_t>(bool is_transA,
...@@ -70,8 +70,8 @@ template void gemm_s8<int8_t>(bool is_transA, ...@@ -70,8 +70,8 @@ template void gemm_s8<int8_t>(bool is_transA,
int8_t* C, int8_t* C,
const float* bias, const float* bias,
bool is_bias, bool is_bias,
bool is_relu,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
} // namespace math } // namespace math
......
...@@ -34,8 +34,8 @@ void gemm_s8(bool is_transA, ...@@ -34,8 +34,8 @@ void gemm_s8(bool is_transA,
Dtype* C, Dtype* C,
const float* bias, const float* bias,
bool is_bias, bool is_bias,
bool is_relu,
const float* scale, const float* scale,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
} // namespace math } // namespace math
......
...@@ -27,7 +27,10 @@ inline void write_gemv_out(const int* in, ...@@ -27,7 +27,10 @@ inline void write_gemv_out(const int* in,
const float* scale, const float* scale,
const float* bias, const float* bias,
int size, int size,
bool is_relu); bool flag_act,
lite_api::ActivationType act,
float six,
float alpha);
template <> template <>
inline void write_gemv_out(const int* in, inline void write_gemv_out(const int* in,
...@@ -35,7 +38,10 @@ inline void write_gemv_out(const int* in, ...@@ -35,7 +38,10 @@ inline void write_gemv_out(const int* in,
const float* scale, const float* scale,
const float* bias, const float* bias,
int size, int size,
bool is_relu) { bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
int i = 0; int i = 0;
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
for (; i < size - 7; i += 8) { for (; i < size - 7; i += 8) {
...@@ -49,9 +55,25 @@ inline void write_gemv_out(const int* in, ...@@ -49,9 +55,25 @@ inline void write_gemv_out(const int* in,
float32x4_t vinf1 = vcvtq_f32_s32(vin1); float32x4_t vinf1 = vcvtq_f32_s32(vin1);
vout0 = vmlaq_f32(vout0, vinf0, vscale0); vout0 = vmlaq_f32(vout0, vinf0, vscale0);
vout1 = vmlaq_f32(vout1, vinf1, vscale1); vout1 = vmlaq_f32(vout1, vinf1, vscale1);
if (is_relu) { if (flag_act) {
vout0 = vmaxq_f32(vout0, vzero); if (act == lite_api::ActivationType::kRelu) {
vout1 = vmaxq_f32(vout1, vzero); vout0 = vmaxq_f32(vout0, vzero);
vout1 = vmaxq_f32(vout1, vzero);
} else if (act == lite_api::ActivationType::kRelu6) {
float32x4_t vsix = vdupq_n_f32(six);
vout0 = vmaxq_f32(vout0, vzero);
vout1 = vmaxq_f32(vout1, vzero);
vout0 = vminq_f32(vout0, vsix);
vout1 = vminq_f32(vout1, vsix);
} else if (act == lite_api::ActivationType::kLeakyRelu) {
float32x4_t valpha = vdupq_n_f32(alpha);
uint32x4_t maska = vcgeq_f32(vout0, vzero);
uint32x4_t maskb = vcgeq_f32(vout1, vzero);
float32x4_t suma = vmulq_f32(vout0, valpha);
float32x4_t sumb = vmulq_f32(vout1, valpha);
vout0 = vbslq_f32(maska, vout0, suma);
vout1 = vbslq_f32(maskb, vout1, sumb);
}
} }
vst1q_f32(out, vout0); vst1q_f32(out, vout0);
vst1q_f32(out + 4, vout1); vst1q_f32(out + 4, vout1);
...@@ -63,7 +85,15 @@ inline void write_gemv_out(const int* in, ...@@ -63,7 +85,15 @@ inline void write_gemv_out(const int* in,
for (; i < size; ++i) { for (; i < size; ++i) {
out[0] = *(in++) * *(scale)++; out[0] = *(in++) * *(scale)++;
out[0] += bias ? *(bias++) : 0.f; out[0] += bias ? *(bias++) : 0.f;
out[0] = is_relu ? (out[0] > 0.f ? out[0] : 0.f) : out[0]; if (flag_act) {
if (act == lite_api::ActivationType::kRelu) {
out[0] = out[0] > 0.f ? out[0] : 0.f;
} else if (act == lite_api::ActivationType::kRelu6) {
out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) {
out[0] = out[0] > 0.f ? out[0] : out[0] * alpha;
}
}
out++; out++;
} }
} }
...@@ -74,24 +104,40 @@ inline void write_gemv_out(const int* in, ...@@ -74,24 +104,40 @@ inline void write_gemv_out(const int* in,
const float* scale, const float* scale,
const float* bias, const float* bias,
int size, int size,
bool flag_relu) { bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
if (bias) { if (bias) {
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
out[0] = float tmp = *(in++) * *(scale++) + *(bias++);
saturate_cast<signed char>(roundf(*(in++) * *(scale++) + *(bias++))); if (flag_act) {
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127 if (act == lite_api::ActivationType::kRelu) {
if (flag_relu) { tmp = tmp > 0.f ? tmp : 0.f;
out[0] = out[0] > 0 ? out[0] : 0; } else if (act == lite_api::ActivationType::kRelu6) {
tmp = tmp > 0.f ? (tmp > six ? six : tmp) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) {
tmp = tmp > 0.f ? tmp : (tmp * alpha);
}
} }
out[0] = saturate_cast<signed char>(roundf(tmp));
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
out++; out++;
} }
} else { } else {
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
out[0] = saturate_cast<signed char>(roundf(*(in++) * *(scale++))); float tmp = *(in++) * *(scale++);
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127 if (flag_act) {
if (flag_relu) { if (act == lite_api::ActivationType::kRelu) {
out[0] = out[0] > 0 ? out[0] : 0; tmp = tmp > 0.f ? tmp : 0.f;
} else if (act == lite_api::ActivationType::kRelu6) {
tmp = tmp > 0.f ? (tmp > six ? six : tmp) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) {
tmp = tmp > 0.f ? tmp : tmp * alpha;
}
} }
out[0] = saturate_cast<signed char>(roundf(tmp));
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
out++; out++;
} }
} }
...@@ -107,7 +153,10 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -107,7 +153,10 @@ bool gemv_int8_oth(const int8_t* A,
const float* scale, const float* scale,
bool is_bias, bool is_bias,
const float* bias, const float* bias,
bool is_relu) { bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
if (transA) { if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false; return false;
...@@ -260,7 +309,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -260,7 +309,8 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[7] += ptr_in[i] * ptr_w7[i]; ptr_out[7] += ptr_in[i] * ptr_w7[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
} }
//! deal with remains //! deal with remains
...@@ -304,7 +354,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -304,7 +354,8 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
} }
#else // __aarch64__ #else // __aarch64__
int out_cnt = M >> 2; int out_cnt = M >> 2;
...@@ -398,7 +449,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -398,7 +449,8 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[2] += ptr_in[i] * ptr_w2[i]; ptr_out[2] += ptr_in[i] * ptr_w2[i];
ptr_out[3] += ptr_in[i] * ptr_w3[i]; ptr_out[3] += ptr_in[i] * ptr_w3[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 4, is_relu); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 4, flag_act, act, six, alpha);
} }
//! deal with remains //! deal with remains
#pragma omp parallel for #pragma omp parallel for
...@@ -439,7 +491,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -439,7 +491,8 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
} }
#endif // __aarch64__ #endif // __aarch64__
return true; return true;
...@@ -456,7 +509,10 @@ bool gemv_int8_sdot(const int8_t* A, ...@@ -456,7 +509,10 @@ bool gemv_int8_sdot(const int8_t* A,
const float* scale, const float* scale,
bool is_bias, bool is_bias,
const float* bias, const float* bias,
bool is_relu) { bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
if (transA) { if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false; return false;
...@@ -594,7 +650,8 @@ bool gemv_int8_sdot(const int8_t* A, ...@@ -594,7 +650,8 @@ bool gemv_int8_sdot(const int8_t* A,
ptr_out[6] += ptr_in[i] * ptr_w6[i]; ptr_out[6] += ptr_in[i] * ptr_w6[i];
ptr_out[7] += ptr_in[i] * ptr_w7[i]; ptr_out[7] += ptr_in[i] * ptr_w7[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
} }
//! deal with remains //! deal with remains
#pragma omp parallel for #pragma omp parallel for
...@@ -634,7 +691,8 @@ bool gemv_int8_sdot(const int8_t* A, ...@@ -634,7 +691,8 @@ bool gemv_int8_sdot(const int8_t* A,
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
} }
return true; return true;
} }
...@@ -650,19 +708,22 @@ bool gemv_int8<float>(const int8_t* A, ...@@ -650,19 +708,22 @@ bool gemv_int8<float>(const int8_t* A,
const float* scale, const float* scale,
bool is_bias, bool is_bias,
const float* bias, const float* bias,
bool is_relu, bool flag_act,
const ARMContext* ctx) { lite_api::ActivationType act,
const ARMContext* ctx,
float six,
float alpha) {
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
return gemv_int8_sdot<float>( return gemv_int8_sdot<float>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} else { } else {
return gemv_int8_oth<float>( return gemv_int8_oth<float>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} }
#else #else
return gemv_int8_oth<float>( return gemv_int8_oth<float>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
#endif #endif
} }
...@@ -676,19 +737,22 @@ bool gemv_int8<int8_t>(const int8_t* A, ...@@ -676,19 +737,22 @@ bool gemv_int8<int8_t>(const int8_t* A,
const float* scale, const float* scale,
bool is_bias, bool is_bias,
const float* bias, const float* bias,
bool is_relu, bool flag_act,
const ARMContext* ctx) { lite_api::ActivationType act,
const ARMContext* ctx,
float six,
float alpha) {
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
return gemv_int8_sdot<int8_t>( return gemv_int8_sdot<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} else { } else {
return gemv_int8_oth<int8_t>( return gemv_int8_oth<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} }
#else #else
return gemv_int8_oth<int8_t>( return gemv_int8_oth<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
#endif #endif
} }
......
...@@ -32,8 +32,11 @@ bool gemv_int8(const int8_t* A, ...@@ -32,8 +32,11 @@ bool gemv_int8(const int8_t* A,
const float* scale, const float* scale,
bool is_bias, bool is_bias,
const float* bias, const float* bias,
bool is_relu, bool flag_act,
const ARMContext* ctx); lite_api::ActivationType act,
const ARMContext* ctx,
float six = 6.f,
float alpha = 1.f);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -79,6 +79,11 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -79,6 +79,11 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
} }
flag_trans_bias_ = true; flag_trans_bias_ = true;
} }
//! update relu6 parameter
if (param.activation_param.active_type == lite_api::ActivationType::kRelu6) {
param.activation_param.Relu_clipped_coef =
param.activation_param.Relu_clipped_coef / param.output_scale;
}
} }
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
......
...@@ -156,7 +156,11 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -156,7 +156,11 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false; bool flag_relu = false;
operators::ActivationParam act_param;
lite_api::ActivationType act;
act_param.has_active = false;
if (param.activation_type == "relu") { if (param.activation_type == "relu") {
act = lite_api::ActivationType::kRelu;
flag_relu = true; flag_relu = true;
} }
if (flag_gemm_) { if (flag_gemm_) {
...@@ -170,8 +174,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -170,8 +174,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data, o_data,
nullptr, nullptr,
false, false,
false,
scale_.data(), scale_.data(),
act_param,
&ctx); &ctx);
if (param.bias) { if (param.bias) {
CHECK_EQ(param.bias->numel(), n_); CHECK_EQ(param.bias->numel(), n_);
...@@ -191,6 +195,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -191,6 +195,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
param.bias != nullptr, param.bias != nullptr,
b_data, b_data,
flag_relu, flag_relu,
act,
&ctx); &ctx);
} }
} }
...@@ -210,8 +215,14 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -210,8 +215,14 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false; bool flag_relu = false;
operators::ActivationParam act_param;
act_param.has_active = false;
lite_api::ActivationType act;
if (param.activation_type == "relu") { if (param.activation_type == "relu") {
flag_relu = true; flag_relu = true;
act_param.has_active = true;
act_param.active_type = lite_api::ActivationType::kRelu;
act = lite_api::ActivationType::kRelu;
} }
if (flag_gemm_) { if (flag_gemm_) {
CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel " CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel "
...@@ -226,8 +237,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -226,8 +237,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
o_data, o_data,
nullptr, nullptr,
false, false,
flag_relu,
scale_.data(), scale_.data(),
act_param,
&ctx); &ctx);
} else { } else {
for (int i = 0; i < m_; ++i) { for (int i = 0; i < m_; ++i) {
...@@ -243,6 +254,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -243,6 +254,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
param.bias != nullptr, param.bias != nullptr,
b_data, b_data,
flag_relu, flag_relu,
act,
&ctx); &ctx);
} }
} }
......
...@@ -53,12 +53,15 @@ DEFINE_int32(stride_w, 1, "stride width"); ...@@ -53,12 +53,15 @@ DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height"); DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width"); DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool(flag_relu, true, "do relu"); DEFINE_bool(flag_act, true, "do act");
DEFINE_bool(flag_bias, true, "with bias"); DEFINE_bool(flag_bias, true, "with bias");
DEFINE_double(clipped_coef, 1.0, "clipped relu coef");
DEFINE_double(leakey_relu_alpha, 8.88, "leakey relu alpha");
typedef paddle::lite::DDim DDim; typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam; typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in, DDim compute_out_dim(const DDim& dim_in,
...@@ -129,9 +132,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -129,9 +132,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
const std::vector<int>& pads, const std::vector<int>& pads,
const std::vector<int>& dilas, const std::vector<int>& dilas,
bool flag_bias, bool flag_bias,
bool flag_relu, int flag_act,
const std::vector<int>& thread_num, const std::vector<int>& thread_num,
const std::vector<int>& power_mode) { const std::vector<int>& power_mode,
const float six = 6.f,
const float alpha = 1.f) {
paddle::lite::DeviceInfo::Init(); paddle::lite::DeviceInfo::Init();
ConvParam param_int8_out; ConvParam param_int8_out;
ConvParam param_fp32_out; ConvParam param_fp32_out;
...@@ -142,7 +147,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -142,7 +147,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads, pads,
dilas, dilas,
flag_bias, flag_bias,
flag_relu, flag_act > 0,
&param_int8_out); &param_int8_out);
get_conv_param<PRECISION(kFloat)>(weight_dim, get_conv_param<PRECISION(kFloat)>(weight_dim,
...@@ -151,7 +156,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -151,7 +156,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads, pads,
dilas, dilas,
flag_bias, flag_bias,
flag_relu, flag_act > 0,
&param_fp32_out); &param_fp32_out);
Tensor weight_fp32; Tensor weight_fp32;
Tensor bias_fp32; Tensor bias_fp32;
...@@ -165,6 +170,22 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -165,6 +170,22 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias); param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias);
bias_fp32.CopyDataFrom(*param_int8_out.bias); bias_fp32.CopyDataFrom(*param_int8_out.bias);
} }
if (flag_act > 0) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
param_fp32_out.fuse_relu = true;
param_int8_out.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = alpha;
}
param_fp32_out.activation_param = act_param;
param_int8_out.activation_param = act_param;
}
std::vector<float> scale_in{1.f / 127}; std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f}; std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f};
...@@ -291,7 +312,9 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -291,7 +312,9 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads[2], pads[2],
pads[0], pads[0],
flag_bias, flag_bias,
static_cast<int>(flag_relu)); flag_act,
six,
alpha);
paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32, paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32,
dout_basic_int8, dout_basic_int8,
scale_out.data(), scale_out.data(),
...@@ -299,7 +322,6 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -299,7 +322,6 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
1, 1,
dim_out.production()); dim_out.production());
} }
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] * double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
weight_dim[3] / group; weight_dim[3] / group;
/// warm up /// warm up
...@@ -364,9 +386,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -364,9 +386,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group << ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", act: " << flag_act << ", threads: " << th
<< ", threads: " << th << ", power_mode: " << cls << ", power_mode: " << cls << " failed!!\n";
<< " failed!!\n";
} }
} }
} }
...@@ -423,9 +444,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -423,9 +444,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", act: " << flag_act << ", threads: " << th
<< ", threads: " << th << ", power_mode: " << cls << ", power_mode: " << cls << " failed!!\n";
<< " failed!!\n";
} }
} }
} }
...@@ -435,9 +455,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -435,9 +455,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< ", " << pads[3] << ", stride: " << strides[0] << ", " << ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", act: " << flag_act << ", threads: " << th
<< ", threads: " << th << ", power_mode: " << cls << ", power_mode: " << cls << " successed!!\n";
<< " successed!!\n";
} }
} }
} }
...@@ -452,9 +471,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -452,9 +471,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
const std::vector<int>& pads, const std::vector<int>& pads,
const std::vector<int>& dilas, const std::vector<int>& dilas,
bool flag_bias, bool flag_bias,
bool flag_relu, int flag_act,
const std::vector<int>& thread_num, const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {} const std::vector<int>& power_mode,
float six = 6.f,
float alpha = 1.f) {}
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
#if 1 /// 3x3dw #if 1 /// 3x3dw
...@@ -463,7 +484,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -463,7 +484,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1}) { for (auto& pad : {0, 1}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) { for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3}); DDim weights_dim({c, 1, 3, 3});
...@@ -479,9 +500,11 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -479,9 +500,11 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
{pad, pad, pad, pad}, {pad, pad, pad, pad},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -497,7 +520,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -497,7 +520,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2, 3, 4}) { for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1}) {
for (auto& c : {1, 5, 15, 33}) { for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
...@@ -513,9 +536,11 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -513,9 +536,11 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
{pad, pad, pad, pad}, {pad, pad, pad, pad},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{1, 4}, {1, 4},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -532,7 +557,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -532,7 +557,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
for (auto& cout : {1, 5, 17}) { for (auto& cout : {1, 5, 17}) {
for (auto& g : {1, 2}) { for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims; std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
continue; continue;
...@@ -550,9 +575,11 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -550,9 +575,11 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
{0, 0, 0, 0}, {0, 0, 0, 0},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -572,7 +599,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -572,7 +599,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
...@@ -587,9 +614,11 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -587,9 +614,11 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top, pad_bottom, pad_left, pad_right},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -612,7 +641,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -612,7 +641,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
...@@ -627,9 +656,11 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -627,9 +656,11 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top, pad_bottom, pad_left, pad_right},
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_relu, flag_act,
{4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -657,7 +688,7 @@ TEST(TestConvRandInt8, test_conv_rand) { ...@@ -657,7 +688,7 @@ TEST(TestConvRandInt8, test_conv_rand) {
for (auto& pad_right : {0, 1, 2}) { for (auto& pad_right : {0, 1, 2}) {
for (auto& dila : {1, 2}) { for (auto& dila : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_act : {0, 1, 2, 4}) {
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
break; break;
} }
...@@ -676,9 +707,11 @@ TEST(TestConvRandInt8, test_conv_rand) { ...@@ -676,9 +707,11 @@ TEST(TestConvRandInt8, test_conv_rand) {
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top, pad_bottom, pad_left, pad_right},
{dila, dila}, {dila, dila},
flag_bias, flag_bias,
flag_relu, flag_act,
{4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
} }
} }
...@@ -713,8 +746,10 @@ TEST(TestConvCustomInt8, test_conv_custom_size) { ...@@ -713,8 +746,10 @@ TEST(TestConvCustomInt8, test_conv_custom_size) {
{FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w}, {FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w},
{FLAGS_dila_h, FLAGS_dila_w}, {FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias, FLAGS_flag_bias,
FLAGS_flag_relu, FLAGS_flag_act,
{FLAGS_threads}, {FLAGS_threads},
{FLAGS_power_mode}); {FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
} }
#endif // custom #endif // custom
...@@ -22,10 +22,12 @@ ...@@ -22,10 +22,12 @@
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/profile/timer.h" #include "lite/core/profile/timer.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/tensor_utils.h" #include "lite/tests/utils/tensor_utils.h"
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
typedef paddle::lite::operators::ActivationParam ActivationParam;
DEFINE_int32(power_mode, DEFINE_int32(power_mode,
3, 3,
...@@ -92,6 +94,11 @@ bool test_gemm_int8(bool tra, ...@@ -92,6 +94,11 @@ bool test_gemm_int8(bool tra,
std::vector<float> scale_c = {k / 127.f}; std::vector<float> scale_c = {k / 127.f};
std::vector<float> scale_merge_fp32(static_cast<size_t>(m)); std::vector<float> scale_merge_fp32(static_cast<size_t>(m));
std::vector<float> scale_merge_int8(static_cast<size_t>(m)); std::vector<float> scale_merge_int8(static_cast<size_t>(m));
ActivationParam act_param;
act_param.has_active = has_relu;
if (has_relu) {
act_param.active_type = (paddle::lite_api::ActivationType)1;
}
for (int j = 0; j < m; ++j) { for (int j = 0; j < m; ++j) {
scale_merge_fp32[j] = scale_a[j] * scale_b[0]; scale_merge_fp32[j] = scale_a[j] * scale_b[0];
scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0]; scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0];
...@@ -178,9 +185,9 @@ bool test_gemm_int8(bool tra, ...@@ -178,9 +185,9 @@ bool test_gemm_int8(bool tra,
n, n,
k, k,
has_bias, has_bias,
has_relu,
trb, trb,
scale_merge_fp32.data(), scale_merge_fp32.data(),
act_param,
&ctx); &ctx);
} }
...@@ -202,9 +209,9 @@ bool test_gemm_int8(bool tra, ...@@ -202,9 +209,9 @@ bool test_gemm_int8(bool tra,
n, n,
k, k,
has_bias, has_bias,
has_relu,
trb, trb,
scale_merge_int8.data(), scale_merge_int8.data(),
act_param,
&ctx); &ctx);
t0.Stop(); t0.Stop();
} }
...@@ -229,9 +236,9 @@ bool test_gemm_int8(bool tra, ...@@ -229,9 +236,9 @@ bool test_gemm_int8(bool tra,
n, n,
k, k,
has_bias, has_bias,
has_relu,
trb, trb,
scale_merge_fp32.data(), scale_merge_fp32.data(),
act_param,
&ctx); &ctx);
t0.Stop(); t0.Stop();
} }
......
...@@ -45,11 +45,20 @@ DEFINE_int32(N, 512, "gemv: N"); ...@@ -45,11 +45,20 @@ DEFINE_int32(N, 512, "gemv: N");
DEFINE_bool(traA, false, "gemv: A transpose"); DEFINE_bool(traA, false, "gemv: A transpose");
DEFINE_bool(flag_relu, false, "do relu"); DEFINE_int32(flag_act, 0, "do act");
DEFINE_bool(flag_bias, false, "with bias"); DEFINE_bool(flag_bias, false, "with bias");
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_double(clipped_coef, 6.0, "clipped relu coef");
bool test_gemv_int8( bool test_gemv_int8(bool tra,
bool tra, int m, int n, bool has_bias, bool has_relu, int cls, int ths) { int m,
int n,
bool has_bias,
int flag_act,
int cls,
int ths,
float six = 6.f,
float alpha = 1.f) {
Tensor ta; Tensor ta;
Tensor tb; Tensor tb;
Tensor tc_int8; Tensor tc_int8;
...@@ -89,8 +98,7 @@ bool test_gemv_int8( ...@@ -89,8 +98,7 @@ bool test_gemv_int8(
} }
LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n
<< ", transA: " << (tra ? "true" : "false") << ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false"); << ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
auto da = ta.mutable_data<int8_t>(); auto da = ta.mutable_data<int8_t>();
...@@ -101,6 +109,16 @@ bool test_gemv_int8( ...@@ -101,6 +109,16 @@ bool test_gemv_int8(
auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>(); auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>();
auto dbias = tbias.mutable_data<float>(); auto dbias = tbias.mutable_data<float>();
paddle::lite_api::ActivationType act =
paddle::lite_api::ActivationType::kIndentity;
if (flag_act == 1) {
act = paddle::lite_api::ActivationType::kRelu;
} else if (flag_act == 2) {
act = paddle::lite_api::ActivationType::kRelu6;
} else if (flag_act == 4) {
act = paddle::lite_api::ActivationType::kLeakyRelu;
}
if (FLAGS_check_result) { if (FLAGS_check_result) {
Tensor ta_fp32; Tensor ta_fp32;
Tensor tb_fp32; Tensor tb_fp32;
...@@ -126,7 +144,9 @@ bool test_gemv_int8( ...@@ -126,7 +144,9 @@ bool test_gemv_int8(
0.f, 0.f,
false, false,
has_bias, has_bias,
has_relu); flag_act,
six,
alpha);
paddle::lite::arm::math::fp32_to_int8(dc_basic_fp32, paddle::lite::arm::math::fp32_to_int8(dc_basic_fp32,
dc_basic_int8, dc_basic_int8,
scale_c.data(), scale_c.data(),
...@@ -152,8 +172,11 @@ bool test_gemv_int8( ...@@ -152,8 +172,11 @@ bool test_gemv_int8(
scale_merge_fp32.data(), scale_merge_fp32.data(),
has_bias, has_bias,
dbias, dbias,
has_relu, flag_act > 0,
&ctx); act,
&ctx,
six,
alpha);
} }
/// int8 output compute /// int8 output compute
...@@ -175,8 +198,11 @@ bool test_gemv_int8( ...@@ -175,8 +198,11 @@ bool test_gemv_int8(
scale_merge_fp32.data(), scale_merge_fp32.data(),
has_bias, has_bias,
dbias, dbias,
has_relu, flag_act > 0,
&ctx); act,
&ctx,
six,
alpha);
t0.Stop(); t0.Stop();
} }
LOG(INFO) << "gemv_int8_int8 output: M: " << m << ", N: " << n LOG(INFO) << "gemv_int8_int8 output: M: " << m << ", N: " << n
...@@ -201,8 +227,11 @@ bool test_gemv_int8( ...@@ -201,8 +227,11 @@ bool test_gemv_int8(
scale_merge_int8.data(), scale_merge_int8.data(),
has_bias, has_bias,
dbias_int8, dbias_int8,
has_relu, flag_act > 0,
&ctx); act,
&ctx,
six / scale_c[0],
alpha);
t0.Stop(); t0.Stop();
} }
LOG(INFO) << "gemm_int8_fp32 output: M: " << m << ", N: " << n LOG(INFO) << "gemm_int8_fp32 output: M: " << m << ", N: " << n
...@@ -291,18 +320,27 @@ TEST(TestLiteGemvInt8, gemv_prepacked_int8) { ...@@ -291,18 +320,27 @@ TEST(TestLiteGemvInt8, gemv_prepacked_int8) {
for (auto& has_bias : {false, true}) { for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) { for (auto& has_relu : {false, true}) {
for (auto& th : {1, 2, 4}) { for (auto& th : {1, 2, 4}) {
auto flag = test_gemv_int8( float six = 6.f;
tra, m, n, has_bias, has_relu, FLAGS_power_mode, th); float alpha = 8.88f;
auto flag = test_gemv_int8(tra,
m,
n,
has_bias,
has_relu > 0,
FLAGS_power_mode,
th,
six,
alpha);
if (flag) { if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n LOG(INFO) << "test m = " << m << ", n=" << n
<< ", bias: " << (has_bias ? "true" : "false") << ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false") << ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false") << ", trans A: " << (tra ? "true" : "false")
<< " passed\n"; << " passed\n";
} else { } else {
LOG(FATAL) << "test m = " << m << ", n=" << n LOG(FATAL) << "test m = " << m << ", n=" << n
<< ", bias: " << (has_bias ? "true" : "false") << ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false") << ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false") << ", trans A: " << (tra ? "true" : "false")
<< " failed\n"; << " failed\n";
} }
...@@ -323,15 +361,17 @@ TEST(TestGemvInt8Custom, gemv_prepacked_int8_custom) { ...@@ -323,15 +361,17 @@ TEST(TestGemvInt8Custom, gemv_prepacked_int8_custom) {
FLAGS_M, FLAGS_M,
FLAGS_N, FLAGS_N,
FLAGS_flag_bias, FLAGS_flag_bias,
FLAGS_flag_relu, FLAGS_flag_act,
FLAGS_power_mode, FLAGS_power_mode,
FLAGS_threads); FLAGS_threads,
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
if (!flag) { if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!"; << ", act: " << FLAGS_flag_act << " failed!!";
} }
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " passed!!"; << ", act: " << FLAGS_flag_act << " passed!!";
} }
...@@ -203,8 +203,8 @@ static void basic_gemv(int m, ...@@ -203,8 +203,8 @@ static void basic_gemv(int m,
c[i] = tmp > (type2)0 ? tmp : (type2)0; c[i] = tmp > (type2)0 ? tmp : (type2)0;
} else if (flag_act == 2) { // relu 6 } else if (flag_act == 2) { // relu 6
c[i] = tmp > (type2)0 ? tmp : (type2)0; c[i] = tmp > (type2)0 ? tmp : (type2)0;
c[i] = c[i] < six ? c[i] : six; c[i] = c[i] < six ? c[i] : six; // ut compute
} else if (flag_act == 4) { // leakey relu } else if (flag_act == 4) { // leakey relu
c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp; c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp;
} }
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册