未验证 提交 9e361a4d 编写于 作者: Y yiicy 提交者: GitHub

[ARM] int8 direct_conv, dw_conv add relu6 and leaky relu fusion, test=develop (#3737)

int8 direct_conv, dw_conv add relu6 and leaky relu fusion
上级 cba42f0d
......@@ -36,7 +36,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -434,7 +435,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
chout,
hout,
wout,
flag_relu,
flag_act,
alpha,
bias_local,
flag_bias,
ptr_write,
......@@ -450,7 +452,8 @@ template void conv_depthwise_3x3s1_int8<int8_t>(int8_t* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -467,7 +470,8 @@ template void conv_depthwise_3x3s1_int8<float>(float* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......
......@@ -42,8 +42,30 @@ void conv_3x3s1_direct_int8(const int8_t* din,
Context<TARGET(kARM)>* ctx,
const float* scale) {
auto paddings = *param.paddings;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias;
auto act_param = param.activation_param;
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
int pad_h = paddings[0];
int pad_w = paddings[2];
......@@ -442,7 +464,8 @@ void conv_3x3s1_direct_int8(const int8_t* din,
chout,
hout,
wout,
flag_relu,
flag_act,
alpha,
bias_local,
flag_bias,
ptr_write,
......
......@@ -36,7 +36,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -447,7 +448,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
chout,
hout,
wout,
flag_relu,
flag_act,
alpha,
bias_local,
flag_bias,
ptr_write,
......@@ -463,7 +465,8 @@ template void conv_depthwise_3x3s2_int8<int8_t>(int8_t* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -480,7 +483,8 @@ template void conv_depthwise_3x3s2_int8<float>(float* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......
......@@ -47,8 +47,30 @@ void conv_3x3s2_direct_int8(const int8_t* din,
//! prepack input to tmp buffer
//! write output to tmp buffer
auto paddings = *param.paddings;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias;
auto act_param = param.activation_param;
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
int pad_h = paddings[0];
int pad_w = paddings[2];
......@@ -442,7 +464,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
chout,
hout,
wout,
flag_relu,
flag_act,
alpha,
bias_local,
flag_bias,
ptr_write,
......@@ -474,8 +497,30 @@ void conv_3x3s2_direct_int8(const int8_t* din,
//! prepack input to tmp buffer
//! write output to tmp buffer
auto paddings = *param.paddings;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias;
auto act_param = param.activation_param;
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
int pad_h = paddings[0];
int pad_w = paddings[2];
const int threads = ctx->threads();
......@@ -698,7 +743,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
chout,
hout,
wout,
flag_relu,
flag_act,
alpha,
bias_local,
flag_bias,
ptr_write,
......
......@@ -36,7 +36,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -726,7 +727,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
chout,
hout,
wout,
flag_relu,
flag_act,
alpha,
bias_local,
flag_bias,
ptr_write,
......@@ -742,7 +744,8 @@ template void conv_depthwise_5x5s1_int8<int8_t>(int8_t* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -759,7 +762,8 @@ template void conv_depthwise_5x5s1_int8<float>(float* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......
......@@ -36,7 +36,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -746,7 +747,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout,
chout,
hout,
wout,
flag_relu,
flag_act,
alpha,
bias_local,
flag_bias,
ptr_write,
......@@ -762,7 +764,8 @@ template void conv_depthwise_5x5s2_int8<int8_t>(int8_t* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -779,7 +782,8 @@ template void conv_depthwise_5x5s2_int8<float>(float* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......
......@@ -2643,48 +2643,81 @@ inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT
int cnt,
float32x4_t scale,
float32x4_t bias,
bool is_relu);
int flag_act,
float* alpha);
#ifdef __aarch64__
#define NCHWC4_TRANS_INT32 \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"movi v20.4s, #0\n" \
"1:\n" \
"trn1 v8.4s, v0.4s, v1.4s\n" \
"trn2 v9.4s, v0.4s, v1.4s\n" \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"trn1 v10.4s, v2.4s, v3.4s\n" \
"trn2 v11.4s, v2.4s, v3.4s\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"trn1 v16.2d, v8.2d, v10.2d\n" \
"trn2 v17.2d, v8.2d, v10.2d\n" \
"trn1 v18.2d, v9.2d, v11.2d\n" \
"trn2 v19.2d, v9.2d, v11.2d\n" /* int32 --> fp32 */ \
"scvtf v4.4s, v16.4s\n" \
"scvtf v5.4s, v17.4s\n" \
"scvtf v6.4s, v18.4s\n" \
"scvtf v7.4s, v19.4s\n" /* add bias */ \
"dup v16.4s, %[bias].s[0]\n" \
"dup v17.4s, %[bias].s[2]\n" \
"dup v18.4s, %[bias].s[1]\n" \
"dup v19.4s, %[bias].s[3]\n" /* mul scale */ \
"fmla v16.4s, v4.4s, %[scale].s[0]\n" \
"fmla v17.4s, v5.4s, %[scale].s[2]\n" \
"fmla v18.4s, v6.4s, %[scale].s[1]\n" \
"fmla v19.4s, v7.4s, %[scale].s[3]\n" /* relu */ \
"cbz %w[relu], 2f\n" \
"fmax v16.4s, v16.4s, v20.4s \n" \
"fmax v17.4s, v17.4s, v20.4s \n" \
"fmax v18.4s, v18.4s, v20.4s \n" \
"fmax v19.4s, v19.4s, v20.4s \n" \
"2:\n"
#define NCHWC4_TRANS_INT32 \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"1:\n" \
"trn1 v8.4s, v0.4s, v1.4s\n" \
"trn2 v9.4s, v0.4s, v1.4s\n" \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"trn1 v10.4s, v2.4s, v3.4s\n" \
"trn2 v11.4s, v2.4s, v3.4s\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"trn1 v16.2d, v8.2d, v10.2d\n" \
"trn2 v17.2d, v8.2d, v10.2d\n" \
"trn1 v18.2d, v9.2d, v11.2d\n" \
"trn2 v19.2d, v9.2d, v11.2d\n" /* int32 --> fp32 */ \
"scvtf v4.4s, v16.4s\n" \
"scvtf v5.4s, v17.4s\n" \
"scvtf v6.4s, v18.4s\n" \
"scvtf v7.4s, v19.4s\n" /* add bias */ \
"dup v16.4s, %[bias].s[0]\n" \
"dup v17.4s, %[bias].s[2]\n" \
"dup v18.4s, %[bias].s[1]\n" \
"dup v19.4s, %[bias].s[3]\n" /* mul scale */ \
"fmla v16.4s, v4.4s, %[scale].s[0]\n" \
"fmla v17.4s, v5.4s, %[scale].s[2]\n" \
"fmla v18.4s, v6.4s, %[scale].s[1]\n" \
"fmla v19.4s, v7.4s, %[scale].s[3]\n" \
"cmp %w[flag_act], #1\n" \
"bne 12f \n" \
"movi v20.4s, #0 \n" /* for relu*/ \
"fmax v16.4s, v16.4s, v20.4s \n" \
"fmax v17.4s, v17.4s, v20.4s \n" \
"fmax v18.4s, v18.4s, v20.4s \n" \
"fmax v19.4s, v19.4s, v20.4s \n" \
"b 2f \n" /* relu end */ \
"12: \n" /* no relu */ \
"cmp %w[flag_act], #0 \n" /* check no act */ \
"beq 2f \n" /* no act end */ \
"cmp %w[flag_act], #2 \n" /* check relu6 */ \
"bne 13f \n" /* jump no relu6*/ \
"movi v8.4s, #0 \n" /* for relu6 */ \
"ld1 {v9.4s}, [%[alpha]] \n" /* relu6 alpha */ \
"fmax v16.4s, v16.4s, v8.4s \n" /* relu6 */ \
"fmax v17.4s, v17.4s, v8.4s \n" /* relu6 */ \
"fmax v18.4s, v18.4s, v8.4s \n" /* relu6 */ \
"fmax v19.4s, v19.4s, v8.4s \n" /* relu6 */ \
"fmin v16.4s, v16.4s, v9.4s \n" /* relu6 */ \
"fmin v17.4s, v17.4s, v9.4s \n" /* relu6 */ \
"fmin v18.4s, v18.4s, v9.4s \n" /* relu6 */ \
"fmin v19.4s, v19.4s, v9.4s \n" /* relu6 */ \
"b 2f \n" /* relu6 end */ \
"13: \n" /* leakey relu */ \
"movi v12.4s, #0 \n" /* for leakey relu */ \
"ld1 {v13.4s}, [%[alpha]] \n" /* leakey relu alpha */ \
"fcmge v4.4s, v16.4s, v12.4s \n" /* vcgeq_f32 */ \
"fmul v5.4s, v16.4s, v13.4s \n" /* vmulq_f32 */ \
"fcmge v6.4s, v17.4s, v12.4s \n" /* vcgeq_f32 */ \
"fmul v7.4s, v17.4s, v13.4s \n" /* vmulq_f32 */ \
"fcmge v8.4s, v18.4s, v12.4s \n" /* vcgeq_f32 */ \
"fmul v9.4s, v18.4s, v13.4s \n" /* vmulq_f32 */ \
"fcmge v10.4s, v19.4s, v12.4s \n" /* vcgeq_f32 */ \
"fmul v11.4s, v19.4s, v13.4s \n" /* vmulq_f32 */ \
"bif v16.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v17.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v18.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v19.16b, v11.16b, v10.16b \n" /* choose*/ \
"2: \n" /* act end */
#else
#define NCHWC4_TRANS_INT32 \
"vld1.32 {d4-d7}, [%[ptr_din]]!\n" \
"vld1.32 {d8-d11}, [%[ptr_din]]!\n" \
"vmov.u32 q15, #0\n" \
"1:\n" /* transpose */ \
"vtrn.32 q2, q3\n" \
"vtrn.32 q4, q5\n" \
......@@ -2701,13 +2734,44 @@ inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT
"vmla.f32 q10, q6, %e[scale][0]\n" \
"vmla.f32 q11, q7, %e[scale][1]\n" \
"vmla.f32 q12, q8, %f[scale][0]\n" \
"vmla.f32 q13, q9, %f[scale][1]\n" /* relu */ \
"cmp %[relu], #0\n" \
"beq 2f\n" \
"vmax.f32 q10, q10, q15\n" \
"vmax.f32 q11, q11, q15\n" \
"vmax.f32 q12, q12, q15\n" \
"vmax.f32 q13, q13, q15\n" \
"vmla.f32 q13, q9, %f[scale][1]\n" \
"vmov.u32 q15, #0 \n" \
"cmp %[flag_act], #1 \n" \
"bne 12f \n" \
"vmax.f32 q10, q10, q15 \n" \
"vmax.f32 q11, q11, q15 \n" \
"vmax.f32 q12, q12, q15 \n" \
"vmax.f32 q13, q13, q15 \n" \
"b 2f \n" \
"12: \n" \
"cmp %[flag_act], #0 \n" \
"beq 2f \n" \
"cmp %[flag_act], #2 \n" \
"bne 13f \n" \
"vld1.f32 {d14-d15}, [%[alpha]] \n" \
"vmax.f32 q10, q10, q15 \n" \
"vmax.f32 q11, q11, q15 \n" \
"vmax.f32 q12, q12, q15 \n" \
"vmax.f32 q13, q13, q15 \n" \
"vmin.f32 q10, q10, q7 \n" \
"vmin.f32 q11, q11, q7 \n" \
"vmin.f32 q12, q12, q7 \n" \
"vmin.f32 q13, q13, q7 \n" \
"b 2f \n" \
"13: \n" \
"vld1.f32 {d6-d7}, [%[alpha]] \n" \
"vcge.f32 q6, q10, q15 \n" \
"vmul.f32 q7, q10, q3 \n" \
"vcge.f32 q8, q11, q15 \n" \
"vmul.f32 q9, q11, q3 \n" \
"vbif q10, q7, q6 \n" \
"vbif q11, q9, q8 \n" \
"vcge.f32 q6, q12, q15 \n" \
"vmul.f32 q7, q12, q3 \n" \
"vcge.f32 q8, q13, q15 \n" \
"vmul.f32 q9, q13, q3 \n" \
"vbif q12, q7, q6 \n" \
"vbif q13, q9, q8 \n" \
"2:\n"
#endif
......@@ -2721,7 +2785,8 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT
int cnt,
float32x4_t scale,
float32x4_t bias,
bool is_relu) {
int flag_act,
float* alpha) {
#ifdef __aarch64__
asm volatile(NCHWC4_TRANS_INT32
"subs %w[cnt], %w[cnt], #1\n"
......@@ -2737,7 +2802,10 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu)
: [scale] "w"(scale),
[bias] "w"(bias),
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc",
"memory",
"v0",
......@@ -2779,7 +2847,10 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu)
: [scale] "w"(scale),
[bias] "w"(bias),
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc",
"memory",
"q2",
......@@ -2808,7 +2879,8 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
int cnt,
float32x4_t scale,
float32x4_t bias,
bool is_relu) {
int flag_act,
float* alpha) {
#ifdef __aarch64__
float32x4_t vmax = vdupq_n_f32(-127.f);
asm volatile(NCHWC4_TRANS_INT32
......@@ -2852,7 +2924,8 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
: [scale] "w"(scale),
[vmax] "w"(vmax),
[bias] "w"(bias),
[relu] "r"(is_relu)
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc",
"memory",
"v0",
......@@ -2942,8 +3015,9 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
[cnt] "+r"(cnt)
: [scale] "w"(scale),
[bias] "w"(bias),
[relu] "r"(is_relu),
[vmax] "r"(vmax)
[vmax] "r"(vmax),
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc",
"memory",
"q2",
......@@ -2963,139 +3037,48 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
#endif
}
template <>
inline void int32_nchwc4_kernel(int32_t*& dout0, // NOLINT
int32_t*& dout1, // NOLINT
int32_t*& dout2, // NOLINT
int32_t*& dout3, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale,
float32x4_t bias,
bool is_relu) {
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/
"cbz %w[relu], 2f\n"
"smax v16.4s, v16.4s, v20.4s \n" /* relu */
"smax v17.4s, v17.4s, v20.4s \n" /* relu */
"smax v18.4s, v18.4s, v20.4s \n" /* relu */
"smax v19.4s, v19.4s, v20.4s \n" /* relu */
"2:\n"
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vmov.u32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q0, q1 @ trans q0, q1 \n"
"vtrn.32 q2, q3 @ trans q2, q3 \n"
"vswp.32 d1, d4 @ swap d1, d4 \n"
"vswp.32 d3, d6 @ swap d3, d6 \n"
"cmp %[relu], #0\n"
"bne 2f\n"
"vmax.s32 q0, q0, q15 @ relu\n"
"vmax.s32 q1, q1, q15 @ relu\n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"2:\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [relu] "r"(is_relu)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15");
#endif
}
template <typename Dtype>
inline Dtype cvt_kernel(int din, float scale, float bias, bool flag_relu);
inline Dtype cvt_kernel(
int din, float scale, float bias, int flag_act, float alpha);
template <>
inline float cvt_kernel(int din, float scale, float bias, bool flag_relu) {
if (flag_relu) {
inline float cvt_kernel(
int din, float scale, float bias, int flag_act, float alpha) {
if (flag_act == 1) {
return LITEMAX(din * scale + bias, 0);
} else if (flag_act == 0) {
return din * scale + bias;
} else if (flag_act == 2) {
float max = LITEMAX(din * scale + bias, 0);
return LITEMIN(max, alpha);
} else {
float result = din * scale + bias;
return result > 0 ? result : alpha * result;
}
return din * scale + bias;
}
template <>
inline int8_t cvt_kernel(int din, float scale, float bias, bool flag_relu) {
if (flag_relu) {
return saturate_cast<int8_t>(round(LITEMAX(din * scale + bias, 0)));
} else {
inline int8_t cvt_kernel(
int din, float scale, float bias, int flag_act, float alpha) {
if (flag_act == 1) {
auto tmp = saturate_cast<int8_t>(round(LITEMAX(din * scale + bias, 0)));
return tmp < -127 ? -127 : tmp;
} else if (flag_act == 0) {
auto tmp = saturate_cast<int8_t>(round(din * scale + bias));
return tmp < -127 ? -127 : tmp;
} else if (flag_act == 2) {
float max = LITEMAX(din * scale + bias, 0);
float relu6_result = LITEMIN(max, alpha);
auto tmp = saturate_cast<int8_t>(round(relu6_result));
return tmp < -127 ? -127 : tmp;
} else {
float result = din * scale + bias;
float leaky_result = result > 0 ? result : alpha * result;
auto tmp = saturate_cast<int8_t>(round(leaky_result));
return tmp < -127 ? -127 : tmp;
}
}
template <>
inline int32_t cvt_kernel(int din, float scale, float bias, bool flag_relu) {
if (flag_relu) {
return LITEMAX(din, 0);
}
return din;
}
template <typename Dtype>
inline void write_int32_nchwc4_to_nchw(const int* din,
Dtype* dout,
......@@ -3108,7 +3091,8 @@ inline void write_int32_nchwc4_to_nchw(const int* din,
int channel,
int height,
int width,
bool flag_relu,
int flag_act,
float* alpha,
float* bias,
bool flag_bias,
Dtype* trash_ptr,
......@@ -3160,21 +3144,22 @@ inline void write_int32_nchwc4_to_nchw(const int* din,
cnt,
w_scale,
w_bias,
flag_relu);
flag_act,
alpha);
}
if (we > width) {
int offset = 16 * (valid_w / 4 - 1);
din_hei_ptr = din + index + offset;
int j = we - 4;
for (; j < width; ++j) {
*(doutc0_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu);
*(doutc1_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[1], scale[1], bias[1], flag_relu);
*(doutc2_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[2], scale[2], bias[2], flag_relu);
*(doutc3_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[3], scale[3], bias[3], flag_relu);
*(doutc0_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[0], scale[0], bias[0], flag_act, alpha[0]);
*(doutc1_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[1], scale[1], bias[1], flag_act, alpha[0]);
*(doutc2_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[2], scale[2], bias[2], flag_act, alpha[0]);
*(doutc3_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[3], scale[3], bias[3], flag_act, alpha[0]);
din_hei_ptr += 4;
}
}
......@@ -3196,7 +3181,8 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
float32x4_t scale1,
float32x4_t bias0,
float32x4_t bias1,
bool is_relu);
int flag_act,
float* alpha);
// clang-format off
#ifdef __aarch64__
......@@ -3205,7 +3191,6 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
"ldp q2, q3, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \
"ldp q4, q5, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \
"ldp q6, q7, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \
"movi v31.4s, #0\n" /* main loop*/ \
"1:\n" \
"trn1 v8.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \
"trn2 v9.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \
......@@ -3256,17 +3241,71 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
"fmla v9.4s, v11.4s, %[scale1].s[2]\n" \
"fmla v12.4s, v14.4s, %[scale1].s[1]\n" \
"fmla v13.4s, v15.4s, %[scale1].s[3]\n" \
/* relu */ \
"cbz %w[relu], 2f\n" \
"fmax v16.4s, v16.4s, v31.4s\n" /*relu*/ \
"fmax v17.4s, v17.4s, v31.4s\n" /*relu*/ \
"fmax v18.4s, v18.4s, v31.4s\n" /*relu*/ \
"fmax v19.4s, v19.4s, v31.4s\n" /*relu*/ \
"fmax v8.4s, v8.4s, v31.4s\n" /*relu*/ \
"fmax v9.4s, v9.4s, v31.4s\n" /*relu*/ \
"fmax v12.4s, v12.4s, v31.4s\n" /*relu*/ \
"fmax v13.4s, v13.4s, v31.4s\n" /*relu*/ \
"2:\n"
/* activation */ \
"cmp %w[flag_act], #1\n" \
"bne 12f \n" \
"movi v31.4s, #0 \n" /* for relu*/ \
"fmax v16.4s, v16.4s, v31.4s \n" /*relu*/ \
"fmax v17.4s, v17.4s, v31.4s \n" /*relu*/ \
"fmax v18.4s, v18.4s, v31.4s \n" /*relu*/ \
"fmax v19.4s, v19.4s, v31.4s \n" /*relu*/ \
"fmax v8.4s, v8.4s, v31.4s \n" /*relu*/ \
"fmax v9.4s, v9.4s, v31.4s \n" /*relu*/ \
"fmax v12.4s, v12.4s, v31.4s \n" /*relu*/ \
"fmax v13.4s, v13.4s, v31.4s \n" /*relu*/ \
"b 2f \n" /* relu end */ \
"12: \n" /* no relu */ \
"cmp %w[flag_act], #0 \n" /* check no act */ \
"beq 2f \n" /* no act end */ \
"cmp %w[flag_act], #2 \n" /* check relu6 */ \
"bne 13f \n" /* jump no relu6*/ \
"movi v20.4s, #0 \n" /* for relu6 */ \
"ld1 {v21.4s}, [%[alpha]] \n" /* relu6 alpha */ \
"fmax v16.4s, v16.4s, v20.4s \n" /* relu6 */ \
"fmax v17.4s, v17.4s, v20.4s \n" /* relu6 */ \
"fmax v18.4s, v18.4s, v20.4s \n" /* relu6 */ \
"fmax v19.4s, v19.4s, v20.4s \n" /* relu6 */ \
"fmax v8.4s, v8.4s, v20.4s \n" /* relu6 */ \
"fmax v9.4s, v9.4s, v20.4s \n" /* relu6 */ \
"fmax v12.4s, v12.4s, v20.4s \n" /* relu6 */ \
"fmax v13.4s, v13.4s, v20.4s \n" /* relu6 */ \
"fmin v16.4s, v16.4s, v21.4s \n" /* relu6 */ \
"fmin v17.4s, v17.4s, v21.4s \n" /* relu6 */ \
"fmin v18.4s, v18.4s, v21.4s \n" /* relu6 */ \
"fmin v19.4s, v19.4s, v21.4s \n" /* relu6 */ \
"fmin v8.4s, v8.4s, v21.4s \n" /* relu6 */ \
"fmin v9.4s, v9.4s, v21.4s \n" /* relu6 */ \
"fmin v12.4s, v12.4s, v21.4s \n" /* relu6 */ \
"fmin v13.4s, v13.4s, v21.4s \n" /* relu6 */ \
"b 2f \n" /* relu6 end */ \
"13: \n" /* leakey relu */ \
"movi v20.4s, #0 \n" /* for leakey relu */ \
"ld1 {v21.4s}, [%[alpha]] \n" /* leakey relu alpha */ \
"fcmge v10.4s, v16.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v11.4s, v16.4s, v21.4s \n" /* vmulq_f32 */ \
"fcmge v14.4s, v17.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v15.4s, v17.4s, v21.4s \n" /* vmulq_f32 */ \
"fcmge v22.4s, v18.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v23.4s, v18.4s, v21.4s \n" /* vmulq_f32 */ \
"fcmge v24.4s, v19.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v25.4s, v19.4s, v21.4s \n" /* vmulq_f32 */ \
"bif v16.16b, v11.16b, v10.16b \n" /* choose*/ \
"bif v17.16b, v15.16b, v14.16b \n" /* choose*/ \
"bif v18.16b, v23.16b, v22.16b \n" /* choose*/ \
"bif v19.16b, v25.16b, v24.16b \n" /* choose*/ \
"fcmge v10.4s, v8.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v11.4s, v8.4s, v21.4s \n" /* vmulq_f32 */ \
"fcmge v14.4s, v9.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v15.4s, v9.4s, v21.4s \n" /* vmulq_f32 */ \
"fcmge v22.4s, v12.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v23.4s, v12.4s, v21.4s \n" /* vmulq_f32 */ \
"fcmge v24.4s, v13.4s, v20.4s \n" /* vcgeq_f32 */ \
"fmul v25.4s, v13.4s, v21.4s \n" /* vmulq_f32 */ \
"bif v8.16b, v11.16b, v10.16b \n" /* choose*/ \
"bif v9.16b, v15.16b, v14.16b \n" /* choose*/ \
"bif v12.16b, v23.16b, v22.16b \n" /* choose*/ \
"bif v13.16b, v25.16b, v24.16b \n" /* choose*/ \
"2: \n" /* act end */
#else
#define INT32_NCHWC8_TO_NCHW_FP32 \
......@@ -3312,18 +3351,68 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
"vswp d5, d12\n" /* q2: b0-b3, q6: d0-d3 */ \
"vswp d3, d10\n" /* q1: e0-e3, q5: g0-g3 */ \
"vswp d7, d14\n" /* q3: f0-f3, q7: h0-h3 */ \
/* relu */ \
"vmov.i32 q8, #0\n" \
"cmp %[relu], #0\n" \
"beq 2f\n" \
"vmax.f32 q0, q0, q8\n" /*relu*/ \
"vmax.f32 q2, q2, q8\n" /*relu*/ \
"vmax.f32 q4, q4, q8\n" /*relu*/ \
"vmax.f32 q6, q6, q8\n" /*relu*/ \
"vmax.f32 q1, q1, q8\n" /*relu*/ \
"vmax.f32 q3, q3, q8\n" /*relu*/ \
"vmax.f32 q5, q5, q8\n" /*relu*/ \
"vmax.f32 q7, q7, q8\n" /*relu*/ \
/* activation */ \
"vmov.u32 q8, #0 \n" \
"cmp %[flag_act], #1 \n" \
"bne 12f \n" \
"vmax.f32 q0, q0, q8 \n" /*relu*/ \
"vmax.f32 q2, q2, q8 \n" /*relu*/ \
"vmax.f32 q4, q4, q8 \n" /*relu*/ \
"vmax.f32 q6, q6, q8 \n" /*relu*/ \
"vmax.f32 q1, q1, q8 \n" /*relu*/ \
"vmax.f32 q3, q3, q8 \n" /*relu*/ \
"vmax.f32 q5, q5, q8 \n" /*relu*/ \
"vmax.f32 q7, q7, q8 \n" /*relu*/ \
"b 2f \n" \
"12: \n" \
"cmp %[flag_act], #0 \n" \
"beq 2f \n" \
"cmp %[flag_act], #2 \n" \
"bne 13f \n" \
"vld1.f32 {d18-d19}, [%[alpha]] \n" \
"vmax.f32 q0, q0, q8 \n" \
"vmax.f32 q2, q2, q8 \n" \
"vmax.f32 q4, q4, q8 \n" \
"vmax.f32 q6, q6, q8 \n" \
"vmax.f32 q1, q1, q8 \n" \
"vmax.f32 q3, q3, q8 \n" \
"vmax.f32 q5, q5, q8 \n" \
"vmax.f32 q7, q7, q8 \n" \
"vmin.f32 q0, q0, q9 \n" \
"vmin.f32 q2, q2, q9 \n" \
"vmin.f32 q4, q4, q9 \n" \
"vmin.f32 q6, q6, q9 \n" \
"vmin.f32 q1, q1, q9 \n" \
"vmin.f32 q3, q3, q9 \n" \
"vmin.f32 q5, q5, q9 \n" \
"vmin.f32 q7, q7, q9 \n" \
"b 2f \n" \
"13: \n" \
"vld1.f32 {d18-d19}, [%[alpha]] \n" \
"vcge.f32 q10, q0, q8 \n" \
"vmul.f32 q11, q0, q9 \n" \
"vbif q0, q11, q10 \n" \
"vcge.f32 q10, q2, q8 \n" \
"vmul.f32 q11, q2, q9 \n" \
"vbif q2, q11, q10 \n" \
"vcge.f32 q10, q4, q8 \n" \
"vmul.f32 q11, q4, q9 \n" \
"vbif q4, q11, q10 \n" \
"vcge.f32 q10, q6, q8 \n" \
"vmul.f32 q11, q6, q9 \n" \
"vbif q6, q11, q10 \n" \
"vcge.f32 q10, q1, q8 \n" \
"vmul.f32 q11, q1, q9 \n" \
"vbif q1, q11, q10 \n" \
"vcge.f32 q10, q3, q8 \n" \
"vmul.f32 q11, q3, q9 \n" \
"vbif q3, q11, q10 \n" \
"vcge.f32 q10, q5, q8 \n" \
"vmul.f32 q11, q5, q9 \n" \
"vbif q5, q11, q10 \n" \
"vcge.f32 q10, q7, q8 \n" \
"vmul.f32 q11, q7, q9 \n" \
"vbif q7, q11, q10 \n" \
"2:\n"
#endif
......@@ -3344,7 +3433,9 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT
float32x4_t scale1,
float32x4_t bias0,
float32x4_t bias1,
bool is_relu) {
int flag_act,
float* alpha) {
// clang-format off
#ifdef __aarch64__
asm volatile(INT32_NCHWC8_TO_NCHW_FP32
"subs %w[cnt], %w[cnt], #1\n" /* loop count -1*/
......@@ -3371,31 +3462,13 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT
[scale1] "w"(scale1),
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v31");
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v31"
);
#else
asm volatile(INT32_NCHWC8_TO_NCHW_FP32
"subs %[cnt], #1\n" /* loop count -1*/
......@@ -3422,22 +3495,13 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT
[scale1] "w"(scale1),
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[relu] "r"(is_relu)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11");
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "q11"
);
#endif
// clang-format on
}
template <>
......@@ -3455,7 +3519,9 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT
float32x4_t scale1,
float32x4_t bias0,
float32x4_t bias1,
bool is_relu) {
int flag_act,
float* alpha) {
// clang-format off
#ifdef __aarch64__
float32x4_t vmax = vdupq_n_f32(-127.f);
asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* fp32-int32 */
......@@ -3529,34 +3595,13 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[vmax] "w"(vmax),
[relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v31");
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
"v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
"v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
"v25", "v31"
);
#else
float vmax[4] = {-127.f, -127.f, -127.f, -127.f};
asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* set +-0.5 offset */
......@@ -3669,175 +3714,13 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT
[bias0] "w"(bias0),
[bias1] "w"(bias1),
[vmax] "r"(vmax),
[relu] "r"(is_relu)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11");
#endif
}
template <>
inline void int32_nchwc8_kernel(int32_t*& dout0, // NOLINT
int32_t*& dout1, // NOLINT
int32_t*& dout2, // NOLINT
int32_t*& dout3, // NOLINT
int32_t*& dout4, // NOLINT
int32_t*& dout5, // NOLINT
int32_t*& dout6, // NOLINT
int32_t*& dout7, // NOLINT
const int32_t*& din, // NOLINT
int cnt,
float32x4_t scale0,
float32x4_t scale1,
float32x4_t bias0,
float32x4_t bias1,
bool is_relu) {
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"movi v20.4s, #0 \n" /* for relu */
"1: \n" /* main loop*/
"trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */
"cbz %w[relu], 2f\n"
"smax v16.4s, v16.4s, v20.4s \n" /*relu*/
"smax v17.4s, v17.4s, v20.4s \n" /*relu*/
"smax v18.4s, v18.4s, v20.4s \n" /*relu*/
"smax v19.4s, v19.4s, v20.4s \n" /*relu*/
"smax v8.4s, v8.4s, v20.4s \n" /*relu*/
"smax v9.4s, v9.4s, v20.4s \n" /*relu*/
"smax v12.4s, v12.4s, v20.4s \n" /*relu*/
"smax v13.4s, v13.4s, v20.4s \n" /*relu*/
"2:\n"
"str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/
"str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/
"str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/
"str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/
"subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/
"str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/
"str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/
"str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/
"str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/
"bne 1b \n" /* jump to main loop*/
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din),
[cnt] "+r"(cnt)
: [relu] "r"(is_relu)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
#else
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"vmov.s32 q15, #0 @ dump zero\n"
"1: @ main loop\n"
"vtrn.32 q0, q2 @ trans q0, q2 \n"
"vtrn.32 q4, q6 @ trans q4, q6 \n"
"vswp.32 d1, d8 @ swap d1, d8 \n"
"vswp.32 d5, d12 @ swap d5, d12\n"
"vtrn.32 q1, q3 @ trans q1, q3 \n"
"vtrn.32 q5, q7 @ trans q5, q7 \n"
"vswp.32 d3, d10 @ swap d3, d10\n"
"vswp.32 d7, d14 @ swap d7, d14\n"
"cmp %[relu], #0\n"
"bne 2f\n"
"vmax.s32 q0, q0, q15 @ relu\n"
"vmax.s32 q1, q1, q15 @ relu\n"
"vmax.s32 q2, q2, q15 @ relu\n"
"vmax.s32 q3, q3, q15 @ relu\n"
"vmax.s32 q4, q4, q15 @ relu\n"
"vmax.s32 q5, q5, q15 @ relu\n"
"vmax.s32 q6, q6, q15 @ relu\n"
"vmax.s32 q7, q7, q15 @ relu\n"
"2:\n"
"subs %[cnt], %[cnt], #1 @ loop count - 1\n"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n"
"vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n"
"vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n"
"vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n"
"vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n"
"vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n"
"bne 1b @ jump to main loop\n"
: [doutc0r0] "+r"(dout0),
[doutc1r0] "+r"(dout1),
[doutc2r0] "+r"(dout2),
[doutc3r0] "+r"(dout3),
[doutc4r0] "+r"(dout4),
[doutc5r0] "+r"(dout5),
[doutc6r0] "+r"(dout6),
[doutc7r0] "+r"(dout7),
[ptr_din] "+r"(din)
: [cnt] "r"(cnt), [relu] "r"(is_relu)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15");
[flag_act] "r"(flag_act),
[alpha] "r"(alpha)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "q11"
);
#endif
// clang-format on
}
/*wirte result in outputs
......@@ -3855,7 +3738,8 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int channel,
int height,
int width,
bool flag_relu,
int flag_act,
float* alpha,
float* bias,
bool flag_bias,
Dtype* trash_ptr,
......@@ -3931,46 +3815,47 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
w_scale1,
w_bias0,
w_bias1,
flag_relu);
flag_act,
alpha);
}
if (we > width) {
int offset = 32 * cnt;
din_hei_ptr = ptr_din + offset;
for (int j = ws + cnt * 4; j < width; ++j) {
if (flag_bias) {
*(doutc0_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu);
*(doutc1_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[1], scale[1], bias[1], flag_relu);
*(doutc2_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[2], scale[2], bias[2], flag_relu);
*(doutc3_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[3], scale[3], bias[3], flag_relu);
*(doutc4_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[4], scale[4], bias[4], flag_relu);
*(doutc5_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[5], scale[5], bias[5], flag_relu);
*(doutc6_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[6], scale[6], bias[6], flag_relu);
*(doutc7_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[7], scale[7], bias[7], flag_relu);
*(doutc0_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[0], scale[0], bias[0], flag_act, alpha[0]);
*(doutc1_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[1], scale[1], bias[1], flag_act, alpha[0]);
*(doutc2_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[2], scale[2], bias[2], flag_act, alpha[0]);
*(doutc3_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[3], scale[3], bias[3], flag_act, alpha[0]);
*(doutc4_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[4], scale[4], bias[4], flag_act, alpha[0]);
*(doutc5_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[5], scale[5], bias[5], flag_act, alpha[0]);
*(doutc6_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[6], scale[6], bias[6], flag_act, alpha[0]);
*(doutc7_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[7], scale[7], bias[7], flag_act, alpha[0]);
} else {
*(doutc0_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], 0.f, flag_relu);
*(doutc1_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[1], scale[1], 0.f, flag_relu);
*(doutc2_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[2], scale[2], 0.f, flag_relu);
*(doutc3_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[3], scale[3], 0.f, flag_relu);
*(doutc4_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[4], scale[4], 0.f, flag_relu);
*(doutc5_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[5], scale[5], 0.f, flag_relu);
*(doutc6_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[6], scale[6], 0.f, flag_relu);
*(doutc7_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[7], scale[7], 0.f, flag_relu);
*(doutc0_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[0], scale[0], 0.f, flag_act, alpha[0]);
*(doutc1_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[1], scale[1], 0.f, flag_act, alpha[0]);
*(doutc2_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[2], scale[2], 0.f, flag_act, alpha[0]);
*(doutc3_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[3], scale[3], 0.f, flag_act, alpha[0]);
*(doutc4_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[4], scale[4], 0.f, flag_act, alpha[0]);
*(doutc5_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[5], scale[5], 0.f, flag_act, alpha[0]);
*(doutc6_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[6], scale[6], 0.f, flag_act, alpha[0]);
*(doutc7_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[7], scale[7], 0.f, flag_act, alpha[0]);
}
din_hei_ptr += 8;
}
......
......@@ -94,7 +94,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -112,7 +113,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -178,7 +180,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......@@ -196,7 +199,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
......
......@@ -790,8 +790,30 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
int pad_h = paddings[0];
int pad_w = paddings[2];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
......@@ -799,7 +821,8 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......@@ -816,7 +839,8 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......@@ -849,8 +873,30 @@ void conv_depthwise_3x3_int8_int8(const void* din,
int pad_h = paddings[0];
int pad_w = paddings[2];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
......@@ -858,7 +904,8 @@ void conv_depthwise_3x3_int8_int8(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......@@ -875,7 +922,8 @@ void conv_depthwise_3x3_int8_int8(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......@@ -908,8 +956,30 @@ void conv_depthwise_5x5_int8_fp32(const void* din,
int pad_h = paddings[0];
int pad_w = paddings[2];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
......@@ -917,7 +987,8 @@ void conv_depthwise_5x5_int8_fp32(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......@@ -934,7 +1005,8 @@ void conv_depthwise_5x5_int8_fp32(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......@@ -967,8 +1039,30 @@ void conv_depthwise_5x5_int8_int8(const void* din,
int pad_h = paddings[0];
int pad_w = paddings[2];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
auto act_param = param.activation_param;
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
......@@ -976,7 +1070,8 @@ void conv_depthwise_5x5_int8_int8(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......@@ -993,7 +1088,8 @@ void conv_depthwise_5x5_int8_int8(const void* din,
scale,
bias,
flag_bias,
flag_relu,
flag_act,
alpha,
num,
ch_in,
h_in,
......
......@@ -534,18 +534,18 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"fmin v17.4s, v17.4s, v1.4s\n" /* relu6 */ \
"fmin v18.4s, v18.4s, v1.4s\n" /* relu6 */ \
"fmin v19.4s, v19.4s, v1.4s\n" /* relu6 */ \
"fmin v20.4s, v20.4s, v0.4s\n" /* relu6 */ \
"fmin v21.4s, v21.4s, v0.4s\n" /* relu6 */ \
"fmin v22.4s, v22.4s, v0.4s\n" /* relu6 */ \
"fmin v23.4s, v23.4s, v0.4s\n" /* relu6 */ \
"fmin v24.4s, v24.4s, v0.4s\n" /* relu6 */ \
"fmin v25.4s, v25.4s, v0.4s\n" /* relu6 */ \
"fmin v26.4s, v26.4s, v0.4s\n" /* relu6 */ \
"fmin v27.4s, v27.4s, v0.4s\n" /* relu6 */ \
"fmin v28.4s, v28.4s, v0.4s\n" /* relu6 */ \
"fmin v29.4s, v29.4s, v0.4s\n" /* relu6 */ \
"fmin v30.4s, v30.4s, v0.4s\n" /* relu6 */ \
"fmin v31.4s, v31.4s, v0.4s\n" /* relu6 */ \
"fmin v20.4s, v20.4s, v1.4s\n" /* relu6 */ \
"fmin v21.4s, v21.4s, v1.4s\n" /* relu6 */ \
"fmin v22.4s, v22.4s, v1.4s\n" /* relu6 */ \
"fmin v23.4s, v23.4s, v1.4s\n" /* relu6 */ \
"fmin v24.4s, v24.4s, v1.4s\n" /* relu6 */ \
"fmin v25.4s, v25.4s, v1.4s\n" /* relu6 */ \
"fmin v26.4s, v26.4s, v1.4s\n" /* relu6 */ \
"fmin v27.4s, v27.4s, v1.4s\n" /* relu6 */ \
"fmin v28.4s, v28.4s, v1.4s\n" /* relu6 */ \
"fmin v29.4s, v29.4s, v1.4s\n" /* relu6 */ \
"fmin v30.4s, v30.4s, v1.4s\n" /* relu6 */ \
"fmin v31.4s, v31.4s, v1.4s\n" /* relu6 */ \
"b 9f \n" /* relu end */
#define GEMM_INT8_LEAKY_RELU \
......
......@@ -169,6 +169,12 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
flag_trans_bias_ = true;
}
//! update relu6 parameter
if (param.activation_param.has_active &&
param.activation_param.active_type == lite_api::ActivationType::kRelu6) {
param.activation_param.Relu_clipped_coef =
param.activation_param.Relu_clipped_coef / param.output_scale;
}
/// select dw conv kernel
if (kw == 3) {
// trans weights
......
......@@ -39,7 +39,8 @@ inline bool direct_conv_trans_weights(
const std::vector<float>& w_scale,
float in_scale,
float out_scale,
std::vector<float>& merge_scale) { // NOLINT
std::vector<float>& merge_scale, // NOLINT
float* relu_clipped_coef) {
constexpr int cblock = 4;
int oc = win->dims()[0];
int ic = win->dims()[1];
......@@ -64,7 +65,8 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kFloat)>(
const std::vector<float>& w_scale,
float in_scale,
float out_scale,
std::vector<float>& merge_scale) { // NOLINT
std::vector<float>& merge_scale, // NOLINT
float* relu_clipped_coef) {
int cblock = 4;
if (stride == 2) {
cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num();
......@@ -103,7 +105,8 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kInt8)>(
const std::vector<float>& w_scale,
float in_scale,
float out_scale,
std::vector<float>& merge_scale) { // NOLINT
std::vector<float>& merge_scale, // NOLINT
float* relu_clipped_coef) {
int cblock = 4;
if (stride == 2) {
cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num();
......@@ -130,6 +133,8 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kInt8)>(
merge_scale[i] = w_scale[i] * scale;
}
}
/// update relu_clipped_coef
*relu_clipped_coef /= out_scale;
/// update bias
if (bin) {
bout->Resize(bin->dims());
......@@ -167,16 +172,17 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
<< "direct conv only support conv3x3s1 and conv3x3s2";
CHECK(kw == 3 && kh == 3)
<< "direct conv only support conv3x3s1 and conv3x3s2";
flag_trans_bias_ =
direct_conv_trans_weights<Ptype, OutType>(param.filter,
&weights_,
param.bias,
&bias_,
sw,
param.weight_scale,
param.input_scale,
param.output_scale,
w_scale_);
flag_trans_bias_ = direct_conv_trans_weights<Ptype, OutType>(
param.filter,
&weights_,
param.bias,
&bias_,
sw,
param.weight_scale,
param.input_scale,
param.output_scale,
w_scale_,
&param.activation_param.Relu_clipped_coef);
}
virtual void Run();
......
......@@ -56,7 +56,7 @@ DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool(flag_act, true, "do act");
DEFINE_bool(flag_bias, true, "with bias");
DEFINE_double(clipped_coef, 1.0, "clipped relu coef");
DEFINE_double(leakey_relu_alpha, 8.88, "leakey relu alpha");
DEFINE_double(leakey_relu_alpha, 2.22, "leakey relu alpha");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
......@@ -188,7 +188,14 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
}
std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f};
std::vector<float> scale_out(1, weight_dim.count(1, 4) / 127.f);
if (flag_act == 2) {
scale_out[0] = six / 127.f;
} else if (flag_act == 4) {
if (std::abs(alpha) > 1) {
scale_out[0] *= std::abs(alpha);
}
}
std::vector<float> scale_w(weight_dim[0], 1.f / 127);
param_int8_out.input_scale = scale_in[0];
......@@ -484,7 +491,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3});
......@@ -520,7 +527,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
......@@ -553,7 +560,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
#if 1 /// conv1x1s1
TEST(TestConv1x1s1Int8, test_conv1x1s1) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) {
for (auto& cin : {1, 3, 8, 33}) {
for (auto& cout : {1, 5, 17}) {
for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) {
......@@ -599,7 +606,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
......@@ -641,7 +648,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
......@@ -673,7 +680,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
}
#endif /// conv3x3s2
#if 0 /// random param conv
#if 1 /// random param conv
TEST(TestConvRandInt8, test_conv_rand) {
if (FLAGS_basic_test) {
for (auto& cin : {1, 17}) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册