diff --git a/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc b/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc index bc2097b9286dbce4430739a0784f2691c62d37a1..832e3182bac94638be52908afef0b9fc1b03c1f2 100644 --- a/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc +++ b/lite/backends/arm/math/conv3x3s1_depthwise_int8.cc @@ -36,7 +36,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -434,7 +435,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, chout, hout, wout, - flag_relu, + flag_act, + alpha, bias_local, flag_bias, ptr_write, @@ -450,7 +452,8 @@ template void conv_depthwise_3x3s1_int8(int8_t* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -467,7 +470,8 @@ template void conv_depthwise_3x3s1_int8(float* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, diff --git a/lite/backends/arm/math/conv3x3s1_direct_int8.cc b/lite/backends/arm/math/conv3x3s1_direct_int8.cc index 64e72bc441bb93fa955e12ff53ce17f0e37b4830..eecdb7d3a418a7a74257e8b60c01a425783e40e3 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_int8.cc @@ -42,8 +42,30 @@ void conv_3x3s1_direct_int8(const int8_t* din, Context* ctx, const float* scale) { auto paddings = *param.paddings; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; + auto act_param = param.activation_param; + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } int pad_h = paddings[0]; int pad_w = paddings[2]; @@ -442,7 +464,8 @@ void conv_3x3s1_direct_int8(const int8_t* din, chout, hout, wout, - flag_relu, + flag_act, + alpha, bias_local, flag_bias, ptr_write, diff --git a/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc b/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc index 2e475fc6067cf52962038fc4bf18c99909e4bafd..5ccfd18a44078ef1c7218d99d3e5ed8032e9b953 100644 --- a/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc +++ b/lite/backends/arm/math/conv3x3s2_depthwise_int8.cc @@ -36,7 +36,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -447,7 +448,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, chout, hout, wout, - flag_relu, + flag_act, + alpha, bias_local, flag_bias, ptr_write, @@ -463,7 +465,8 @@ template void conv_depthwise_3x3s2_int8(int8_t* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -480,7 +483,8 @@ template void conv_depthwise_3x3s2_int8(float* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, diff --git a/lite/backends/arm/math/conv3x3s2_direct_int8.cc b/lite/backends/arm/math/conv3x3s2_direct_int8.cc index 3d6f3dd743c3e46b6123f2c93dbfed586ad7b4c6..b36fe83563718b85c71546abe958098e1e413760 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_int8.cc @@ -47,8 +47,30 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! prepack input to tmp buffer //! write output to tmp buffer auto paddings = *param.paddings; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; + auto act_param = param.activation_param; + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } int pad_h = paddings[0]; int pad_w = paddings[2]; @@ -442,7 +464,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, chout, hout, wout, - flag_relu, + flag_act, + alpha, bias_local, flag_bias, ptr_write, @@ -474,8 +497,30 @@ void conv_3x3s2_direct_int8(const int8_t* din, //! prepack input to tmp buffer //! write output to tmp buffer auto paddings = *param.paddings; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; + auto act_param = param.activation_param; + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } int pad_h = paddings[0]; int pad_w = paddings[2]; const int threads = ctx->threads(); @@ -698,7 +743,8 @@ void conv_3x3s2_direct_int8(const int8_t* din, chout, hout, wout, - flag_relu, + flag_act, + alpha, bias_local, flag_bias, ptr_write, diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc index ed3dad300804dc90fac874999ac5d0a420cff4a4..5a5f3f8c025a1a7951c31b90af85d65c1108087d 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc @@ -36,7 +36,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -726,7 +727,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout, chout, hout, wout, - flag_relu, + flag_act, + alpha, bias_local, flag_bias, ptr_write, @@ -742,7 +744,8 @@ template void conv_depthwise_5x5s1_int8(int8_t* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -759,7 +762,8 @@ template void conv_depthwise_5x5s1_int8(float* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_int8.cc b/lite/backends/arm/math/conv5x5s2_depthwise_int8.cc index 0ac1705de76102c92c9e63d64721aa2467baaf04..f5979524540f93fc66a589a5b4d19239a3fe8b98 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_int8.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_int8.cc @@ -36,7 +36,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -746,7 +747,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout, chout, hout, wout, - flag_relu, + flag_act, + alpha, bias_local, flag_bias, ptr_write, @@ -762,7 +764,8 @@ template void conv_depthwise_5x5s2_int8(int8_t* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -779,7 +782,8 @@ template void conv_depthwise_5x5s2_int8(float* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index c4fb51021e5b0288a4bc1fd476764348fdc7e450..78d4f3f74e3e8a0fb06b1fda83ad5deed281621b 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -2643,48 +2643,81 @@ inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT int cnt, float32x4_t scale, float32x4_t bias, - bool is_relu); + int flag_act, + float* alpha); #ifdef __aarch64__ -#define NCHWC4_TRANS_INT32 \ - "ldp q0, q1, [%[ptr_din]], #32\n" \ - "ldp q2, q3, [%[ptr_din]], #32\n" \ - "movi v20.4s, #0\n" \ - "1:\n" \ - "trn1 v8.4s, v0.4s, v1.4s\n" \ - "trn2 v9.4s, v0.4s, v1.4s\n" \ - "ldp q0, q1, [%[ptr_din]], #32\n" \ - "trn1 v10.4s, v2.4s, v3.4s\n" \ - "trn2 v11.4s, v2.4s, v3.4s\n" \ - "ldp q2, q3, [%[ptr_din]], #32\n" \ - "trn1 v16.2d, v8.2d, v10.2d\n" \ - "trn2 v17.2d, v8.2d, v10.2d\n" \ - "trn1 v18.2d, v9.2d, v11.2d\n" \ - "trn2 v19.2d, v9.2d, v11.2d\n" /* int32 --> fp32 */ \ - "scvtf v4.4s, v16.4s\n" \ - "scvtf v5.4s, v17.4s\n" \ - "scvtf v6.4s, v18.4s\n" \ - "scvtf v7.4s, v19.4s\n" /* add bias */ \ - "dup v16.4s, %[bias].s[0]\n" \ - "dup v17.4s, %[bias].s[2]\n" \ - "dup v18.4s, %[bias].s[1]\n" \ - "dup v19.4s, %[bias].s[3]\n" /* mul scale */ \ - "fmla v16.4s, v4.4s, %[scale].s[0]\n" \ - "fmla v17.4s, v5.4s, %[scale].s[2]\n" \ - "fmla v18.4s, v6.4s, %[scale].s[1]\n" \ - "fmla v19.4s, v7.4s, %[scale].s[3]\n" /* relu */ \ - "cbz %w[relu], 2f\n" \ - "fmax v16.4s, v16.4s, v20.4s \n" \ - "fmax v17.4s, v17.4s, v20.4s \n" \ - "fmax v18.4s, v18.4s, v20.4s \n" \ - "fmax v19.4s, v19.4s, v20.4s \n" \ - "2:\n" +#define NCHWC4_TRANS_INT32 \ + "ldp q0, q1, [%[ptr_din]], #32\n" \ + "ldp q2, q3, [%[ptr_din]], #32\n" \ + "1:\n" \ + "trn1 v8.4s, v0.4s, v1.4s\n" \ + "trn2 v9.4s, v0.4s, v1.4s\n" \ + "ldp q0, q1, [%[ptr_din]], #32\n" \ + "trn1 v10.4s, v2.4s, v3.4s\n" \ + "trn2 v11.4s, v2.4s, v3.4s\n" \ + "ldp q2, q3, [%[ptr_din]], #32\n" \ + "trn1 v16.2d, v8.2d, v10.2d\n" \ + "trn2 v17.2d, v8.2d, v10.2d\n" \ + "trn1 v18.2d, v9.2d, v11.2d\n" \ + "trn2 v19.2d, v9.2d, v11.2d\n" /* int32 --> fp32 */ \ + "scvtf v4.4s, v16.4s\n" \ + "scvtf v5.4s, v17.4s\n" \ + "scvtf v6.4s, v18.4s\n" \ + "scvtf v7.4s, v19.4s\n" /* add bias */ \ + "dup v16.4s, %[bias].s[0]\n" \ + "dup v17.4s, %[bias].s[2]\n" \ + "dup v18.4s, %[bias].s[1]\n" \ + "dup v19.4s, %[bias].s[3]\n" /* mul scale */ \ + "fmla v16.4s, v4.4s, %[scale].s[0]\n" \ + "fmla v17.4s, v5.4s, %[scale].s[2]\n" \ + "fmla v18.4s, v6.4s, %[scale].s[1]\n" \ + "fmla v19.4s, v7.4s, %[scale].s[3]\n" \ + "cmp %w[flag_act], #1\n" \ + "bne 12f \n" \ + "movi v20.4s, #0 \n" /* for relu*/ \ + "fmax v16.4s, v16.4s, v20.4s \n" \ + "fmax v17.4s, v17.4s, v20.4s \n" \ + "fmax v18.4s, v18.4s, v20.4s \n" \ + "fmax v19.4s, v19.4s, v20.4s \n" \ + "b 2f \n" /* relu end */ \ + "12: \n" /* no relu */ \ + "cmp %w[flag_act], #0 \n" /* check no act */ \ + "beq 2f \n" /* no act end */ \ + "cmp %w[flag_act], #2 \n" /* check relu6 */ \ + "bne 13f \n" /* jump no relu6*/ \ + "movi v8.4s, #0 \n" /* for relu6 */ \ + "ld1 {v9.4s}, [%[alpha]] \n" /* relu6 alpha */ \ + "fmax v16.4s, v16.4s, v8.4s \n" /* relu6 */ \ + "fmax v17.4s, v17.4s, v8.4s \n" /* relu6 */ \ + "fmax v18.4s, v18.4s, v8.4s \n" /* relu6 */ \ + "fmax v19.4s, v19.4s, v8.4s \n" /* relu6 */ \ + "fmin v16.4s, v16.4s, v9.4s \n" /* relu6 */ \ + "fmin v17.4s, v17.4s, v9.4s \n" /* relu6 */ \ + "fmin v18.4s, v18.4s, v9.4s \n" /* relu6 */ \ + "fmin v19.4s, v19.4s, v9.4s \n" /* relu6 */ \ + "b 2f \n" /* relu6 end */ \ + "13: \n" /* leakey relu */ \ + "movi v12.4s, #0 \n" /* for leakey relu */ \ + "ld1 {v13.4s}, [%[alpha]] \n" /* leakey relu alpha */ \ + "fcmge v4.4s, v16.4s, v12.4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v16.4s, v13.4s \n" /* vmulq_f32 */ \ + "fcmge v6.4s, v17.4s, v12.4s \n" /* vcgeq_f32 */ \ + "fmul v7.4s, v17.4s, v13.4s \n" /* vmulq_f32 */ \ + "fcmge v8.4s, v18.4s, v12.4s \n" /* vcgeq_f32 */ \ + "fmul v9.4s, v18.4s, v13.4s \n" /* vmulq_f32 */ \ + "fcmge v10.4s, v19.4s, v12.4s \n" /* vcgeq_f32 */ \ + "fmul v11.4s, v19.4s, v13.4s \n" /* vmulq_f32 */ \ + "bif v16.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v17.16b, v7.16b, v6.16b \n" /* choose*/ \ + "bif v18.16b, v9.16b, v8.16b \n" /* choose*/ \ + "bif v19.16b, v11.16b, v10.16b \n" /* choose*/ \ + "2: \n" /* act end */ #else #define NCHWC4_TRANS_INT32 \ "vld1.32 {d4-d7}, [%[ptr_din]]!\n" \ "vld1.32 {d8-d11}, [%[ptr_din]]!\n" \ - "vmov.u32 q15, #0\n" \ "1:\n" /* transpose */ \ "vtrn.32 q2, q3\n" \ "vtrn.32 q4, q5\n" \ @@ -2701,13 +2734,44 @@ inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT "vmla.f32 q10, q6, %e[scale][0]\n" \ "vmla.f32 q11, q7, %e[scale][1]\n" \ "vmla.f32 q12, q8, %f[scale][0]\n" \ - "vmla.f32 q13, q9, %f[scale][1]\n" /* relu */ \ - "cmp %[relu], #0\n" \ - "beq 2f\n" \ - "vmax.f32 q10, q10, q15\n" \ - "vmax.f32 q11, q11, q15\n" \ - "vmax.f32 q12, q12, q15\n" \ - "vmax.f32 q13, q13, q15\n" \ + "vmla.f32 q13, q9, %f[scale][1]\n" \ + "vmov.u32 q15, #0 \n" \ + "cmp %[flag_act], #1 \n" \ + "bne 12f \n" \ + "vmax.f32 q10, q10, q15 \n" \ + "vmax.f32 q11, q11, q15 \n" \ + "vmax.f32 q12, q12, q15 \n" \ + "vmax.f32 q13, q13, q15 \n" \ + "b 2f \n" \ + "12: \n" \ + "cmp %[flag_act], #0 \n" \ + "beq 2f \n" \ + "cmp %[flag_act], #2 \n" \ + "bne 13f \n" \ + "vld1.f32 {d14-d15}, [%[alpha]] \n" \ + "vmax.f32 q10, q10, q15 \n" \ + "vmax.f32 q11, q11, q15 \n" \ + "vmax.f32 q12, q12, q15 \n" \ + "vmax.f32 q13, q13, q15 \n" \ + "vmin.f32 q10, q10, q7 \n" \ + "vmin.f32 q11, q11, q7 \n" \ + "vmin.f32 q12, q12, q7 \n" \ + "vmin.f32 q13, q13, q7 \n" \ + "b 2f \n" \ + "13: \n" \ + "vld1.f32 {d6-d7}, [%[alpha]] \n" \ + "vcge.f32 q6, q10, q15 \n" \ + "vmul.f32 q7, q10, q3 \n" \ + "vcge.f32 q8, q11, q15 \n" \ + "vmul.f32 q9, q11, q3 \n" \ + "vbif q10, q7, q6 \n" \ + "vbif q11, q9, q8 \n" \ + "vcge.f32 q6, q12, q15 \n" \ + "vmul.f32 q7, q12, q3 \n" \ + "vcge.f32 q8, q13, q15 \n" \ + "vmul.f32 q9, q13, q3 \n" \ + "vbif q12, q7, q6 \n" \ + "vbif q13, q9, q8 \n" \ "2:\n" #endif @@ -2721,7 +2785,8 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT int cnt, float32x4_t scale, float32x4_t bias, - bool is_relu) { + int flag_act, + float* alpha) { #ifdef __aarch64__ asm volatile(NCHWC4_TRANS_INT32 "subs %w[cnt], %w[cnt], #1\n" @@ -2737,7 +2802,10 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT [doutc3r0] "+r"(dout3), [ptr_din] "+r"(din), [cnt] "+r"(cnt) - : [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu) + : [scale] "w"(scale), + [bias] "w"(bias), + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) : "cc", "memory", "v0", @@ -2779,7 +2847,10 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT [doutc3r0] "+r"(dout3), [ptr_din] "+r"(din), [cnt] "+r"(cnt) - : [scale] "w"(scale), [bias] "w"(bias), [relu] "r"(is_relu) + : [scale] "w"(scale), + [bias] "w"(bias), + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) : "cc", "memory", "q2", @@ -2808,7 +2879,8 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT int cnt, float32x4_t scale, float32x4_t bias, - bool is_relu) { + int flag_act, + float* alpha) { #ifdef __aarch64__ float32x4_t vmax = vdupq_n_f32(-127.f); asm volatile(NCHWC4_TRANS_INT32 @@ -2852,7 +2924,8 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT : [scale] "w"(scale), [vmax] "w"(vmax), [bias] "w"(bias), - [relu] "r"(is_relu) + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) : "cc", "memory", "v0", @@ -2942,8 +3015,9 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT [cnt] "+r"(cnt) : [scale] "w"(scale), [bias] "w"(bias), - [relu] "r"(is_relu), - [vmax] "r"(vmax) + [vmax] "r"(vmax), + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) : "cc", "memory", "q2", @@ -2963,139 +3037,48 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT #endif } -template <> -inline void int32_nchwc4_kernel(int32_t*& dout0, // NOLINT - int32_t*& dout1, // NOLINT - int32_t*& dout2, // NOLINT - int32_t*& dout3, // NOLINT - const int32_t*& din, // NOLINT - int cnt, - float32x4_t scale, - float32x4_t bias, - bool is_relu) { -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v1.4s \n" /* trans q0, q1*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v10.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v2.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn2 v17.2d, v8.2d, v10.2d \n" /* trans q8, q10*/ - "trn1 v18.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "trn2 v19.2d, v9.2d, v11.2d \n" /* trans q9, q11*/ - "cbz %w[relu], 2f\n" - "smax v16.4s, v16.4s, v20.4s \n" /* relu */ - "smax v17.4s, v17.4s, v20.4s \n" /* relu */ - "smax v18.4s, v18.4s, v20.4s \n" /* relu */ - "smax v19.4s, v19.4s, v20.4s \n" /* relu */ - "2:\n" - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "bne 1b \n" /* jump to main loop*/ - : [doutc0r0] "+r"(dout0), - [doutc1r0] "+r"(dout1), - [doutc2r0] "+r"(dout2), - [doutc3r0] "+r"(dout3), - [ptr_din] "+r"(din), - [cnt] "+r"(cnt) - : [relu] "r"(is_relu) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q1 @ trans q0, q1 \n" - "vtrn.32 q2, q3 @ trans q2, q3 \n" - "vswp.32 d1, d4 @ swap d1, d4 \n" - "vswp.32 d3, d6 @ swap d3, d6 \n" - "cmp %[relu], #0\n" - "bne 2f\n" - "vmax.s32 q0, q0, q15 @ relu\n" - "vmax.s32 q1, q1, q15 @ relu\n" - "vmax.s32 q2, q2, q15 @ relu\n" - "vmax.s32 q3, q3, q15 @ relu\n" - "2:\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer\n" - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "bne 1b @ jump to main loop\n" - : [doutc0r0] "+r"(dout0), - [doutc1r0] "+r"(dout1), - [doutc2r0] "+r"(dout2), - [doutc3r0] "+r"(dout3), - [ptr_din] "+r"(din), - [cnt] "+r"(cnt) - : [relu] "r"(is_relu) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15"); -#endif -} - template -inline Dtype cvt_kernel(int din, float scale, float bias, bool flag_relu); +inline Dtype cvt_kernel( + int din, float scale, float bias, int flag_act, float alpha); template <> -inline float cvt_kernel(int din, float scale, float bias, bool flag_relu) { - if (flag_relu) { +inline float cvt_kernel( + int din, float scale, float bias, int flag_act, float alpha) { + if (flag_act == 1) { return LITEMAX(din * scale + bias, 0); + } else if (flag_act == 0) { + return din * scale + bias; + } else if (flag_act == 2) { + float max = LITEMAX(din * scale + bias, 0); + return LITEMIN(max, alpha); + } else { + float result = din * scale + bias; + return result > 0 ? result : alpha * result; } - return din * scale + bias; } template <> -inline int8_t cvt_kernel(int din, float scale, float bias, bool flag_relu) { - if (flag_relu) { - return saturate_cast(round(LITEMAX(din * scale + bias, 0))); - } else { +inline int8_t cvt_kernel( + int din, float scale, float bias, int flag_act, float alpha) { + if (flag_act == 1) { + auto tmp = saturate_cast(round(LITEMAX(din * scale + bias, 0))); + return tmp < -127 ? -127 : tmp; + } else if (flag_act == 0) { auto tmp = saturate_cast(round(din * scale + bias)); return tmp < -127 ? -127 : tmp; + } else if (flag_act == 2) { + float max = LITEMAX(din * scale + bias, 0); + float relu6_result = LITEMIN(max, alpha); + auto tmp = saturate_cast(round(relu6_result)); + return tmp < -127 ? -127 : tmp; + } else { + float result = din * scale + bias; + float leaky_result = result > 0 ? result : alpha * result; + auto tmp = saturate_cast(round(leaky_result)); + return tmp < -127 ? -127 : tmp; } } -template <> -inline int32_t cvt_kernel(int din, float scale, float bias, bool flag_relu) { - if (flag_relu) { - return LITEMAX(din, 0); - } - return din; -} - template inline void write_int32_nchwc4_to_nchw(const int* din, Dtype* dout, @@ -3108,7 +3091,8 @@ inline void write_int32_nchwc4_to_nchw(const int* din, int channel, int height, int width, - bool flag_relu, + int flag_act, + float* alpha, float* bias, bool flag_bias, Dtype* trash_ptr, @@ -3160,21 +3144,22 @@ inline void write_int32_nchwc4_to_nchw(const int* din, cnt, w_scale, w_bias, - flag_relu); + flag_act, + alpha); } if (we > width) { int offset = 16 * (valid_w / 4 - 1); din_hei_ptr = din + index + offset; int j = we - 4; for (; j < width; ++j) { - *(doutc0_ptr++) = - cvt_kernel(din_hei_ptr[0], scale[0], bias[0], flag_relu); - *(doutc1_ptr++) = - cvt_kernel(din_hei_ptr[1], scale[1], bias[1], flag_relu); - *(doutc2_ptr++) = - cvt_kernel(din_hei_ptr[2], scale[2], bias[2], flag_relu); - *(doutc3_ptr++) = - cvt_kernel(din_hei_ptr[3], scale[3], bias[3], flag_relu); + *(doutc0_ptr++) = cvt_kernel( + din_hei_ptr[0], scale[0], bias[0], flag_act, alpha[0]); + *(doutc1_ptr++) = cvt_kernel( + din_hei_ptr[1], scale[1], bias[1], flag_act, alpha[0]); + *(doutc2_ptr++) = cvt_kernel( + din_hei_ptr[2], scale[2], bias[2], flag_act, alpha[0]); + *(doutc3_ptr++) = cvt_kernel( + din_hei_ptr[3], scale[3], bias[3], flag_act, alpha[0]); din_hei_ptr += 4; } } @@ -3196,7 +3181,8 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT float32x4_t scale1, float32x4_t bias0, float32x4_t bias1, - bool is_relu); + int flag_act, + float* alpha); // clang-format off #ifdef __aarch64__ @@ -3205,7 +3191,6 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT "ldp q2, q3, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \ "ldp q4, q5, [%[ptr_din]], #32\n" /* load r00, r01 to q0, q1 */ \ "ldp q6, q7, [%[ptr_din]], #32\n" /* load r02, r03 to q2, q3 */ \ - "movi v31.4s, #0\n" /* main loop*/ \ "1:\n" \ "trn1 v8.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \ "trn2 v9.4s, v0.4s, v2.4s\n" /* trans q0, q1*/ \ @@ -3256,17 +3241,71 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT "fmla v9.4s, v11.4s, %[scale1].s[2]\n" \ "fmla v12.4s, v14.4s, %[scale1].s[1]\n" \ "fmla v13.4s, v15.4s, %[scale1].s[3]\n" \ - /* relu */ \ - "cbz %w[relu], 2f\n" \ - "fmax v16.4s, v16.4s, v31.4s\n" /*relu*/ \ - "fmax v17.4s, v17.4s, v31.4s\n" /*relu*/ \ - "fmax v18.4s, v18.4s, v31.4s\n" /*relu*/ \ - "fmax v19.4s, v19.4s, v31.4s\n" /*relu*/ \ - "fmax v8.4s, v8.4s, v31.4s\n" /*relu*/ \ - "fmax v9.4s, v9.4s, v31.4s\n" /*relu*/ \ - "fmax v12.4s, v12.4s, v31.4s\n" /*relu*/ \ - "fmax v13.4s, v13.4s, v31.4s\n" /*relu*/ \ - "2:\n" + /* activation */ \ + "cmp %w[flag_act], #1\n" \ + "bne 12f \n" \ + "movi v31.4s, #0 \n" /* for relu*/ \ + "fmax v16.4s, v16.4s, v31.4s \n" /*relu*/ \ + "fmax v17.4s, v17.4s, v31.4s \n" /*relu*/ \ + "fmax v18.4s, v18.4s, v31.4s \n" /*relu*/ \ + "fmax v19.4s, v19.4s, v31.4s \n" /*relu*/ \ + "fmax v8.4s, v8.4s, v31.4s \n" /*relu*/ \ + "fmax v9.4s, v9.4s, v31.4s \n" /*relu*/ \ + "fmax v12.4s, v12.4s, v31.4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, v31.4s \n" /*relu*/ \ + "b 2f \n" /* relu end */ \ + "12: \n" /* no relu */ \ + "cmp %w[flag_act], #0 \n" /* check no act */ \ + "beq 2f \n" /* no act end */ \ + "cmp %w[flag_act], #2 \n" /* check relu6 */ \ + "bne 13f \n" /* jump no relu6*/ \ + "movi v20.4s, #0 \n" /* for relu6 */ \ + "ld1 {v21.4s}, [%[alpha]] \n" /* relu6 alpha */ \ + "fmax v16.4s, v16.4s, v20.4s \n" /* relu6 */ \ + "fmax v17.4s, v17.4s, v20.4s \n" /* relu6 */ \ + "fmax v18.4s, v18.4s, v20.4s \n" /* relu6 */ \ + "fmax v19.4s, v19.4s, v20.4s \n" /* relu6 */ \ + "fmax v8.4s, v8.4s, v20.4s \n" /* relu6 */ \ + "fmax v9.4s, v9.4s, v20.4s \n" /* relu6 */ \ + "fmax v12.4s, v12.4s, v20.4s \n" /* relu6 */ \ + "fmax v13.4s, v13.4s, v20.4s \n" /* relu6 */ \ + "fmin v16.4s, v16.4s, v21.4s \n" /* relu6 */ \ + "fmin v17.4s, v17.4s, v21.4s \n" /* relu6 */ \ + "fmin v18.4s, v18.4s, v21.4s \n" /* relu6 */ \ + "fmin v19.4s, v19.4s, v21.4s \n" /* relu6 */ \ + "fmin v8.4s, v8.4s, v21.4s \n" /* relu6 */ \ + "fmin v9.4s, v9.4s, v21.4s \n" /* relu6 */ \ + "fmin v12.4s, v12.4s, v21.4s \n" /* relu6 */ \ + "fmin v13.4s, v13.4s, v21.4s \n" /* relu6 */ \ + "b 2f \n" /* relu6 end */ \ + "13: \n" /* leakey relu */ \ + "movi v20.4s, #0 \n" /* for leakey relu */ \ + "ld1 {v21.4s}, [%[alpha]] \n" /* leakey relu alpha */ \ + "fcmge v10.4s, v16.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v11.4s, v16.4s, v21.4s \n" /* vmulq_f32 */ \ + "fcmge v14.4s, v17.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v15.4s, v17.4s, v21.4s \n" /* vmulq_f32 */ \ + "fcmge v22.4s, v18.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v23.4s, v18.4s, v21.4s \n" /* vmulq_f32 */ \ + "fcmge v24.4s, v19.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v25.4s, v19.4s, v21.4s \n" /* vmulq_f32 */ \ + "bif v16.16b, v11.16b, v10.16b \n" /* choose*/ \ + "bif v17.16b, v15.16b, v14.16b \n" /* choose*/ \ + "bif v18.16b, v23.16b, v22.16b \n" /* choose*/ \ + "bif v19.16b, v25.16b, v24.16b \n" /* choose*/ \ + "fcmge v10.4s, v8.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v11.4s, v8.4s, v21.4s \n" /* vmulq_f32 */ \ + "fcmge v14.4s, v9.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v15.4s, v9.4s, v21.4s \n" /* vmulq_f32 */ \ + "fcmge v22.4s, v12.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v23.4s, v12.4s, v21.4s \n" /* vmulq_f32 */ \ + "fcmge v24.4s, v13.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v25.4s, v13.4s, v21.4s \n" /* vmulq_f32 */ \ + "bif v8.16b, v11.16b, v10.16b \n" /* choose*/ \ + "bif v9.16b, v15.16b, v14.16b \n" /* choose*/ \ + "bif v12.16b, v23.16b, v22.16b \n" /* choose*/ \ + "bif v13.16b, v25.16b, v24.16b \n" /* choose*/ \ + "2: \n" /* act end */ #else #define INT32_NCHWC8_TO_NCHW_FP32 \ @@ -3312,18 +3351,68 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT "vswp d5, d12\n" /* q2: b0-b3, q6: d0-d3 */ \ "vswp d3, d10\n" /* q1: e0-e3, q5: g0-g3 */ \ "vswp d7, d14\n" /* q3: f0-f3, q7: h0-h3 */ \ - /* relu */ \ - "vmov.i32 q8, #0\n" \ - "cmp %[relu], #0\n" \ - "beq 2f\n" \ - "vmax.f32 q0, q0, q8\n" /*relu*/ \ - "vmax.f32 q2, q2, q8\n" /*relu*/ \ - "vmax.f32 q4, q4, q8\n" /*relu*/ \ - "vmax.f32 q6, q6, q8\n" /*relu*/ \ - "vmax.f32 q1, q1, q8\n" /*relu*/ \ - "vmax.f32 q3, q3, q8\n" /*relu*/ \ - "vmax.f32 q5, q5, q8\n" /*relu*/ \ - "vmax.f32 q7, q7, q8\n" /*relu*/ \ + /* activation */ \ + "vmov.u32 q8, #0 \n" \ + "cmp %[flag_act], #1 \n" \ + "bne 12f \n" \ + "vmax.f32 q0, q0, q8 \n" /*relu*/ \ + "vmax.f32 q2, q2, q8 \n" /*relu*/ \ + "vmax.f32 q4, q4, q8 \n" /*relu*/ \ + "vmax.f32 q6, q6, q8 \n" /*relu*/ \ + "vmax.f32 q1, q1, q8 \n" /*relu*/ \ + "vmax.f32 q3, q3, q8 \n" /*relu*/ \ + "vmax.f32 q5, q5, q8 \n" /*relu*/ \ + "vmax.f32 q7, q7, q8 \n" /*relu*/ \ + "b 2f \n" \ + "12: \n" \ + "cmp %[flag_act], #0 \n" \ + "beq 2f \n" \ + "cmp %[flag_act], #2 \n" \ + "bne 13f \n" \ + "vld1.f32 {d18-d19}, [%[alpha]] \n" \ + "vmax.f32 q0, q0, q8 \n" \ + "vmax.f32 q2, q2, q8 \n" \ + "vmax.f32 q4, q4, q8 \n" \ + "vmax.f32 q6, q6, q8 \n" \ + "vmax.f32 q1, q1, q8 \n" \ + "vmax.f32 q3, q3, q8 \n" \ + "vmax.f32 q5, q5, q8 \n" \ + "vmax.f32 q7, q7, q8 \n" \ + "vmin.f32 q0, q0, q9 \n" \ + "vmin.f32 q2, q2, q9 \n" \ + "vmin.f32 q4, q4, q9 \n" \ + "vmin.f32 q6, q6, q9 \n" \ + "vmin.f32 q1, q1, q9 \n" \ + "vmin.f32 q3, q3, q9 \n" \ + "vmin.f32 q5, q5, q9 \n" \ + "vmin.f32 q7, q7, q9 \n" \ + "b 2f \n" \ + "13: \n" \ + "vld1.f32 {d18-d19}, [%[alpha]] \n" \ + "vcge.f32 q10, q0, q8 \n" \ + "vmul.f32 q11, q0, q9 \n" \ + "vbif q0, q11, q10 \n" \ + "vcge.f32 q10, q2, q8 \n" \ + "vmul.f32 q11, q2, q9 \n" \ + "vbif q2, q11, q10 \n" \ + "vcge.f32 q10, q4, q8 \n" \ + "vmul.f32 q11, q4, q9 \n" \ + "vbif q4, q11, q10 \n" \ + "vcge.f32 q10, q6, q8 \n" \ + "vmul.f32 q11, q6, q9 \n" \ + "vbif q6, q11, q10 \n" \ + "vcge.f32 q10, q1, q8 \n" \ + "vmul.f32 q11, q1, q9 \n" \ + "vbif q1, q11, q10 \n" \ + "vcge.f32 q10, q3, q8 \n" \ + "vmul.f32 q11, q3, q9 \n" \ + "vbif q3, q11, q10 \n" \ + "vcge.f32 q10, q5, q8 \n" \ + "vmul.f32 q11, q5, q9 \n" \ + "vbif q5, q11, q10 \n" \ + "vcge.f32 q10, q7, q8 \n" \ + "vmul.f32 q11, q7, q9 \n" \ + "vbif q7, q11, q10 \n" \ "2:\n" #endif @@ -3344,7 +3433,9 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT float32x4_t scale1, float32x4_t bias0, float32x4_t bias1, - bool is_relu) { + int flag_act, + float* alpha) { +// clang-format off #ifdef __aarch64__ asm volatile(INT32_NCHWC8_TO_NCHW_FP32 "subs %w[cnt], %w[cnt], #1\n" /* loop count -1*/ @@ -3371,31 +3462,13 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT [scale1] "w"(scale1), [bias0] "w"(bias0), [bias1] "w"(bias1), - [relu] "r"(is_relu) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v31" + ); #else asm volatile(INT32_NCHWC8_TO_NCHW_FP32 "subs %[cnt], #1\n" /* loop count -1*/ @@ -3422,22 +3495,13 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT [scale1] "w"(scale1), [bias0] "w"(bias0), [bias1] "w"(bias1), - [relu] "r"(is_relu) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11" + ); #endif + // clang-format on } template <> @@ -3455,7 +3519,9 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT float32x4_t scale1, float32x4_t bias0, float32x4_t bias1, - bool is_relu) { + int flag_act, + float* alpha) { +// clang-format off #ifdef __aarch64__ float32x4_t vmax = vdupq_n_f32(-127.f); asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* fp32-int32 */ @@ -3529,34 +3595,13 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT [bias0] "w"(bias0), [bias1] "w"(bias1), [vmax] "w"(vmax), - [relu] "r"(is_relu) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v31"); + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v31" + ); #else float vmax[4] = {-127.f, -127.f, -127.f, -127.f}; asm volatile(INT32_NCHWC8_TO_NCHW_FP32 /* set +-0.5 offset */ @@ -3669,175 +3714,13 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT [bias0] "w"(bias0), [bias1] "w"(bias1), [vmax] "r"(vmax), - [relu] "r"(is_relu) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -#endif -} - -template <> -inline void int32_nchwc8_kernel(int32_t*& dout0, // NOLINT - int32_t*& dout1, // NOLINT - int32_t*& dout2, // NOLINT - int32_t*& dout3, // NOLINT - int32_t*& dout4, // NOLINT - int32_t*& dout5, // NOLINT - int32_t*& dout6, // NOLINT - int32_t*& dout7, // NOLINT - const int32_t*& din, // NOLINT - int cnt, - float32x4_t scale0, - float32x4_t scale1, - float32x4_t bias0, - float32x4_t bias1, - bool is_relu) { -#ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "trn1 v8.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn2 v9.4s, v0.4s, v2.4s \n" /* trans q0, q1*/ - "trn1 v10.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "trn2 v11.4s, v1.4s, v3.4s \n" /* trans q2, q3*/ - "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v12.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn2 v13.4s, v4.4s, v6.4s \n" /* trans q0, q1*/ - "trn1 v14.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "trn2 v15.4s, v5.4s, v7.4s \n" /* trans q2, q3*/ - "ldp q2, q3, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "trn1 v16.2d, v8.2d, v12.2d \n" /* trans q8, q10 00 01 02 03*/ - "trn2 v17.2d, v8.2d, v12.2d \n" /* trans q8, q10 20 21 22 23*/ - "trn1 v18.2d, v9.2d, v13.2d \n" /* trans q9, q11 10 11 12 13*/ - "trn2 v19.2d, v9.2d, v13.2d \n" /* trans q9, q11 30 31 32 33*/ - "ldp q4, q5, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ - "trn1 v8.2d, v10.2d, v14.2d \n" /* trans q8, q10 40 41 42 43*/ - "trn2 v9.2d, v10.2d, v14.2d \n" /* trans q8, q10 60 61 62 63*/ - "trn1 v12.2d, v11.2d, v15.2d \n" /* trans q9, q11 50 51 52 53*/ - "trn2 v13.2d, v11.2d, v15.2d \n" /* trans q9, q11 70 71 72 73*/ - "ldp q6, q7, [%[ptr_din]], #32 \n" /* load r02, r03 to q2, q3 */ - "cbz %w[relu], 2f\n" - "smax v16.4s, v16.4s, v20.4s \n" /*relu*/ - "smax v17.4s, v17.4s, v20.4s \n" /*relu*/ - "smax v18.4s, v18.4s, v20.4s \n" /*relu*/ - "smax v19.4s, v19.4s, v20.4s \n" /*relu*/ - "smax v8.4s, v8.4s, v20.4s \n" /*relu*/ - "smax v9.4s, v9.4s, v20.4s \n" /*relu*/ - "smax v12.4s, v12.4s, v20.4s \n" /*relu*/ - "smax v13.4s, v13.4s, v20.4s \n" /*relu*/ - "2:\n" - "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ - "str q18, [%[doutc1r0]], #16 \n" /* store c1r0*/ - "str q19, [%[doutc3r0]], #16 \n" /* store c3r0*/ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q8, [%[doutc4r0]], #16 \n" /* store c0r0*/ - "str q9, [%[doutc6r0]], #16 \n" /* store c2r0*/ - "str q12, [%[doutc5r0]], #16 \n" /* store c1r0*/ - "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ - "bne 1b \n" /* jump to main loop*/ - : [doutc0r0] "+r"(dout0), - [doutc1r0] "+r"(dout1), - [doutc2r0] "+r"(dout2), - [doutc3r0] "+r"(dout3), - [doutc4r0] "+r"(dout4), - [doutc5r0] "+r"(dout5), - [doutc6r0] "+r"(dout6), - [doutc7r0] "+r"(dout7), - [ptr_din] "+r"(din), - [cnt] "+r"(cnt) - : [relu] "r"(is_relu) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20"); -#else - asm volatile( - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "vmov.s32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - "vtrn.32 q0, q2 @ trans q0, q2 \n" - "vtrn.32 q4, q6 @ trans q4, q6 \n" - "vswp.32 d1, d8 @ swap d1, d8 \n" - "vswp.32 d5, d12 @ swap d5, d12\n" - "vtrn.32 q1, q3 @ trans q1, q3 \n" - "vtrn.32 q5, q7 @ trans q5, q7 \n" - "vswp.32 d3, d10 @ swap d3, d10\n" - "vswp.32 d7, d14 @ swap d7, d14\n" - "cmp %[relu], #0\n" - "bne 2f\n" - "vmax.s32 q0, q0, q15 @ relu\n" - "vmax.s32 q1, q1, q15 @ relu\n" - "vmax.s32 q2, q2, q15 @ relu\n" - "vmax.s32 q3, q3, q15 @ relu\n" - "vmax.s32 q4, q4, q15 @ relu\n" - "vmax.s32 q5, q5, q15 @ relu\n" - "vmax.s32 q6, q6, q15 @ relu\n" - "vmax.s32 q7, q7, q15 @ relu\n" - "2:\n" - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" - "vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer\n" - "vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer\n" - "vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer\n" - "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" - "vld1.32 {d4-d7}, [%[ptr_din]]! @load data \n" - "vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer\n" - "vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer\n" - "vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer\n" - "vld1.32 {d8-d11}, [%[ptr_din]]! @load data \n" - "vld1.32 {d12-d15}, [%[ptr_din]]! @load data \n" - "bne 1b @ jump to main loop\n" - : [doutc0r0] "+r"(dout0), - [doutc1r0] "+r"(dout1), - [doutc2r0] "+r"(dout2), - [doutc3r0] "+r"(dout3), - [doutc4r0] "+r"(dout4), - [doutc5r0] "+r"(dout5), - [doutc6r0] "+r"(dout6), - [doutc7r0] "+r"(dout7), - [ptr_din] "+r"(din) - : [cnt] "r"(cnt), [relu] "r"(is_relu) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q15"); + [flag_act] "r"(flag_act), + [alpha] "r"(alpha) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11" + ); #endif + // clang-format on } /*wirte result in outputs @@ -3855,7 +3738,8 @@ inline void write_int32_nchwc8_to_nchw(const int* din, int channel, int height, int width, - bool flag_relu, + int flag_act, + float* alpha, float* bias, bool flag_bias, Dtype* trash_ptr, @@ -3931,46 +3815,47 @@ inline void write_int32_nchwc8_to_nchw(const int* din, w_scale1, w_bias0, w_bias1, - flag_relu); + flag_act, + alpha); } if (we > width) { int offset = 32 * cnt; din_hei_ptr = ptr_din + offset; for (int j = ws + cnt * 4; j < width; ++j) { if (flag_bias) { - *(doutc0_ptr++) = - cvt_kernel(din_hei_ptr[0], scale[0], bias[0], flag_relu); - *(doutc1_ptr++) = - cvt_kernel(din_hei_ptr[1], scale[1], bias[1], flag_relu); - *(doutc2_ptr++) = - cvt_kernel(din_hei_ptr[2], scale[2], bias[2], flag_relu); - *(doutc3_ptr++) = - cvt_kernel(din_hei_ptr[3], scale[3], bias[3], flag_relu); - *(doutc4_ptr++) = - cvt_kernel(din_hei_ptr[4], scale[4], bias[4], flag_relu); - *(doutc5_ptr++) = - cvt_kernel(din_hei_ptr[5], scale[5], bias[5], flag_relu); - *(doutc6_ptr++) = - cvt_kernel(din_hei_ptr[6], scale[6], bias[6], flag_relu); - *(doutc7_ptr++) = - cvt_kernel(din_hei_ptr[7], scale[7], bias[7], flag_relu); + *(doutc0_ptr++) = cvt_kernel( + din_hei_ptr[0], scale[0], bias[0], flag_act, alpha[0]); + *(doutc1_ptr++) = cvt_kernel( + din_hei_ptr[1], scale[1], bias[1], flag_act, alpha[0]); + *(doutc2_ptr++) = cvt_kernel( + din_hei_ptr[2], scale[2], bias[2], flag_act, alpha[0]); + *(doutc3_ptr++) = cvt_kernel( + din_hei_ptr[3], scale[3], bias[3], flag_act, alpha[0]); + *(doutc4_ptr++) = cvt_kernel( + din_hei_ptr[4], scale[4], bias[4], flag_act, alpha[0]); + *(doutc5_ptr++) = cvt_kernel( + din_hei_ptr[5], scale[5], bias[5], flag_act, alpha[0]); + *(doutc6_ptr++) = cvt_kernel( + din_hei_ptr[6], scale[6], bias[6], flag_act, alpha[0]); + *(doutc7_ptr++) = cvt_kernel( + din_hei_ptr[7], scale[7], bias[7], flag_act, alpha[0]); } else { - *(doutc0_ptr++) = - cvt_kernel(din_hei_ptr[0], scale[0], 0.f, flag_relu); - *(doutc1_ptr++) = - cvt_kernel(din_hei_ptr[1], scale[1], 0.f, flag_relu); - *(doutc2_ptr++) = - cvt_kernel(din_hei_ptr[2], scale[2], 0.f, flag_relu); - *(doutc3_ptr++) = - cvt_kernel(din_hei_ptr[3], scale[3], 0.f, flag_relu); - *(doutc4_ptr++) = - cvt_kernel(din_hei_ptr[4], scale[4], 0.f, flag_relu); - *(doutc5_ptr++) = - cvt_kernel(din_hei_ptr[5], scale[5], 0.f, flag_relu); - *(doutc6_ptr++) = - cvt_kernel(din_hei_ptr[6], scale[6], 0.f, flag_relu); - *(doutc7_ptr++) = - cvt_kernel(din_hei_ptr[7], scale[7], 0.f, flag_relu); + *(doutc0_ptr++) = cvt_kernel( + din_hei_ptr[0], scale[0], 0.f, flag_act, alpha[0]); + *(doutc1_ptr++) = cvt_kernel( + din_hei_ptr[1], scale[1], 0.f, flag_act, alpha[0]); + *(doutc2_ptr++) = cvt_kernel( + din_hei_ptr[2], scale[2], 0.f, flag_act, alpha[0]); + *(doutc3_ptr++) = cvt_kernel( + din_hei_ptr[3], scale[3], 0.f, flag_act, alpha[0]); + *(doutc4_ptr++) = cvt_kernel( + din_hei_ptr[4], scale[4], 0.f, flag_act, alpha[0]); + *(doutc5_ptr++) = cvt_kernel( + din_hei_ptr[5], scale[5], 0.f, flag_act, alpha[0]); + *(doutc6_ptr++) = cvt_kernel( + din_hei_ptr[6], scale[6], 0.f, flag_act, alpha[0]); + *(doutc7_ptr++) = cvt_kernel( + din_hei_ptr[7], scale[7], 0.f, flag_act, alpha[0]); } din_hei_ptr += 8; } diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index 72d887ce4e630057286d98c86970def4a9efdb04..c833bc8441ee3267987be9dafad882e0b6e7fd46 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -94,7 +94,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -112,7 +113,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -178,7 +180,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, @@ -196,7 +199,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout, const float* scale, const float* bias, bool flag_bias, - bool flag_relu, + int flag_act, + float* alpha, int num, int chin, int hin, diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index b3d23a9e2f7202e792d4f3a223edc7c0726083c8..7c3f61ba914c26c9348fe328cc592ea1f6796310 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -790,8 +790,30 @@ void conv_depthwise_3x3_int8_fp32(const void* din, int pad_h = paddings[0]; int pad_w = paddings[2]; int stride = param.strides[1]; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + auto act_param = param.activation_param; + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } if (stride == 1) { conv_depthwise_3x3s1_int8(reinterpret_cast(dout), reinterpret_cast(din), @@ -799,7 +821,8 @@ void conv_depthwise_3x3_int8_fp32(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, @@ -816,7 +839,8 @@ void conv_depthwise_3x3_int8_fp32(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, @@ -849,8 +873,30 @@ void conv_depthwise_3x3_int8_int8(const void* din, int pad_h = paddings[0]; int pad_w = paddings[2]; int stride = param.strides[1]; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + auto act_param = param.activation_param; + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } if (stride == 1) { conv_depthwise_3x3s1_int8(reinterpret_cast(dout), reinterpret_cast(din), @@ -858,7 +904,8 @@ void conv_depthwise_3x3_int8_int8(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, @@ -875,7 +922,8 @@ void conv_depthwise_3x3_int8_int8(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, @@ -908,8 +956,30 @@ void conv_depthwise_5x5_int8_fp32(const void* din, int pad_h = paddings[0]; int pad_w = paddings[2]; int stride = param.strides[1]; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + auto act_param = param.activation_param; + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } if (stride == 1) { conv_depthwise_5x5s1_int8(reinterpret_cast(dout), reinterpret_cast(din), @@ -917,7 +987,8 @@ void conv_depthwise_5x5_int8_fp32(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, @@ -934,7 +1005,8 @@ void conv_depthwise_5x5_int8_fp32(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, @@ -967,8 +1039,30 @@ void conv_depthwise_5x5_int8_int8(const void* din, int pad_h = paddings[0]; int pad_w = paddings[2]; int stride = param.strides[1]; - bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + auto act_param = param.activation_param; + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } if (stride == 1) { conv_depthwise_5x5s1_int8(reinterpret_cast(dout), reinterpret_cast(din), @@ -976,7 +1070,8 @@ void conv_depthwise_5x5_int8_int8(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, @@ -993,7 +1088,8 @@ void conv_depthwise_5x5_int8_int8(const void* din, scale, bias, flag_bias, - flag_relu, + flag_act, + alpha, num, ch_in, h_in, diff --git a/lite/backends/arm/math/gemm_prepacked_int8.cc b/lite/backends/arm/math/gemm_prepacked_int8.cc index 0642d3c95cc21a024b13d1a62a45baae6db1936d..343e93439d2db563e5ccd4d8c6aed681601871a0 100644 --- a/lite/backends/arm/math/gemm_prepacked_int8.cc +++ b/lite/backends/arm/math/gemm_prepacked_int8.cc @@ -534,18 +534,18 @@ inline void gemm_int8_kernel(const int8_t* a_ptr, "fmin v17.4s, v17.4s, v1.4s\n" /* relu6 */ \ "fmin v18.4s, v18.4s, v1.4s\n" /* relu6 */ \ "fmin v19.4s, v19.4s, v1.4s\n" /* relu6 */ \ - "fmin v20.4s, v20.4s, v0.4s\n" /* relu6 */ \ - "fmin v21.4s, v21.4s, v0.4s\n" /* relu6 */ \ - "fmin v22.4s, v22.4s, v0.4s\n" /* relu6 */ \ - "fmin v23.4s, v23.4s, v0.4s\n" /* relu6 */ \ - "fmin v24.4s, v24.4s, v0.4s\n" /* relu6 */ \ - "fmin v25.4s, v25.4s, v0.4s\n" /* relu6 */ \ - "fmin v26.4s, v26.4s, v0.4s\n" /* relu6 */ \ - "fmin v27.4s, v27.4s, v0.4s\n" /* relu6 */ \ - "fmin v28.4s, v28.4s, v0.4s\n" /* relu6 */ \ - "fmin v29.4s, v29.4s, v0.4s\n" /* relu6 */ \ - "fmin v30.4s, v30.4s, v0.4s\n" /* relu6 */ \ - "fmin v31.4s, v31.4s, v0.4s\n" /* relu6 */ \ + "fmin v20.4s, v20.4s, v1.4s\n" /* relu6 */ \ + "fmin v21.4s, v21.4s, v1.4s\n" /* relu6 */ \ + "fmin v22.4s, v22.4s, v1.4s\n" /* relu6 */ \ + "fmin v23.4s, v23.4s, v1.4s\n" /* relu6 */ \ + "fmin v24.4s, v24.4s, v1.4s\n" /* relu6 */ \ + "fmin v25.4s, v25.4s, v1.4s\n" /* relu6 */ \ + "fmin v26.4s, v26.4s, v1.4s\n" /* relu6 */ \ + "fmin v27.4s, v27.4s, v1.4s\n" /* relu6 */ \ + "fmin v28.4s, v28.4s, v1.4s\n" /* relu6 */ \ + "fmin v29.4s, v29.4s, v1.4s\n" /* relu6 */ \ + "fmin v30.4s, v30.4s, v1.4s\n" /* relu6 */ \ + "fmin v31.4s, v31.4s, v1.4s\n" /* relu6 */ \ "b 9f \n" /* relu end */ #define GEMM_INT8_LEAKY_RELU \ diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index 907a915a37670798632ebe072c8cd0ff207c0a98..e65591b0c8de340e46d3c36b52033f6377e0d10f 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -169,6 +169,12 @@ void DepthwiseConv::PrepareForRun() { } flag_trans_bias_ = true; } + //! update relu6 parameter + if (param.activation_param.has_active && + param.activation_param.active_type == lite_api::ActivationType::kRelu6) { + param.activation_param.Relu_clipped_coef = + param.activation_param.Relu_clipped_coef / param.output_scale; + } /// select dw conv kernel if (kw == 3) { // trans weights diff --git a/lite/kernels/arm/conv_direct.h b/lite/kernels/arm/conv_direct.h index 72b5e4cf815e1f812ca695d4e59b76b28058405e..a4fac01f651e76f4aace334fb8f742e7f4926e28 100644 --- a/lite/kernels/arm/conv_direct.h +++ b/lite/kernels/arm/conv_direct.h @@ -39,7 +39,8 @@ inline bool direct_conv_trans_weights( const std::vector& w_scale, float in_scale, float out_scale, - std::vector& merge_scale) { // NOLINT + std::vector& merge_scale, // NOLINT + float* relu_clipped_coef) { constexpr int cblock = 4; int oc = win->dims()[0]; int ic = win->dims()[1]; @@ -64,7 +65,8 @@ inline bool direct_conv_trans_weights( const std::vector& w_scale, float in_scale, float out_scale, - std::vector& merge_scale) { // NOLINT + std::vector& merge_scale, // NOLINT + float* relu_clipped_coef) { int cblock = 4; if (stride == 2) { cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num(); @@ -103,7 +105,8 @@ inline bool direct_conv_trans_weights( const std::vector& w_scale, float in_scale, float out_scale, - std::vector& merge_scale) { // NOLINT + std::vector& merge_scale, // NOLINT + float* relu_clipped_coef) { int cblock = 4; if (stride == 2) { cblock = lite::arm::math::conv_3x3s2_direct_int8_c_num(); @@ -130,6 +133,8 @@ inline bool direct_conv_trans_weights( merge_scale[i] = w_scale[i] * scale; } } + /// update relu_clipped_coef + *relu_clipped_coef /= out_scale; /// update bias if (bin) { bout->Resize(bin->dims()); @@ -167,16 +172,17 @@ class DirectConv : public KernelLite { << "direct conv only support conv3x3s1 and conv3x3s2"; CHECK(kw == 3 && kh == 3) << "direct conv only support conv3x3s1 and conv3x3s2"; - flag_trans_bias_ = - direct_conv_trans_weights(param.filter, - &weights_, - param.bias, - &bias_, - sw, - param.weight_scale, - param.input_scale, - param.output_scale, - w_scale_); + flag_trans_bias_ = direct_conv_trans_weights( + param.filter, + &weights_, + param.bias, + &bias_, + sw, + param.weight_scale, + param.input_scale, + param.output_scale, + w_scale_, + ¶m.activation_param.Relu_clipped_coef); } virtual void Run(); diff --git a/lite/tests/math/conv_int8_compute_test.cc b/lite/tests/math/conv_int8_compute_test.cc index 339e1e7325048a28d06dcf4fee93885f0e803d7c..8dac81fe9f08f3e85fab844ce2df0965fbb52289 100644 --- a/lite/tests/math/conv_int8_compute_test.cc +++ b/lite/tests/math/conv_int8_compute_test.cc @@ -56,7 +56,7 @@ DEFINE_int32(dila_w, 1, "dilation width"); DEFINE_bool(flag_act, true, "do act"); DEFINE_bool(flag_bias, true, "with bias"); DEFINE_double(clipped_coef, 1.0, "clipped relu coef"); -DEFINE_double(leakey_relu_alpha, 8.88, "leakey relu alpha"); +DEFINE_double(leakey_relu_alpha, 2.22, "leakey relu alpha"); typedef paddle::lite::DDim DDim; typedef paddle::lite::Tensor Tensor; @@ -188,7 +188,14 @@ void test_conv_int8(const std::vector& input_dims, } std::vector scale_in{1.f / 127}; - std::vector scale_out{weight_dim.count(1, 4) / 127.f}; + std::vector scale_out(1, weight_dim.count(1, 4) / 127.f); + if (flag_act == 2) { + scale_out[0] = six / 127.f; + } else if (flag_act == 4) { + if (std::abs(alpha) > 1) { + scale_out[0] *= std::abs(alpha); + } + } std::vector scale_w(weight_dim[0], 1.f / 127); param_int8_out.input_scale = scale_in[0]; @@ -484,7 +491,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { for (auto& stride : {1, 2}) { for (auto& pad : {0, 1}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_act : {0, 1}) { + for (auto& flag_act : {0, 1, 2, 4}) { for (auto& c : {1, 3, 5, 8, 16, 32}) { std::vector dims; DDim weights_dim({c, 1, 3, 3}); @@ -520,7 +527,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { for (auto& stride : {1, 2}) { for (auto& pad : {0, 1, 2, 3, 4}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_act : {0, 1}) { + for (auto& flag_act : {0, 1, 2, 4}) { for (auto& c : {1, 5, 15, 33}) { std::vector dims; DDim weights_dim({c, 1, 5, 5}); @@ -553,7 +560,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { #if 1 /// conv1x1s1 TEST(TestConv1x1s1Int8, test_conv1x1s1) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 32}) { + for (auto& cin : {1, 3, 8, 33}) { for (auto& cout : {1, 5, 17}) { for (auto& g : {1, 2}) { for (auto& flag_bias : {false, true}) { @@ -599,7 +606,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { for (auto& pad_left : {1, 2}) { for (auto& pad_right : {1, 2}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_act : {0, 1}) { + for (auto& flag_act : {0, 1, 2, 4}) { std::vector dims; DDim weights_dim({cout, cin, 3, 3}); for (auto& batch : {1, 2}) { @@ -641,7 +648,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { for (auto& pad_left : {1, 2}) { for (auto& pad_right : {1, 2}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_act : {0, 1}) { + for (auto& flag_act : {0, 1, 2, 4}) { std::vector dims; DDim weights_dim({cout, cin, 3, 3}); for (auto& batch : {1, 2}) { @@ -673,7 +680,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { } #endif /// conv3x3s2 -#if 0 /// random param conv +#if 1 /// random param conv TEST(TestConvRandInt8, test_conv_rand) { if (FLAGS_basic_test) { for (auto& cin : {1, 17}) {