diff --git a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc index ca27d181842a5f7faaf9497de1f947161279eefb..67d60b18141f64fd4e0048e1a5d1e2c5373c7484 100644 --- a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc +++ b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc @@ -295,7 +295,8 @@ void conv_compute_6x6_3x3(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } else { for (int ci = 0; ci < oc_4; ++ci) { @@ -341,7 +342,8 @@ void conv_compute_6x6_3x3(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } } @@ -562,7 +564,8 @@ void conv_compute_2x2_3x3(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } else { for (int ci = 0; ci < oc_4; ++ci) { @@ -602,7 +605,8 @@ void conv_compute_2x2_3x3(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } } @@ -814,7 +818,8 @@ void conv_compute_2x2_3x3_small(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } else { for (int ci = 0; ci < oc_4; ++ci) { @@ -854,7 +859,8 @@ void conv_compute_2x2_3x3_small(const float* input, hout, wout, false, - zero_ptr); + zero_ptr, + nullptr); } } } diff --git a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc index b4972a1ecab151947f8aaa7d6db0f6e82a08e5e4..5cee02b639af7e04a9184af765a5e96be4cb4cdb 100644 --- a/lite/backends/arm/math/conv3x3s1_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1_direct_fp32.cc @@ -76,6 +76,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); auto paddings = *param.paddings; + auto act_param = param.activation_param; const int pad_h = paddings[0]; const int pad_w = paddings[2]; @@ -469,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } const float* weight_remain_ptr = weights + c_round_down * w_stride; #pragma omp parallel for num_threads(threads) @@ -780,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } } } diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc index e4c9fb99ef9a6b5d3987a1efd5a644f322ea043c..6f056677378ad0499e0f2ce8b0dd56cee5d6a6ae 100644 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -32,6 +32,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1p0_bias_s(float *dout, @@ -46,6 +47,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1p1_bias(float *dout, @@ -60,6 +62,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1p1_bias_s(float *dout, @@ -74,6 +77,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx); void conv_depthwise_3x3s1_fp32(const float *din, @@ -90,6 +94,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, int pad, bool flag_bias, bool flag_relu, + const operators::ActivationParam act_param, ARMContext *ctx) { if (pad == 0) { if (w_in > 5) { @@ -105,6 +110,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s1p0_bias_s(dout, @@ -119,6 +125,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } } @@ -136,6 +143,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } else { conv_depthwise_3x3s1p1_bias_s(dout, @@ -150,11 +158,12 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_in, h_out, w_out, + act_param, ctx); } } } - +// clang-format on #ifdef __aarch64__ #define INIT_S1 \ "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ @@ -255,14 +264,12 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ \ - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ + "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ /* r4 */ \ + "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ #define LEFT_RESULT_S1 \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ @@ -345,16 +352,15 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ #define MID_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ "st1 {v12.4s}, [%[doutr0]], #16 \n" \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -411,30 +417,31 @@ void conv_depthwise_3x3s1_fp32(const float *din, #define RIGHT_COMPUTE_S1 \ "3: \n" \ + "movi v20.4s, #0 \n" \ "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \ "ld1 {v22.4s}, [%[doutr0]] \n" \ "ld1 {v23.4s}, [%[doutr1]] \n" \ "ld1 {v24.4s}, [%[doutr2]] \n" \ "ld1 {v25.4s}, [%[doutr3]] \n" \ \ - "bif v0.16b, %[vzero].16b, v18.16b \n" \ - "bif v1.16b, %[vzero].16b, v19.16b \n" \ - "bif v2.16b, %[vzero].16b, v18.16b \n" \ - "bif v3.16b, %[vzero].16b, v19.16b \n" \ + "bif v0.16b, v20.16b, v18.16b \n" \ + "bif v1.16b, v20.16b, v19.16b \n" \ + "bif v2.16b, v20.16b, v18.16b \n" \ + "bif v3.16b, v20.16b, v19.16b \n" \ \ - "bif v4.16b, %[vzero].16b, v18.16b \n" \ - "bif v5.16b, %[vzero].16b, v19.16b \n" \ - "bif v6.16b, %[vzero].16b, v18.16b \n" \ - "bif v7.16b, %[vzero].16b, v19.16b \n" \ + "bif v4.16b, v20.16b, v18.16b \n" \ + "bif v5.16b, v20.16b, v19.16b \n" \ + "bif v6.16b, v20.16b, v18.16b \n" \ + "bif v7.16b, v20.16b, v19.16b \n" \ \ "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \ "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ - "bif v8.16b, %[vzero].16b, v18.16b \n" \ - "bif v9.16b, %[vzero].16b, v19.16b \n" \ - "bif v10.16b, %[vzero].16b, v18.16b \n" \ - "bif v11.16b, %[vzero].16b, v19.16b \n" \ + "bif v8.16b, v20.16b, v18.16b \n" \ + "bif v9.16b, v20.16b, v19.16b \n" \ + "bif v10.16b, v20.16b, v18.16b \n" \ + "bif v11.16b, v20.16b, v19.16b \n" \ \ "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ \ @@ -467,15 +474,13 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ + "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ #define RIGHT_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ "bif v12.16b, v22.16b, v18.16b \n" \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -520,10 +525,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, "st1 {v15.4s}, [%[doutr3]], #16 \n" #define LEFT_RESULT_S1_RELU \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ \ @@ -570,14 +571,113 @@ void conv_depthwise_3x3s1_fp32(const float *din, "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ "blt 3f \n" +#define LEFT_RESULT_S1_RELU6 \ + "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + \ + "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "cmp %w[cnt], #1 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + +#define LEFT_RESULT_S1_LEAKY_RELU \ + "cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ + "bif v13.16b, v21.16b, v19.16b \n" /* choose*/ \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ + \ + "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ + "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ + \ + "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ + "cmhs v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ + \ + "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + \ + "cmhs v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ + "cmp %w[cnt], #1 \n" \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "blt 3f \n" + #define MID_RESULT_S1_RELU \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "movi v20.4s, #0 \n" \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -598,7 +698,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -617,7 +717,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, /* r3 */ \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ \ @@ -633,20 +733,157 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "subs %w[cnt], %w[cnt], #1 \n" \ \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ \ "st1 {v15.4s}, [%[doutr3]], #16 \n" \ "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ \ "bne 1b \n" -#define RIGHT_RESULT_S1_RELU \ +#define MID_RESULT_S1_RELU6 \ + "movi v20.4s, #0 \n" \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ + /* r3 */ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + \ + "bne 1b \n" + +#define MID_RESULT_S1_LEAKY_RELU \ + "movi v21.4s, #0 \n" \ + "cmhs v18.4s, v12.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "cmhs v18.4s, v13.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v13.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" \ + \ /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "cmhs v18.4s, v14.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ + \ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "cmhs v18.4s, v15.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" \ + "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ + "bne 1b \n" + +#define RIGHT_RESULT_S1_RELU \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -664,7 +901,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -680,7 +917,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ \ "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ \ @@ -690,72 +927,184 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "st1 {v14.4s}, [%[doutr2]], #16 \n" \ \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ \ "bif v15.16b, v25.16b, v18.16b \n" \ \ "st1 {v15.4s}, [%[doutr3]], #16 \n" -#define COMPUTE_S_S1 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s}, [%[din0]], #16\n" \ - "ld1 {v1.4s}, [%[din1]], #16\n" \ - "ld1 {v2.4s}, [%[din2]], #16\n" \ - "ld1 {v3.4s}, [%[din3]], #16\n" \ - \ - "bif v0.16b, %[zero].16b, %[mask].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask].16b\n" \ - "bif v2.16b, %[zero].16b, %[mask].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask].16b\n" \ - \ - "ext v4.16b, %[zero].16b, v0.16b, #12\n" \ - "ext v5.16b, %[zero].16b, v1.16b, #12\n" \ - "ext v6.16b, %[zero].16b, v2.16b, #12\n" \ - "ext v7.16b, %[zero].16b, v3.16b, #12\n" \ - \ - "ext v8.16b, v0.16b, %[zero].16b, #4\n" \ - "ext v9.16b, v1.16b, %[zero].16b, #4\n" \ - "ext v10.16b, v2.16b, %[zero].16b, #4\n" \ - "ext v11.16b, v3.16b, %[zero].16b, #4\n" \ - \ - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ - \ - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ - \ - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ - \ - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ - \ - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ - \ - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ - \ - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ - \ - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v14.4s\n" \ - "fadd v12.4s, v12.4s, v16.4s\n" \ - \ - "fadd v13.4s, v13.4s, v15.4s\n" \ - "fadd v13.4s, v13.4s, v17.4s\n" \ - \ - "fadd v12.4s, v12.4s, %[bias].4s\n" \ +#define RIGHT_RESULT_S1_RELU6 \ + "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "bif v12.16b, v22.16b, v18.16b \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v10.4s, v20.s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ + \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + \ + "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ + "bif v15.16b, v25.16b, v18.16b \n" \ + \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define RIGHT_RESULT_S1_LEAKY_RELU \ + "movi v1.4s, #0 \n" \ + "cmhs v20.4s, v12.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + \ + "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v12.16b, v21.16b, v20.16b \n" /* choose*/ \ + \ + "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ + "bif v12.16b, v22.16b, v18.16b \n" \ + "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "cmhs v20.4s, v13.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + "st1 {v12.4s}, [%[doutr0]], #16 \n" \ + \ + "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v13.16b, v21.16b, v20.16b \n" \ + "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ + "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ + \ + "bif v13.16b, v23.16b, v18.16b \n" \ + \ + "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ + \ + "cmhs v20.4s, v14.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ + \ + "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ + \ + "bif v14.16b, v21.16b, v20.16b \n" \ + "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ + \ + "bif v14.16b, v24.16b, v18.16b \n" \ + \ + "cmhs v20.4s, v15.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + \ + "st1 {v14.4s}, [%[doutr2]], #16 \n" \ + "bif v15.16b, v21.16b, v20.16b \n" \ + "bif v15.16b, v25.16b, v18.16b \n" \ + "st1 {v15.4s}, [%[doutr3]], #16 \n" + +#define COMPUTE_S_S1 \ + "prfm pldl1keep, [%[din0]]\n" \ + "prfm pldl1keep, [%[din1]]\n" \ + "prfm pldl1keep, [%[din2]]\n" \ + "prfm pldl1keep, [%[din3]]\n" \ + \ + "ld1 {v0.4s}, [%[din0]], #16\n" \ + "ld1 {v1.4s}, [%[din1]], #16\n" \ + "ld1 {v2.4s}, [%[din2]], #16\n" \ + "ld1 {v3.4s}, [%[din3]], #16\n" \ + \ + "bif v0.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v1.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v2.16b, %[vzero].16b, %[mask].16b\n" \ + "bif v3.16b, %[vzero].16b, %[mask].16b\n" \ + \ + "ext v4.16b, %[vzero].16b, v0.16b, #12\n" \ + "ext v5.16b, %[vzero].16b, v1.16b, #12\n" \ + "ext v6.16b, %[vzero].16b, v2.16b, #12\n" \ + "ext v7.16b, %[vzero].16b, v3.16b, #12\n" \ + \ + "ext v8.16b, v0.16b, %[vzero].16b, #4\n" \ + "ext v9.16b, v1.16b, %[vzero].16b, #4\n" \ + "ext v10.16b, v2.16b, %[vzero].16b, #4\n" \ + "ext v11.16b, v3.16b, %[vzero].16b, #4\n" \ + \ + "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ + "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ + \ + "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ + "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ + \ + "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ + "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ + \ + "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ + "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ + \ + "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ + "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ + \ + "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ + "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ + \ + "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ + "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ + \ + "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ + "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ + \ + "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ + "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ + \ + "fadd v12.4s, v12.4s, v14.4s\n" \ + "fadd v12.4s, v12.4s, v16.4s\n" \ + \ + "fadd v13.4s, v13.4s, v15.4s\n" \ + "fadd v13.4s, v13.4s, v17.4s\n" \ + \ + "fadd v12.4s, v12.4s, %[bias].4s\n" \ "fadd v13.4s, v13.4s, %[bias].4s\n" #define RESULT_S_S1 \ @@ -765,16 +1114,42 @@ void conv_depthwise_3x3s1_fp32(const float *din, "st1 {v12.4s}, [%[out1]]\n" \ "st1 {v13.4s}, [%[out2]]\n" -#define RESULT_S_S1_RELU \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "fmax v12.4s, v12.4s, %[zero].4s\n" \ - "fmax v13.4s, v13.4s, %[zero].4s\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ +#define RESULT_S_S1_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s\n" \ + "fmax v13.4s, v13.4s, %[vzero].4s\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" + +#define RESULT_S_S1_RELU6 \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fmax v12.4s, v12.4s, %[vzero].4s\n" \ + "fmax v13.4s, v13.4s, %[vzero].4s\n" \ + \ + "fmin v12.4s, v12.4s, %[vsix].4s\n" \ + "fmin v13.4s, v13.4s, %[vsix].4s\n" \ + \ + "st1 {v12.4s}, [%[out1]]\n" \ "st1 {v13.4s}, [%[out2]]\n" +#define RESULT_S_S1_LEAKY_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" \ + "bif v13.16b, v21.16b, v19.16b \n" \ + "st1 {v12.4s}, [%[out1]]\n" \ + "st1 {v13.4s}, [%[out2]]\n" #define COMPUTE_S_S1_P0 \ "prfm pldl1keep, [%[din0]]\n" \ "prfm pldl1keep, [%[din1]]\n" \ @@ -786,17 +1161,17 @@ void conv_depthwise_3x3s1_fp32(const float *din, "ld1 {v4.4s, v5.4s}, [%[din2]]\n" \ "ld1 {v6.4s, v7.4s}, [%[din3]]\n" \ \ - "bif v0.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v0.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v1.16b, %[vzero].16b, %[mask2].16b\n" \ \ - "bif v2.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v2.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v3.16b, %[vzero].16b, %[mask2].16b\n" \ \ - "bif v4.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v5.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v4.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v5.16b, %[vzero].16b, %[mask2].16b\n" \ \ - "bif v6.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v7.16b, %[zero].16b, %[mask2].16b\n" \ + "bif v6.16b, %[vzero].16b, %[mask1].16b\n" \ + "bif v7.16b, %[vzero].16b, %[mask2].16b\n" \ \ "ext v8.16b, v0.16b, v1.16b, #4\n" \ "ext v9.16b, v0.16b, v1.16b, #8\n" \ @@ -849,7 +1224,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, // "st1 {v12.4s}, [%[out1]]\n" \ // "st1 {v13.4s}, [%[out2]]\n" \ - #else #define INIT_S1 \ "pld [%[din0_ptr]] @ preload data\n" \ @@ -1129,6 +1503,66 @@ void conv_depthwise_3x3s1_fp32(const float *din, "vdup.32 q5, %[bias_val] @ and \n" \ "blt 3f @ jump to main loop start point\n" +#define LEFT_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vld1.f32 {d28-d29}, [%[six_ptr]] @ load six \n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + +#define LEFT_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + "vld1.f32 {d28-d29}, [%[scale_ptr]] @ load scale \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ + \ + "vbif q4, q6, q15 @ choose \n" \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q5, q14 \n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vbif q5, q6, q7 @ choose \n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + "cmp %[cnt], #1 @ check whether has mid cols\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + "blt 3f @ jump to main loop start point\n" + #define MID_RESULT_S1_RELU \ /* r3 */ \ "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ @@ -1157,6 +1591,69 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "bne 1b @ jump to main loop start point\n" +#define MID_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[six_ptr]]! @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + +#define MID_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + \ + "vbif q4, q6, q15 @ choose \n" \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ + \ + "vbif q5, q6, q7 @ choose \n" \ + "vext.32 q6, q8, q9, #1 @ 1234\n" \ + "vext.32 q7, q8, q9, #2 @ 2345\n" \ + "vdup.32 q4, %[bias_val] @ and \n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ + \ + "subs %[cnt], #1 @ loop count minus 1\n" \ + \ + "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ + "vdup.32 q5, %[bias_val] @ and \n" \ + \ + "bne 1b @ jump to main loop start point\n" + #define RIGHT_RESULT_S1_RELU \ /* r3 */ \ "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ @@ -1178,6 +1675,58 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" +#define RIGHT_RESULT_S1_RELU6 \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \ + "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vmin.f32 q4, q4, q14 @ relu6 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + \ + "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vmin.f32 q5, q5, q14 @ relu6 \n" \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + +#define RIGHT_RESULT_S1_LEAKY_RELU \ + /* r3 */ \ + "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ + \ + "vld1.32 {d28-d29}, [%[scale_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ + \ + "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q4, q14 \n" \ + \ + "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ + "vbif q4, q6, q15 @ choose \n" \ + \ + "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q6, q5, q14 \n" \ + \ + "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ + "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ + "vbif q5, q6, q7 @ choose \n" \ + \ + "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ + \ + "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ + "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ + \ + "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" + #define COMPUTE_S_S1 \ "pld [%[din0]]\n" \ "pld [%[din1]]\n" \ @@ -1251,6 +1800,36 @@ void conv_depthwise_3x3s1_fp32(const float *din, "vst1.32 {d28-d29}, [%[out1]]\n" \ "vst1.32 {d30-d31}, [%[out2]]\n" +#define RESULT_S_S1_RELU6 \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vld1.32 {d20-d21}, [%[six_ptr]] \n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vmax.f32 q15, q15, %q[vzero]\n" \ + \ + "vmin.f32 q14, q14, q10 \n" \ + "vmin.f32 q15, q15, q10 \n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + +#define RESULT_S_S1_LEAKY_RELU \ + "pld [%[out1]]\n" \ + "pld [%[out2]]\n" \ + \ + "vld1.32 {d18-d19}, [%[scale_ptr]] \n" \ + "vcge.f32 q10, q14, %q[vzero] @ q0 > 0 \n" \ + "vcge.f32 q11, q15, %q[vzero] @ q0 > 0 \n" \ + "vmul.f32 q12, q14, q9 \n" \ + "vmul.f32 q13, q15, q9 \n" \ + \ + "vbif q14, q10, q12 \n" \ + "vbif q15, q11, q13 \n" \ + \ + "vst1.32 {d28-d29}, [%[out1]]\n" \ + "vst1.32 {d30-d31}, [%[out2]]\n" + #define COMPUTE_S_S1_P0 \ "pld [%[din0]]\n" \ "pld [%[din1]]\n" \ @@ -1333,6 +1912,413 @@ void conv_depthwise_3x3s1_fp32(const float *din, "vadd.f32 q15, q5, q9 @ q4 += q10 \n" #endif + +#ifdef __aarch64__ +void act_switch_3x3s1p1(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + const float *din_ptr4, + const float *din_ptr5, + float *doutr0, + float *doutr1, + float *doutr2, + float *doutr3, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask, + unsigned int *rmask, + float32x4_t vzero, + float *vbias, + int cnt, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vsix] "w"(vsix), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vscale] "w"(vscale), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } +} +#else +void act_switch_3x3s1p1(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask_ptr, + unsigned int *rmask_ptr, + float32x4_t vzero, + float bias_val, + int cnt, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 + MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [six_ptr] "r"(vsix), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [scale_ptr] "r"(vscale), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 + MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +} +#endif +// clang-format on /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width > 4 @@ -1349,6 +2335,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! pad is done implicit const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; @@ -1486,106 +2473,25 @@ void conv_depthwise_3x3s1p1_bias(float *dout, } int cnt = cnt_col; - if (flag_relu) { - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } else { - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - } + act_switch_3x3s1p1(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + doutr0, + doutr1, + doutr2, + doutr3, + wr0, + wr1, + wr2, + vmask, + rmask, + vzero, + vbias, + cnt, + act_param); dout_ptr = dout_ptr + 4 * w_out; } #else @@ -1598,7 +2504,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout, doutr0 = dout_ptr; doutr1 = dout_ptr + w_out; - // unsigned int* rst_mask = rmask; if (i == 0) { din_ptr0 = zero_ptr; @@ -1635,77 +2540,314 @@ void conv_depthwise_3x3s1p1_bias(float *dout, int cnt = cnt_col; unsigned int *rmask_ptr = rmask; unsigned int *vmask_ptr = vmask; - if (flag_relu) { - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } + act_switch_3x3s1p1(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + doutr0, + doutr1, + wr0, + wr1, + wr2, + vmask_ptr, + rmask_ptr, + vzero, + bias_val, + cnt, + act_param); dout_ptr += 2 * w_out; } //! end of processing mid rows #endif } } } - +void act_switch_3x3s1p1_s(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp, + float32x4_t vzero, + float32x4_t wbias, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); +#else + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [vsix] "w"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [six_ptr] "r"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [vscale] "w"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); + break; +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [scale_ptr] "r"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); +#else + asm volatile(COMPUTE_S_S1 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [mask] "w"(vmask_rp), + [bias] "w"(wbias), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width <= 4 @@ -1722,6 +2864,7 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! 3x3s1 convolution, implemented by direct algorithm //! pad is done implicit @@ -1772,7 +2915,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, if (hs == -1) { dr0 = zero; } - switch (he - h_in) { case 2: dr2 = zero; @@ -1782,127 +2924,19 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, default: break; } -#ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); - } else { - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); - } -#else - if (flag_relu) { - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } -#endif + act_switch_3x3s1p1_s(dr0, + dr1, + dr2, + dr3, + out_buf1, + out_buf2, + wr0, + wr1, + wr2, + vmask_rp, + vzero, + wbias, + act_param); for (int w = 0; w < w_out; ++w) { *doutr0++ = out_buf1[w]; *doutr1++ = out_buf2[w]; @@ -1916,6 +2950,490 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, } // end of processing batchs } +#ifdef __aarch64__ +void act_switch_3x3s1p0(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + const float *din_ptr4, + const float *din_ptr5, + float *doutr0, + float *doutr1, + float *doutr2, + float *doutr3, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask, + unsigned int *rmask, + float32x4_t vzero, + float *vbias, + int cnt, + int remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_RELU + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_RELU6 + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU6 "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vsix] "w"(vsix), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_LEAKY_RELU "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [vscale] "w"(vscale), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S1 + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ + "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ + "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ + "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + MID_COMPUTE_S1 MID_RESULT_S1 + "cmp %w[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + "0: \n" + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr0), + [din_ptr1] "+r"(din_ptr1), + [din_ptr2] "+r"(din_ptr2), + [din_ptr3] "+r"(din_ptr3), + [din_ptr4] "+r"(din_ptr4), + [din_ptr5] "+r"(din_ptr5), + [doutr0] "+r"(doutr0), + [doutr1] "+r"(doutr1), + [doutr2] "+r"(doutr2), + [doutr3] "+r"(doutr3) + : [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [bias_val] "r"(vbias), + [vmask] "r"(vmask), + [rmask] "r"(rmask), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25"); + } +} +#else +void act_switch_3x3s1p0(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + unsigned int *vmask_ptr, + unsigned int *rmask_ptr, + float32x4_t vzero, + float bias_val, + int cnt, + int remain, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_RELU + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_RELU6 + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_RELU6 "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [six_ptr] "r"(vsix), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + asm volatile(INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 + MID_RESULT_S1_LEAKY_RELU + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 + RIGHT_RESULT_S1_LEAKY_RELU + "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [scale_ptr] "r"(vscale), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { + asm volatile( + INIT_S1 + "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" + "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" + "vext.32 q6, q8, q9, #1 @ 0012\n" + "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 MID_RESULT_S1 + "cmp %[remain], #1 \n" + "blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 + "0: \n" + : [dout_ptr1] "+r"(doutr0), + [dout_ptr2] "+r"(doutr1), + [din0_ptr] "+r"(din_ptr0), + [din1_ptr] "+r"(din_ptr1), + [din2_ptr] "+r"(din_ptr2), + [din3_ptr] "+r"(din_ptr3), + [cnt] "+r"(cnt), + [rmask] "+r"(rmask_ptr), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias_val] "r"(bias_val), + [vzero] "w"(vzero), + [remain] "r"(remain) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } +} +#endif /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width > 4 @@ -1932,6 +3450,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! pad is done implicit const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; @@ -2060,15 +3579,16 @@ void conv_depthwise_3x3s1p0_bias(float *dout, } int cnt = tile_w; + /* if (flag_relu) { asm volatile( INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0) + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0) + "ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234 + "ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345 + "ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0) + "ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0) MID_COMPUTE_S1 MID_RESULT_S1_RELU "cmp %w[remain], #1 \n" "blt 0f \n" RIGHT_COMPUTE_S1 @@ -2123,12 +3643,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout, } else { asm volatile( INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0) + "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0) + "ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234 + "ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345 + "ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0) + "ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0) MID_COMPUTE_S1 MID_RESULT_S1 "cmp %w[remain], #1 \n" "blt 0f \n" RIGHT_COMPUTE_S1 @@ -2181,6 +3701,27 @@ void conv_depthwise_3x3s1p0_bias(float *dout, "v24", "v25"); } + */ + act_switch_3x3s1p0(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + din_ptr4, + din_ptr5, + doutr0, + doutr1, + doutr2, + doutr3, + wr0, + wr1, + wr2, + vmask, + rmask, + vzero, + vbias, + cnt, + remain, + act_param); dout_ptr = dout_ptr + 4 * w_out; } #else @@ -2219,6 +3760,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, int cnt = tile_w; unsigned int *rmask_ptr = rmask; unsigned int *vmask_ptr = vmask; + /* if (flag_relu) { asm volatile(INIT_S1 "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" @@ -2301,13 +3843,328 @@ void conv_depthwise_3x3s1p0_bias(float *dout, "q13", "q14", "q15"); - } + }*/ + act_switch_3x3s1p0(din_ptr0, + din_ptr1, + din_ptr2, + din_ptr3, + doutr0, + doutr1, + wr0, + wr1, + wr2, + vmask_ptr, + rmask_ptr, + vzero, + bias_val, + cnt, + remain, + act_param); dout_ptr += 2 * w_out; } //! end of processing mid rows #endif } } } +void act_switch_3x3s1p0_s(const float *din_ptr0, + const float *din_ptr1, + const float *din_ptr2, + const float *din_ptr3, + float *doutr0, + float *doutr1, + float32x4_t wr0, + float32x4_t wr1, + float32x4_t wr2, + uint32x4_t vmask_rp1, + uint32x4_t vmask_rp2, + float32x4_t vzero, + float32x4_t wbias, + unsigned int *vmask_ptr, + float bias_val, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); + float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); +#else + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [six_ptr] "r"(vsix), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + break; +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [scale_ptr] "r"(vscale), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + break; +#endif + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(din_ptr0), + [din1] "+r"(din_ptr1), + [din2] "+r"(din_ptr2), + [din3] "+r"(din_ptr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(doutr0), + [out2] "r"(doutr1) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} /** * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * width <= 4 @@ -2324,6 +4181,7 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, const int w_in, const int h_out, const int w_out, + const operators::ActivationParam act_param, ARMContext *ctx) { //! 3x3s1 convolution, implemented by direct algorithm //! pad is done implicit @@ -2355,15 +4213,22 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, float32x4_t wr1 = vld1q_f32(weight_ptr + 3); float32x4_t wr2 = vld1q_f32(weight_ptr + 6); -#ifdef __aarch64__ + // #ifdef __aarch64__ + // float32x4_t wbias; + // if (flag_bias) { + // wbias = vdupq_n_f32(bias[i]); + // } else { + // wbias = vdupq_n_f32(0.f); + // } + // #endif // __aarch64__ float32x4_t wbias; + float bias_val = 0.f; if (flag_bias) { wbias = vdupq_n_f32(bias[i]); + bias_val = bias[i]; } else { wbias = vdupq_n_f32(0.f); } -#endif // __aarch64__ - float out_buf1[4]; float out_buf2[4]; float trash_buf[4]; @@ -2396,135 +4261,154 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, break; } } -#ifdef __aarch64__ - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - } -#else + /* + #ifdef __aarch64__ + if (flag_relu) { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } else { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vbias] "w"(wbias), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [vzero] "w"(vzero), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); + } + #else + unsigned int *vmask_ptr = vmask; + float bias_val = flag_bias ? bias[i] : 0.f; + if (flag_relu) { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } else { + asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 + : [din0] "+r"(dr0), + [din1] "+r"(dr1), + [din2] "+r"(dr2), + [din3] "+r"(dr3), + [vmask] "+r"(vmask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [vzero] "w"(vzero), + [bias_val] "r"(bias_val), + [out1] "r"(out_buf1), + [out2] "r"(out_buf2) + : "cc", + "memory", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + } + #endif + */ unsigned int *vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - if (flag_relu) { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } else { - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } -#endif + act_switch_3x3s1p0_s(dr0, + dr1, + dr2, + dr3, + out_buf1, + out_buf2, + wr0, + wr1, + wr2, + vmask_rp1, + vmask_rp2, + vzero, + wbias, + vmask_ptr, + bias_val, + act_param); for (int w = 0; w < w_out; ++w) { *doutr0++ = out_buf1[w]; *doutr1++ = out_buf2[w]; diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index 08e5efecd751bcca534ba7a47035c5f70fa1f6bf..fd54e214cf27e001e21efcf255b09113bbe12d19 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -25,6 +25,785 @@ namespace paddle { namespace lite { namespace arm { namespace math { +// clang-format off +#ifdef __aarch64__ +#define COMPUTE \ + "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \ + "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \ + "ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ \ + "ldp q4, q5, [%[inr0]]\n" /* load input r0*/ \ + "ldp q10, q11, [%[inr1]]\n" /* load input r1*/ \ + /* r0, r1, mul w0, get out r0, r1 */ \ + "fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ \ + "fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ \ + "fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ \ + "fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ \ + "fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ \ + "fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ \ + "fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ \ + "fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ \ + /* r0, r1, mul w1, get out r0, r1 */ \ + "fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ \ + "fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ \ + "fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ \ + "fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ \ + "fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ \ + "fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ \ + "fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ \ + /* r0, r1, mul w2, get out r0, r1 */ \ + "fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ \ + "fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \ + "fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ \ + "fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ \ + "ldp q4, q5, [%[inr2]]\n" /* load input r2*/ \ + "fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ \ + "fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ \ + "fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ \ + "fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ \ + /* r1, r2, mul w3, get out r0, r1 */ \ + "fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ \ + "fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ \ + "fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ \ + "fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ \ + "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ \ + "fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ \ + "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ \ + "fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ \ + /* r1, r2, mul w4, get out r0, r1 */ \ + "fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ \ + "ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ \ + "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ \ + "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ \ + "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ \ + "ldp x0, x1, [%[outl]] \n" \ + "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ \ + "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ \ + "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ \ + "fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ \ + /* r1, r2, mul w5, get out r0, r1 */ \ + "fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ \ + "fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ \ + "ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ \ + "fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ \ + "fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ \ + "ldp q10, q11, [%[inr3]]\n" /* load input r3*/ \ + "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ \ + "fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ \ + "fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ \ + "fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ \ + /* r2, r3, mul w6, get out r0, r1 */ \ + "fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ \ + "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ \ + "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ \ + "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ \ + "ldp x2, x3, [%[outl], #16] \n" \ + "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ \ + "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ \ + "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ \ + "fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ \ + /* r2, r3, mul w7, get out r0, r1 */ \ + "fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ \ + "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ \ + "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ \ + "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ \ + "ldp x4, x5, [%[outl], #32] \n" \ + "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ \ + "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ \ + "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ \ + "fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ \ + /* r2, r3, mul w8, get out r0, r1 */ \ + "fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ \ + "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ \ + "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ \ + "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ \ + "ldp x6, x7, [%[outl], #48] \n" \ + "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ \ + "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ \ + "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ \ + "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ \ + \ + "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ \ + "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ \ + /* transpose */ \ + "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ \ + "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ \ + "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ \ + "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ \ + "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ \ + "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ \ + "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ \ + "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ \ + "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ + +#define RELU \ + "movi v0.4s, #0\n" /* for relu */ \ + "ldr x0, [%[outl], #80]\n" \ + "fmax v15.4s, v15.4s, v0.4s\n" \ + "fmax v16.4s, v16.4s, v0.4s\n" \ + "fmax v17.4s, v17.4s, v0.4s\n" \ + "fmax v18.4s, v18.4s, v0.4s\n" \ + "ld1 {v1.4s}, [x0]\n" \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" \ + "ldr x0, [%[outl]]\n" \ + +#define RELU6 \ + "fmin v15.4s, v15.4s, v1.4s\n" \ + "fmin v16.4s, v16.4s, v1.4s\n" \ + "fmin v17.4s, v17.4s, v1.4s\n" \ + "fmin v18.4s, v18.4s, v1.4s\n" \ + "fmin v19.4s, v19.4s, v1.4s\n" \ + "fmin v20.4s, v20.4s, v1.4s\n" \ + "fmin v21.4s, v21.4s, v1.4s\n" \ + "fmin v22.4s, v22.4s, v1.4s\n" + +#define LEAKY_RELU \ + "movi v0.4s, #0\n" /* for relu */ \ + "ldr x0, [%[outl], #88]\n" \ + "cmhs v1.4s, v15.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v2.4s, v16.4s, v0.4s \n" /* vcgeq_u32 */ \ + "ld1 {v9.4s}, [x0] \n" \ + "cmhs v3.4s, v17.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v4.4s, v18.4s, v0.4s \n" /* vcgeq_u32 */ \ + "ldr x0, [%[outl]] \n" \ + "fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \ + "fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \ + "fmul v7.4s, v17.4s, v9.4s \n" /* mul */ \ + "fmul v8.4s, v18.4s, v9.4s \n" /* mul */ \ + "bif v15.16b, v5.16b, v1.16b \n" /* choose*/ \ + "bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \ + "bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \ + "bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \ + "cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v2.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v3.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ + "cmhs v4.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \ + "fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \ + "fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \ + "fmul v8.4s, v22.4s, v9.4s \n" /* mul */ \ + "bif v19.16b, v5.16b, v1.16b \n" /* choose*/ \ + "bif v20.16b, v6.16b, v2.16b \n" /* choose*/ \ + "bif v21.16b, v7.16b, v3.16b \n" /* choose*/ \ + "bif v22.16b, v8.16b, v4.16b \n" /* choose*/ + +#define STORE \ + "cbnz %w[flag_mask], 1f\n" \ + "str q15, [x0]\n" /* save outc00 */ \ + "str q16, [x4]\n" /* save outc01 */ \ + "str q17, [x1]\n" /* save outc10 */ \ + "str q18, [x5]\n" /* save outc11 */ \ + "str q19, [x2]\n" /* save outc20 */ \ + "str q20, [x6]\n" /* save outc21 */ \ + "str q21, [x3]\n" /* save outc30 */ \ + "str q22, [x7]\n" /* save outc31 */ \ + "b 2f\n" \ + "1:\n" \ + "str q15, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q17, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q19, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q21, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q16, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q18, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q20, [%[out]], #16 \n" /* save remain to pre_out */ \ + "str q22, [%[out]], #16 \n" /* save remain to pre_out */ \ + "2:\n" +#else +#define COMPUTE \ + /* load weights */ \ + "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" \ + "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" \ + /* load r0, r1 */ \ + "vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" \ + "vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" \ + /* main loop */ \ + "0: @ main loop\n" \ + /* mul r0 with w0, w1, w2, get out r0 */ \ + "vmul.f32 q8, q5, q0 @ w0 * inr00\n" \ + "vmul.f32 q9, q5, q1 @ w0 * inr01\n" \ + "vmul.f32 q10, q5, q2 @ w0 * inr02\n" \ + "vmul.f32 q11, q5, q3 @ w0 * inr03\n" \ + "vmla.f32 q8, q6, q1 @ w1 * inr01\n" \ + "vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" \ + "vmla.f32 q9, q6, q2 @ w1 * inr02\n" \ + "vmla.f32 q10, q6, q3 @ w1 * inr03\n" \ + "vmla.f32 q11, q6, q0 @ w1 * inr04\n" \ + "vmla.f32 q8, q7, q2 @ w2 * inr02\n" \ + "vmla.f32 q9, q7, q3 @ w2 * inr03\n" \ + "vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" \ + "vmla.f32 q10, q7, q0 @ w2 * inr04\n" \ + "vmla.f32 q11, q7, q1 @ w2 * inr05\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" \ + "vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" \ + /* mul r1 with w0-w5, get out r0, r1 */ \ + "vmul.f32 q12, q5, q2 @ w0 * inr10\n" \ + "vmul.f32 q13, q5, q3 @ w0 * inr11\n" \ + "vmul.f32 q14, q5, q0 @ w0 * inr12\n" \ + "vmul.f32 q15, q5, q1 @ w0 * inr13\n" \ + "vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" \ + "vmla.f32 q8, q4, q2 @ w3 * inr10\n" \ + "vmla.f32 q9, q4, q3 @ w3 * inr11\n" \ + "vmla.f32 q10, q4, q0 @ w3 * inr12\n" \ + "vmla.f32 q11, q4, q1 @ w3 * inr13\n" \ + /* mul r1 with w1, w4, get out r1, r0 */ \ + "vmla.f32 q8, q5, q3 @ w4 * inr11\n" \ + "vmla.f32 q12, q6, q3 @ w1 * inr11\n" \ + "vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" \ + "vmla.f32 q9, q5, q0 @ w4 * inr12\n" \ + "vmla.f32 q13, q6, q0 @ w1 * inr12\n" \ + "vmla.f32 q10, q5, q1 @ w4 * inr13\n" \ + "vmla.f32 q14, q6, q1 @ w1 * inr13\n" \ + "vmla.f32 q11, q5, q2 @ w4 * inr14\n" \ + "vmla.f32 q15, q6, q2 @ w1 * inr14\n" \ + "vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" \ + /* mul r1 with w2, w5, get out r1, r0 */ \ + "vmla.f32 q12, q7, q0 @ w2 * inr12\n" \ + "vmla.f32 q13, q7, q1 @ w2 * inr13\n" \ + "vmla.f32 q8, q6, q0 @ w5 * inr12\n" \ + "vmla.f32 q9, q6, q1 @ w5 * inr13\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" \ + "vmla.f32 q14, q7, q2 @ w2 * inr14\n" \ + "vmla.f32 q15, q7, q3 @ w2 * inr15\n" \ + "vmla.f32 q10, q6, q2 @ w5 * inr14\n" \ + "vmla.f32 q11, q6, q3 @ w5 * inr15\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" \ + "vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" \ + /* mul r2 with w3-w8, get out r0, r1 */ \ + "vmla.f32 q12, q4, q0 @ w3 * inr20\n" \ + "vmla.f32 q13, q4, q1 @ w3 * inr21\n" \ + "vmla.f32 q14, q4, q2 @ w3 * inr22\n" \ + "vmla.f32 q15, q4, q3 @ w3 * inr23\n" \ + "vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" \ + "vmla.f32 q8, q7, q0 @ w6 * inr20\n" \ + "vmla.f32 q9, q7, q1 @ w6 * inr21\n" \ + "vmla.f32 q10, q7, q2 @ w6 * inr22\n" \ + "vmla.f32 q11, q7, q3 @ w6 * inr23\n" \ + /* mul r2 with w4, w7, get out r1, r0 */ \ + "vmla.f32 q8, q4, q1 @ w7 * inr21\n" \ + "vmla.f32 q12, q5, q1 @ w4 * inr21\n" \ + "vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" \ + "vmla.f32 q9, q4, q2 @ w7 * inr22\n" \ + "vmla.f32 q13, q5, q2 @ w4 * inr22\n" \ + "vmla.f32 q10, q4, q3 @ w7 * inr23\n" \ + "vmla.f32 q14, q5, q3 @ w4 * inr23\n" \ + "vmla.f32 q11, q4, q0 @ w7 * inr24\n" \ + "vmla.f32 q15, q5, q0 @ w4 * inr24\n" \ + "vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" \ + /* mul r1 with w5, w8, get out r1, r0 */ \ + "vmla.f32 q12, q6, q2 @ w5 * inr22\n" \ + "vmla.f32 q13, q6, q3 @ w5 * inr23\n" \ + "vmla.f32 q8, q5, q2 @ w8 * inr22\n" \ + "vmla.f32 q9, q5, q3 @ w8 * inr23\n" \ + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" \ + "ldr r4, [%[outl], #32] @ load bias addr to r4\n" \ + "vmla.f32 q14, q6, q0 @ w5 * inr24\n" \ + "vmla.f32 q15, q6, q1 @ w5 * inr25\n" \ + "vmla.f32 q10, q5, q0 @ w8 * inr24\n" \ + "vmla.f32 q11, q5, q1 @ w8 * inr25\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" \ + "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \ + /* mul r3 with w6, w7, w8, get out r1 */ \ + "vmla.f32 q12, q7, q2 @ w6 * inr30\n" \ + "vmla.f32 q13, q7, q3 @ w6 * inr31\n" \ + "vmla.f32 q14, q7, q0 @ w6 * inr32\n" \ + "vmla.f32 q15, q7, q1 @ w6 * inr33\n" \ + "vmla.f32 q12, q4, q3 @ w7 * inr31\n" \ + "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" \ + "vld1.32 {d12-d13}, [r4] @ load bias\n" \ + "vmla.f32 q13, q4, q0 @ w7 * inr32\n" \ + "vmla.f32 q14, q4, q1 @ w7 * inr33\n" \ + "vmla.f32 q15, q4, q2 @ w7 * inr34\n" \ + "ldr r0, [%[outl]] @ load outc00 to r0\n" \ + "vmla.f32 q12, q5, q0 @ w8 * inr32\n" \ + "vmla.f32 q13, q5, q1 @ w8 * inr33\n" \ + "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" \ + "vmla.f32 q14, q5, q2 @ w8 * inr34\n" \ + "vmla.f32 q15, q5, q3 @ w8 * inr35\n" \ + "ldr r1, [%[outl], #4] @ load outc10 to r1\n" \ + "vadd.f32 q8, q8, q6 @ r00 add bias\n" \ + "vadd.f32 q9, q9, q6 @ r01 add bias\n" \ + "vadd.f32 q10, q10, q6 @ r02 add bias\n" \ + "vadd.f32 q11, q11, q6 @ r03 add bias\n" \ + "ldr r2, [%[outl], #8] @ load outc20 to r2\n" \ + "vadd.f32 q12, q12, q6 @ r10 add bias\n" \ + "vadd.f32 q13, q13, q6 @ r11 add bias\n" \ + "vadd.f32 q14, q14, q6 @ r12 add bias\n" \ + "vadd.f32 q15, q15, q6 @ r13 add bias\n" \ + "ldr r3, [%[outl], #12] @ load outc30 to r3\n" \ + "vmov.u32 q7, #0 @ mov zero to q7\n" +#define RELU \ + "vmax.f32 q8, q8, q7 @ r00 relu\n" \ + "vmax.f32 q9, q9, q7 @ r01 relu\n" \ + "vmax.f32 q10, q10, q7 @ r02 relu\n" \ + "vmax.f32 q11, q11, q7 @ r03 relu\n" \ + "vmax.f32 q12, q12, q7 @ r10 relu\n" \ + "vmax.f32 q13, q13, q7 @ r11 relu\n" \ + "vmax.f32 q14, q14, q7 @ r12 relu\n" \ + "vmax.f32 q15, q15, q7 @ r13 relu\n" + +#define RELU6 \ + "ldr r4, [%[outl], #40] @ load six to r4\n" \ + "vld1.32 {d12-d13}, [r4] @load data \n" \ + "vmin.f32 q8, q8, q6 @ r00 relu\n" \ + "vmin.f32 q9, q9, q6 @ r01 relu\n" \ + "vmin.f32 q10, q10, q6 @ r02 relu\n" \ + "vmin.f32 q11, q11, q6 @ r03 relu\n" \ + "vmin.f32 q12, q12, q6 @ r10 relu\n" \ + "vmin.f32 q13, q13, q6 @ r11 relu\n" \ + "vmin.f32 q14, q14, q6 @ r12 relu\n" \ + "vmin.f32 q15, q15, q6 @ r13 relu\n" + +#define LEAKY_RELU \ + "ldr r4, [%[outl], #44] @ load scale to r4\n" \ + "vld1.32 {d12-d13}, [r4] @load data \n" \ + "vcge.f32 q0, q8, q7 @ q0 > 0 \n" \ + "vcge.f32 q1, q9, q7 @ q0 > 0 \n" \ + "vmul.f32 q4, q8, q6 \n" \ + "vmul.f32 q5, q9, q6 \n" \ + "vcge.f32 q2, q10, q7 @ q0 > 0 \n" \ + "vcge.f32 q3, q11, q7 @ q0 > 0 \n" \ + "vbif q8, q4, q0 @ choose \n" \ + "vbif q9, q5, q1 @ choose \n" \ + "vmul.f32 q4, q10, q6 \n" \ + "vmul.f32 q5, q11, q6 \n" \ + "vbif q10, q4, q2 @ choose \n" \ + "vbif q11, q5, q3 @ choose \n" \ + "vcge.f32 q0, q12, q7 @ q0 > 0 \n" \ + "vcge.f32 q1, q13, q7 @ q0 > 0 \n" \ + "vmul.f32 q4, q12, q6 \n" \ + "vmul.f32 q5, q13, q6 \n" \ + "vcge.f32 q2, q14, q7 @ q0 > 0 \n" \ + "vcge.f32 q3, q15, q7 @ q0 > 0 \n" \ + "vbif q12, q4, q0 @ choose \n" \ + "vbif q13, q5, q1 @ choose \n" \ + "vmul.f32 q4, q14, q6 \n" \ + "vmul.f32 q5, q15, q6 \n" \ + "vbif q14, q4, q2 @ choose \n" \ + "vbif q15, q5, q3 @ choose \n" + +#define STORE \ + "ldr r4, [%[outl], #16] @ load outc01 to r4\n" \ + "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" \ + "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" \ + "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" \ + "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" \ + "ldr r5, [%[outl], #20] @ load outc11 to r5\n" \ + "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" \ + "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" \ + "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" \ + "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" \ + "cmp %[flag_mask], #0 @ cmp flag mask\n" \ + "bne 2f\n" \ + "vst1.32 {d16-d17}, [r0] @ save outc00\n" \ + "vst1.32 {d18-d19}, [r1] @ save outc10\n" \ + "vst1.32 {d20-d21}, [r2] @ save outc20\n" \ + "vst1.32 {d22-d23}, [r3] @ save outc30\n" \ + "vst1.32 {d24-d25}, [r4] @ save outc01\n" \ + "vst1.32 {d26-d27}, [r5] @ save outc11\n" \ + "ldr r0, [%[outl], #24] @ load outc21 to r0\n" \ + "ldr r1, [%[outl], #28] @ load outc31 to r1\n" \ + "vst1.32 {d28-d29}, [r0] @ save outc21\n" \ + "vst1.32 {d30-d31}, [r1] @ save outc31\n" \ + "b 3f @ branch end\n" \ + "2: \n" \ + "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" \ + "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" \ + "3: \n" +#endif +// clang-format on +void act_switch_3x3s1(const float* inr0, + const float* inr1, + const float* inr2, + const float* inr3, + float* out0, + const float* weight_c, + float flag_mask, + void* outl_ptr, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t w5, + float32x4_t w6, + float32x4_t w7, + float32x4_t w8, + float32x4_t vbias, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [out] "+r"(out0) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [w5] "w"(w5), + [w6] "w"(w6), + [w7] "w"(w7), + [w8] "w"(w8), + [vbias] "w"(vbias), + [outl] "r"(outl_ptr), + [flag_mask] "r"(flag_mask) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "x0", + "x1", + "x2", + "x3", + "x4", + "x5", + "x6", + "x7"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [out0] "+r"(out0), + [wc0] "+r"(weight_c) + : [flag_mask] "r"(flag_mask), [outl] "r"(outl_ptr) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "r0", + "r1", + "r2", + "r3", + "r4", + "r5"); +#endif + } +} void conv_3x3s1_depthwise_fp32(const float* i_data, float* o_data, int bs, @@ -37,6 +816,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx) { int threads = ctx->threads(); @@ -78,6 +858,31 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, remain = remain > 0 ? remain : 0; int row_len = win_round * out_c_block; + float six_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + float scale_ptr[4] = {1.f, 1.f, 1.f, 1.f}; + float relu_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: + break; + case lite_api::ActivationType::kRelu6: + six_ptr[0] = act_param.Relu_clipped_coef; + six_ptr[1] = act_param.Relu_clipped_coef; + six_ptr[2] = act_param.Relu_clipped_coef; + six_ptr[3] = act_param.Relu_clipped_coef; + break; + case lite_api::ActivationType::kLeakyRelu: + scale_ptr[0] = act_param.Leaky_relu_alpha; + scale_ptr[1] = act_param.Leaky_relu_alpha; + scale_ptr[2] = act_param.Leaky_relu_alpha; + scale_ptr[3] = act_param.Leaky_relu_alpha; + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } for (int n = 0; n < bs; ++n) { const float* din_batch = i_data + n * ic * size_in_channel; float* dout_batch = o_data + n * oc * size_out_channel; @@ -147,6 +952,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outc21 = ptr_write; outc31 = ptr_write; } + float* outl[] = {outc00, outc10, outc20, @@ -156,361 +962,54 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outc21, outc31, reinterpret_cast(bias_local), - reinterpret_cast(flag_relu)}; + reinterpret_cast(relu_ptr), + reinterpret_cast(six_ptr), + reinterpret_cast(scale_ptr)}; void* outl_ptr = reinterpret_cast(outl); for (int w = 0; w < w_loop; ++w) { bool flag_mask = (w == w_loop - 1) && flag_remain; float* out0 = pre_out; -// clang-format off #ifdef __aarch64__ - asm volatile( - "ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ - "ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ - "ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ - "ldp q8, q9, [%[inr1]], #32\n" /* load input r1*/ - "ldp q4, q5, [%[inr0]]\n" /* load input r0*/ - "ldp q10, q11, [%[inr1]]\n" /* load input r1*/ - /* r0, r1, mul w0, get out r0, r1 */ - "fmul v15.4s , %[w0].4s, v0.4s\n" /* outr00 = w0 * r0, 0*/ - "fmul v16.4s , %[w0].4s, v1.4s\n" /* outr01 = w0 * r0, 1*/ - "fmul v17.4s , %[w0].4s, v2.4s\n" /* outr02 = w0 * r0, 2*/ - "fmul v18.4s , %[w0].4s, v3.4s\n" /* outr03 = w0 * r0, 3*/ - "fmul v19.4s , %[w0].4s, v6.4s\n" /* outr10 = w0 * r1, 0*/ - "fmul v20.4s , %[w0].4s, v7.4s\n" /* outr11 = w0 * r1, 1*/ - "fmul v21.4s , %[w0].4s, v8.4s\n" /* outr12 = w0 * r1, 2*/ - "fmul v22.4s , %[w0].4s, v9.4s\n" /* outr13 = w0 * r1, 3*/ - /* r0, r1, mul w1, get out r0, r1 */ - "fmla v15.4s , %[w1].4s, v1.4s\n" /* outr00 = w1 * r0[1]*/ - "ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ - "fmla v16.4s , %[w1].4s, v2.4s\n" /* outr01 = w1 * r0[2]*/ - "fmla v17.4s , %[w1].4s, v3.4s\n" /* outr02 = w1 * r0[3]*/ - "fmla v18.4s , %[w1].4s, v4.4s\n" /* outr03 = w1 * r0[4]*/ - "fmla v19.4s , %[w1].4s, v7.4s\n" /* outr10 = w1 * r1[1]*/ - "fmla v20.4s , %[w1].4s, v8.4s\n" /* outr11 = w1 * r1[2]*/ - "fmla v21.4s , %[w1].4s, v9.4s\n" /* outr12 = w1 * r1[3]*/ - "fmla v22.4s , %[w1].4s, v10.4s\n"/* outr13 = w1 * r1[4]*/ - /* r0, r1, mul w2, get out r0, r1 */ - "fmla v15.4s , %[w2].4s, v2.4s\n" /* outr00 = w2 * r0[2]*/ - "fmla v16.4s , %[w2].4s, v3.4s\n" /* outr01 = w2 * r0[3]*/ - "ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ - "fmla v17.4s , %[w2].4s, v4.4s\n" /* outr02 = w2 * r0[4]*/ - "fmla v18.4s , %[w2].4s, v5.4s\n" /* outr03 = w2 * r0[5]*/ - "ldp q4, q5, [%[inr2]]\n" /* load input r2*/ - "fmla v19.4s , %[w2].4s, v8.4s\n" /* outr10 = w2 * r1[2]*/ - "fmla v20.4s , %[w2].4s, v9.4s\n" /* outr11 = w2 * r1[3]*/ - "fmla v21.4s , %[w2].4s, v10.4s\n"/* outr12 = w2 * r1[4]*/ - "fmla v22.4s , %[w2].4s, v11.4s\n"/* outr13 = w2 * r1[5]*/ - /* r1, r2, mul w3, get out r0, r1 */ - "fmla v15.4s , %[w3].4s, v6.4s\n" /* outr00 = w3 * r1[0]*/ - "fmla v16.4s , %[w3].4s, v7.4s\n" /* outr01 = w3 * r1[1]*/ - "fmla v17.4s , %[w3].4s, v8.4s\n" /* outr02 = w3 * r1[2]*/ - "fmla v18.4s , %[w3].4s, v9.4s\n" /* outr03 = w3 * r1[3]*/ - "fmla v19.4s , %[w3].4s, v0.4s\n" /* outr10 = w3 * r2[0]*/ - "fmla v20.4s , %[w3].4s, v1.4s\n" /* outr11 = w3 * r2[1]*/ - "fmla v21.4s , %[w3].4s, v2.4s\n" /* outr12 = w3 * r2[2]*/ - "fmla v22.4s , %[w3].4s, v3.4s\n" /* outr13 = w3 * r2[3]*/ - /* r1, r2, mul w4, get out r0, r1 */ - "fmla v15.4s , %[w4].4s, v7.4s\n" /* outr00 = w4 * r1[1]*/ - "ldp q6, q7, [%[inr3]], #32\n" /* load input r3*/ - "fmla v16.4s , %[w4].4s, v8.4s\n" /* outr01 = w4 * r1[2]*/ - "fmla v17.4s , %[w4].4s, v9.4s\n" /* outr02 = w4 * r1[3]*/ - "fmla v18.4s , %[w4].4s, v10.4s\n"/* outr03 = w4 * r1[4]*/ - "ldp x0, x1, [%[outl]] \n" - "fmla v19.4s , %[w4].4s, v1.4s\n" /* outr10 = w4 * r2[1]*/ - "fmla v20.4s , %[w4].4s, v2.4s\n" /* outr11 = w4 * r2[2]*/ - "fmla v21.4s , %[w4].4s, v3.4s\n" /* outr12 = w4 * r2[3]*/ - "fmla v22.4s , %[w4].4s, v4.4s\n" /* outr13 = w4 * r2[4]*/ - /* r1, r2, mul w5, get out r0, r1 */ - "fmla v15.4s , %[w5].4s, v8.4s\n" /* outr00 = w5 * r1[2]*/ - "fmla v16.4s , %[w5].4s, v9.4s\n" /* outr01 = w5 * r1[3]*/ - "ldp q8, q9, [%[inr3]], #32\n" /* load input r3*/ - "fmla v17.4s , %[w5].4s, v10.4s\n"/* outr02 = w5 * r1[4]*/ - "fmla v18.4s , %[w5].4s, v11.4s\n"/* outr03 = w5 * r1[5]*/ - "ldp q10, q11, [%[inr3]]\n" /* load input r3*/ - "fmla v19.4s , %[w5].4s, v2.4s\n" /* outr10 = w5 * r2[2]*/ - "fmla v20.4s , %[w5].4s, v3.4s\n" /* outr11 = w5 * r2[3]*/ - "fmla v21.4s , %[w5].4s, v4.4s\n" /* outr12 = w5 * r2[4]*/ - "fmla v22.4s , %[w5].4s, v5.4s\n" /* outr13 = w5 * r2[5]*/ - /* r2, r3, mul w6, get out r0, r1 */ - "fmla v15.4s , %[w6].4s, v0.4s\n" /* outr00 = w6 * r2[0]*/ - "fmla v16.4s , %[w6].4s, v1.4s\n" /* outr01 = w6 * r2[1]*/ - "fmla v17.4s , %[w6].4s, v2.4s\n" /* outr02 = w6 * r2[2]*/ - "fmla v18.4s , %[w6].4s, v3.4s\n" /* outr03 = w6 * r2[3]*/ - "ldp x2, x3, [%[outl], #16] \n" - "fmla v19.4s , %[w6].4s, v6.4s\n" /* outr10 = w6 * r3[0]*/ - "fmla v20.4s , %[w6].4s, v7.4s\n" /* outr11 = w6 * r3[1]*/ - "fmla v21.4s , %[w6].4s, v8.4s\n" /* outr12 = w6 * r3[2]*/ - "fmla v22.4s , %[w6].4s, v9.4s\n" /* outr13 = w6 * r3[3]*/ - /* r2, r3, mul w7, get out r0, r1 */ - "fmla v15.4s , %[w7].4s, v1.4s\n" /* outr00 = w7 * r2[1]*/ - "fmla v16.4s , %[w7].4s, v2.4s\n" /* outr01 = w7 * r2[2]*/ - "fmla v17.4s , %[w7].4s, v3.4s\n" /* outr02 = w7 * r2[3]*/ - "fmla v18.4s , %[w7].4s, v4.4s\n" /* outr03 = w7 * r2[4]*/ - "ldp x4, x5, [%[outl], #32] \n" - "fmla v19.4s , %[w7].4s, v7.4s\n" /* outr10 = w7 * r3[1]*/ - "fmla v20.4s , %[w7].4s, v8.4s\n" /* outr11 = w7 * r3[2]*/ - "fmla v21.4s , %[w7].4s, v9.4s\n" /* outr12 = w7 * r3[3]*/ - "fmla v22.4s , %[w7].4s, v10.4s\n"/* outr13 = w7 * r3[4]*/ - /* r2, r3, mul w8, get out r0, r1 */ - "fmla v15.4s , %[w8].4s, v2.4s\n" /* outr00 = w8 * r2[2]*/ - "fmla v16.4s , %[w8].4s, v3.4s\n" /* outr01 = w8 * r2[3]*/ - "fmla v17.4s , %[w8].4s, v4.4s\n" /* outr02 = w8 * r2[0]*/ - "fmla v18.4s , %[w8].4s, v5.4s\n" /* outr03 = w8 * r2[1]*/ - "ldp x6, x7, [%[outl], #48] \n" - "fmla v19.4s , %[w8].4s, v8.4s\n" /* outr10 = w8 * r3[2]*/ - "fmla v20.4s , %[w8].4s, v9.4s\n" /* outr11 = w8 * r3[3]*/ - "fmla v21.4s , %[w8].4s, v10.4s\n"/* outr12 = w8 * r3[0]*/ - "fmla v22.4s , %[w8].4s, v11.4s\n"/* outr13 = w8 * r3[1]*/ - - "fadd v15.4s, v15.4s, %[vbias].4s\n"/* add bias */ - "fadd v16.4s, v16.4s, %[vbias].4s\n"/* add bias */ - "fadd v17.4s, v17.4s, %[vbias].4s\n"/* add bias */ - "fadd v18.4s, v18.4s, %[vbias].4s\n"/* add bias */ - "fadd v19.4s, v19.4s, %[vbias].4s\n"/* add bias */ - "fadd v20.4s, v20.4s, %[vbias].4s\n"/* add bias */ - "fadd v21.4s, v21.4s, %[vbias].4s\n"/* add bias */ - "fadd v22.4s, v22.4s, %[vbias].4s\n"/* add bias */ - - /* transpose */ - "trn1 v0.4s, v15.4s, v16.4s\n" /* r0: a0a1c0c1*/ - "trn2 v1.4s, v15.4s, v16.4s\n" /* r0: b0b1d0d1*/ - "trn1 v2.4s, v17.4s, v18.4s\n" /* r0: a2a3c2c3*/ - "trn2 v3.4s, v17.4s, v18.4s\n" /* r0: b2b3d2d3*/ - "trn1 v4.4s, v19.4s, v20.4s\n" /* r1: a0a1c0c1*/ - "trn2 v5.4s, v19.4s, v20.4s\n" /* r1: b0b1d0d1*/ - "trn1 v6.4s, v21.4s, v22.4s\n" /* r1: a2a3c2c3*/ - "trn2 v7.4s, v21.4s, v22.4s\n" /* r1: b2b3d2d3*/ - "trn1 v15.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ - "trn2 v19.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ - "trn1 v17.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ - "trn2 v21.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ - "trn1 v16.2d, v4.2d, v6.2d\n" /* r1: a0a1a2a3*/ - "trn2 v20.2d, v4.2d, v6.2d\n" /* r1: c0c1c2c3*/ - "trn1 v18.2d, v5.2d, v7.2d\n" /* r1: b0b1b2b3*/ - "trn2 v22.2d, v5.2d, v7.2d\n" /* r1: d0d1d2d3*/ - - "cbz %w[flag_relu], 0f\n" /* skip relu*/ - "movi v0.4s, #0\n" /* for relu */ - "fmax v15.4s, v15.4s, v0.4s\n" - "fmax v16.4s, v16.4s, v0.4s\n" - "fmax v17.4s, v17.4s, v0.4s\n" - "fmax v18.4s, v18.4s, v0.4s\n" - "fmax v19.4s, v19.4s, v0.4s\n" - "fmax v20.4s, v20.4s, v0.4s\n" - "fmax v21.4s, v21.4s, v0.4s\n" - "fmax v22.4s, v22.4s, v0.4s\n" - "0:\n" - "cbnz %w[flag_mask], 1f\n" - "str q15, [x0]\n" /* save outc00 */ - "str q16, [x4]\n" /* save outc01 */ - "str q17, [x1]\n" /* save outc10 */ - "str q18, [x5]\n" /* save outc11 */ - "str q19, [x2]\n" /* save outc20 */ - "str q20, [x6]\n" /* save outc21 */ - "str q21, [x3]\n" /* save outc30 */ - "str q22, [x7]\n" /* save outc31 */ - "b 2f\n" - "1:\n" - "str q15, [%[out]], #16 \n" /* save remain to pre_out */ - "str q17, [%[out]], #16 \n" /* save remain to pre_out */ - "str q19, [%[out]], #16 \n" /* save remain to pre_out */ - "str q21, [%[out]], #16 \n" /* save remain to pre_out */ - "str q16, [%[out]], #16 \n" /* save remain to pre_out */ - "str q18, [%[out]], #16 \n" /* save remain to pre_out */ - "str q20, [%[out]], #16 \n" /* save remain to pre_out */ - "str q22, [%[out]], #16 \n" /* save remain to pre_out */ - "2:\n" - :[inr0] "+r"(inr0), [inr1] "+r"(inr1), - [inr2] "+r"(inr2), [inr3] "+r"(inr3), - [out]"+r"(out0) - :[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2), - [w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5), - [w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8), - [vbias]"w" (vbias), [outl] "r" (outl_ptr), - [flag_mask] "r" (flag_mask), [flag_relu] "r" (flag_relu) - : "cc", "memory", - "v0","v1","v2","v3","v4","v5","v6","v7", - "v8", "v9", "v10", "v11", "v15", - "v16","v17","v18","v19","v20","v21","v22", - "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7" - ); + act_switch_3x3s1(inr0, + inr1, + inr2, + inr3, + out0, + weight_c, + flag_mask, + outl_ptr, + w0, + w1, + w2, + w3, + w4, + w5, + w6, + w7, + w8, + vbias, + act_param); #else - asm volatile( - /* load weights */ - "vld1.32 {d10-d13}, [%[wc0]]! @ load w0, w1, to q5, q6\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w2, to q7\n" - /* load r0, r1 */ - "vld1.32 {d0-d3}, [%[r0]]! @ load r0, q0, q1\n" - "vld1.32 {d4-d7}, [%[r0]]! @ load r0, q2, q3\n" - /* main loop */ - "0: @ main loop\n" - /* mul r0 with w0, w1, w2, get out r0 */ - "vmul.f32 q8, q5, q0 @ w0 * inr00\n" - "vmul.f32 q9, q5, q1 @ w0 * inr01\n" - "vmul.f32 q10, q5, q2 @ w0 * inr02\n" - "vmul.f32 q11, q5, q3 @ w0 * inr03\n" - "vmla.f32 q8, q6, q1 @ w1 * inr01\n" - "vld1.32 {d0-d3}, [%[r0]] @ load r0, q0, q1\n" - "vmla.f32 q9, q6, q2 @ w1 * inr02\n" - "vmla.f32 q10, q6, q3 @ w1 * inr03\n" - "vmla.f32 q11, q6, q0 @ w1 * inr04\n" - "vmla.f32 q8, q7, q2 @ w2 * inr02\n" - "vmla.f32 q9, q7, q3 @ w2 * inr03\n" - "vld1.32 {d4-d7}, [%[r1]]! @ load r0, q2, q3\n" - "vmla.f32 q10, q7, q0 @ w2 * inr04\n" - "vmla.f32 q11, q7, q1 @ w2 * inr05\n" - "vld1.32 {d0-d3}, [%[r1]]! @ load r0, q0, q1\n" - "vld1.32 {d8-d9}, [%[wc0]]! @ load w3 to q4\n" - /* mul r1 with w0-w5, get out r0, r1 */ - "vmul.f32 q12, q5, q2 @ w0 * inr10\n" - "vmul.f32 q13, q5, q3 @ w0 * inr11\n" - "vmul.f32 q14, q5, q0 @ w0 * inr12\n" - "vmul.f32 q15, q5, q1 @ w0 * inr13\n" - "vld1.32 {d10-d11}, [%[wc0]]! @ load w4 to q5\n" - "vmla.f32 q8, q4, q2 @ w3 * inr10\n" - "vmla.f32 q9, q4, q3 @ w3 * inr11\n" - "vmla.f32 q10, q4, q0 @ w3 * inr12\n" - "vmla.f32 q11, q4, q1 @ w3 * inr13\n" - /* mul r1 with w1, w4, get out r1, r0 */ - "vmla.f32 q8, q5, q3 @ w4 * inr11\n" - "vmla.f32 q12, q6, q3 @ w1 * inr11\n" - "vld1.32 {d4-d7}, [%[r1]] @ load r1, q2, q3\n" - "vmla.f32 q9, q5, q0 @ w4 * inr12\n" - "vmla.f32 q13, q6, q0 @ w1 * inr12\n" - "vmla.f32 q10, q5, q1 @ w4 * inr13\n" - "vmla.f32 q14, q6, q1 @ w1 * inr13\n" - "vmla.f32 q11, q5, q2 @ w4 * inr14\n" - "vmla.f32 q15, q6, q2 @ w1 * inr14\n" - "vld1.32 {d12-d13}, [%[wc0]]! @ load w5 to q6\n" - /* mul r1 with w2, w5, get out r1, r0 */ - "vmla.f32 q12, q7, q0 @ w2 * inr12\n" - "vmla.f32 q13, q7, q1 @ w2 * inr13\n" - "vmla.f32 q8, q6, q0 @ w5 * inr12\n" - "vmla.f32 q9, q6, q1 @ w5 * inr13\n" - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, q0, q1\n" - "vmla.f32 q14, q7, q2 @ w2 * inr14\n" - "vmla.f32 q15, q7, q3 @ w2 * inr15\n" - "vmla.f32 q10, q6, q2 @ w5 * inr14\n" - "vmla.f32 q11, q6, q3 @ w5 * inr15\n" - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, q0, q1\n" - "vld1.32 {d14-d15}, [%[wc0]]! @ load w6, to q7\n" - /* mul r2 with w3-w8, get out r0, r1 */ - "vmla.f32 q12, q4, q0 @ w3 * inr20\n" - "vmla.f32 q13, q4, q1 @ w3 * inr21\n" - "vmla.f32 q14, q4, q2 @ w3 * inr22\n" - "vmla.f32 q15, q4, q3 @ w3 * inr23\n" - "vld1.32 {d8-d9}, [%[wc0]]! @ load w7, to q4\n" - "vmla.f32 q8, q7, q0 @ w6 * inr20\n" - "vmla.f32 q9, q7, q1 @ w6 * inr21\n" - "vmla.f32 q10, q7, q2 @ w6 * inr22\n" - "vmla.f32 q11, q7, q3 @ w6 * inr23\n" - /* mul r2 with w4, w7, get out r1, r0 */ - "vmla.f32 q8, q4, q1 @ w7 * inr21\n" - "vmla.f32 q12, q5, q1 @ w4 * inr21\n" - "vld1.32 {d0-d3}, [%[r2]] @ load r2, q0, q1\n" - "vmla.f32 q9, q4, q2 @ w7 * inr22\n" - "vmla.f32 q13, q5, q2 @ w4 * inr22\n" - "vmla.f32 q10, q4, q3 @ w7 * inr23\n" - "vmla.f32 q14, q5, q3 @ w4 * inr23\n" - "vmla.f32 q11, q4, q0 @ w7 * inr24\n" - "vmla.f32 q15, q5, q0 @ w4 * inr24\n" - "vld1.32 {d10-d11}, [%[wc0]]! @ load w8 to q5\n" - /* mul r1 with w5, w8, get out r1, r0 */ - "vmla.f32 q12, q6, q2 @ w5 * inr22\n" - "vmla.f32 q13, q6, q3 @ w5 * inr23\n" - "vmla.f32 q8, q5, q2 @ w8 * inr22\n" - "vmla.f32 q9, q5, q3 @ w8 * inr23\n" - "vld1.32 {d4-d7}, [%[r3]]! @ load r3, q2, q3\n" - "ldr r4, [%[outl], #32] @ load bias addr to r4\n" - "vmla.f32 q14, q6, q0 @ w5 * inr24\n" - "vmla.f32 q15, q6, q1 @ w5 * inr25\n" - "vmla.f32 q10, q5, q0 @ w8 * inr24\n" - "vmla.f32 q11, q5, q1 @ w8 * inr25\n" - "vld1.32 {d0-d3}, [%[r3]]! @ load r3, q0, q1\n" - "sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" - /* mul r3 with w6, w7, w8, get out r1 */ - "vmla.f32 q12, q7, q2 @ w6 * inr30\n" - "vmla.f32 q13, q7, q3 @ w6 * inr31\n" - "vmla.f32 q14, q7, q0 @ w6 * inr32\n" - "vmla.f32 q15, q7, q1 @ w6 * inr33\n" - "vmla.f32 q12, q4, q3 @ w7 * inr31\n" - "vld1.32 {d4-d7}, [%[r3]] @ load r3, q2, q3\n" - "vld1.32 {d12-d13}, [r4] @ load bias\n" - "vmla.f32 q13, q4, q0 @ w7 * inr32\n" - "vmla.f32 q14, q4, q1 @ w7 * inr33\n" - "vmla.f32 q15, q4, q2 @ w7 * inr34\n" - "ldr r0, [%[outl]] @ load outc00 to r0\n" - "vmla.f32 q12, q5, q0 @ w8 * inr32\n" - "vmla.f32 q13, q5, q1 @ w8 * inr33\n" - "ldr r5, [%[outl], #36] @ load flag_relu to r5\n" - "vmla.f32 q14, q5, q2 @ w8 * inr34\n" - "vmla.f32 q15, q5, q3 @ w8 * inr35\n" - "ldr r1, [%[outl], #4] @ load outc10 to r1\n" - "vadd.f32 q8, q8, q6 @ r00 add bias\n" - "vadd.f32 q9, q9, q6 @ r01 add bias\n" - "vadd.f32 q10, q10, q6 @ r02 add bias\n" - "vadd.f32 q11, q11, q6 @ r03 add bias\n" - "ldr r2, [%[outl], #8] @ load outc20 to r2\n" - "vadd.f32 q12, q12, q6 @ r10 add bias\n" - "vadd.f32 q13, q13, q6 @ r11 add bias\n" - "vadd.f32 q14, q14, q6 @ r12 add bias\n" - "vadd.f32 q15, q15, q6 @ r13 add bias\n" - "ldr r3, [%[outl], #12] @ load outc30 to r3\n" - "vmov.u32 q7, #0 @ mov zero to q7\n" - "cmp r5, #0 @ cmp flag relu\n" - "beq 1f @ skip relu\n" - "vmax.f32 q8, q8, q7 @ r00 relu\n" - "vmax.f32 q9, q9, q7 @ r01 relu\n" - "vmax.f32 q10, q10, q7 @ r02 relu\n" - "vmax.f32 q11, q11, q7 @ r03 relu\n" - "vmax.f32 q12, q12, q7 @ r10 relu\n" - "vmax.f32 q13, q13, q7 @ r11 relu\n" - "vmax.f32 q14, q14, q7 @ r12 relu\n" - "vmax.f32 q15, q15, q7 @ r13 relu\n" - "1:\n" - "ldr r4, [%[outl], #16] @ load outc01 to r4\n" - "vtrn.32 q8, q9 @ r0: q8 : a0a1c0c1, q9 : b0b1d0d1\n" - "vtrn.32 q10, q11 @ r0: q10: a2a3c2c3, q11: b2b3d2d3\n" - "vtrn.32 q12, q13 @ r1: q12: a0a1c0c1, q13: b0b1d0d1\n" - "vtrn.32 q14, q15 @ r1: q14: a2a3c2c3, q15: b2b3d2d3\n" - "ldr r5, [%[outl], #20] @ load outc11 to r5\n" - "vswp d17, d20 @ r0: q8 : a0a1a2a3, q10: c0c1c2c3 \n" - "vswp d19, d22 @ r0: q9 : b0b1b2b3, q11: d0d1d2d3 \n" - "vswp d25, d28 @ r1: q12: a0a1a2a3, q14: c0c1c2c3 \n" - "vswp d27, d30 @ r1: q13: b0b1b2b3, q15: d0d1d2d3 \n" - "cmp %[flag_mask], #0 @ cmp flag mask\n" - "bne 2f\n" - "vst1.32 {d16-d17}, [r0] @ save outc00\n" - "vst1.32 {d18-d19}, [r1] @ save outc10\n" - "vst1.32 {d20-d21}, [r2] @ save outc20\n" - "vst1.32 {d22-d23}, [r3] @ save outc30\n" - "vst1.32 {d24-d25}, [r4] @ save outc01\n" - "vst1.32 {d26-d27}, [r5] @ save outc11\n" - "ldr r0, [%[outl], #24] @ load outc21 to r0\n" - "ldr r1, [%[outl], #28] @ load outc31 to r1\n" - "vst1.32 {d28-d29}, [r0] @ save outc21\n" - "vst1.32 {d30-d31}, [r1] @ save outc31\n" - "b 3f @ branch end\n" - "2: \n" - "vst1.32 {d16-d17}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d18-d19}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d20-d21}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d22-d23}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d24-d25}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d26-d27}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d28-d29}, [%[out0]]! @ save remain to pre_out\n" - "vst1.32 {d30-d31}, [%[out0]]! @ save remain to pre_out\n" - "3: \n" - : [r0] "+r"(inr0), [r1] "+r"(inr1), - [r2] "+r"(inr2), [r3] "+r"(inr3), - [out0] "+r"(out0), [wc0] "+r"(weight_c) - : [flag_mask] "r" (flag_mask), [outl] "r" (outl_ptr) - : "cc", "memory", - "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", - "q10", "q11", "q12", "q13","q14", "q15", "r0", "r1", "r2", "r3", "r4", "r5" - ); -#endif // __arch64__ - // clang-format on + act_switch_3x3s1(inr0, + inr1, + inr2, + inr3, + out0, + weight_c, + flag_mask, + outl_ptr, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + vbias, + act_param); +#endif outl[0] += 4; outl[1] += 4; outl[2] += 4; @@ -519,6 +1018,10 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, outl[5] += 4; outl[6] += 4; outl[7] += 4; + inr0 += 16; + inr1 += 16; + inr2 += 16; + inr3 += 16; if (flag_mask) { memcpy(outl[0] - 4, pre_out, remain * sizeof(float)); memcpy(outl[1] - 4, pre_out + 4, remain * sizeof(float)); diff --git a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc index 807135f57dfadf690277ab57bd5597e9470ae549..f5b196efcca3f3f35367f2fea5e8f475b7147f48 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_fp32.cc @@ -75,6 +75,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, //! prepack input to tmp buffer //! write output to tmp buffer auto paddings = *param.paddings; + auto act_param = param.activation_param; const int threads = ctx->threads(); int l2_size = ctx->llc_size() / sizeof(float); const int pad_w = paddings[2]; @@ -510,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } #pragma omp parallel for num_threads(threads) @@ -839,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, oh, ow, flag_relu, - ptr_write); + ptr_write, + &act_param); } } } diff --git a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc index 455781e37e0747950e6740f6db45c1ce8c0e96c8..602239a1fe1675c6eecb5b45a8e526ada98a56bb 100644 --- a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc @@ -205,14 +205,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \ "fadd v16.4s, v16.4s, v11.4s \n" \ - "fadd v16.4s, v16.4s, v12.4s \n" + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" #define LEFT_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ - \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \ \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ @@ -244,53 +242,52 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "blt 1f \n" -#define MID_COMPUTE_S2 \ - "2: \n" /* r0 */ \ - "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ - "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ - "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ - \ - "ext v10.16b, v2.16b, v18.16b, #4 \n" \ - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ - "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ - "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ - "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ - \ - "ext v10.16b, v4.16b, v19.16b, #4 \n" \ - \ - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ - "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ - "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ - \ - "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ - "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ - \ - "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ - "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ext v10.16b, v6.16b, v20.16b, #4 \n" \ - \ - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ - "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ - "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ - \ - "ext v10.16b, v8.16b, v21.16b, #4 \n" \ - \ - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ - \ - "fadd v16.4s, v16.4s, v11.4s \n" \ - "fadd v16.4s, v16.4s, v12.4s \n" +#define MID_COMPUTE_S2 \ + "2: \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, v18.16b, #4 \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, v19.16b, #4 \n" \ + \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, v20.16b, #4 \n" \ + \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, v21.16b, #4 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" #define MID_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ - "ld1 {v15.4s}, [%[inptr0]] \n" \ - "ld1 {v18.4s}, [%[inptr1]] \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -360,14 +357,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "fadd v16.4s, v16.4s, v11.4s \n" \ "fadd v16.4s, v16.4s, v12.4s \n" \ - "ld1 {v1.4s}, [%[outptr1]] \n" + "ld1 {v1.4s}, [%[outptr1]] \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" #define RIGHT_RESULT_S2 \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ "bif v16.16b, v0.16b, %[wmask].16b \n" \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -382,11 +377,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "4: \n" #define LEFT_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ - \ "fmax v16.4s, v16.4s, %[vzero].4s \n" \ \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ @@ -424,14 +414,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "blt 1f \n" #define MID_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ - "ld1 {v15.4s}, [%[inptr0]] \n" \ - "ld1 {v18.4s}, [%[inptr1]] \n" \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ @@ -457,11 +439,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, "bne 2b \n" #define RIGHT_RESULT_S2_RELU \ - /* r4 */ \ - "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ - "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ - "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ - \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ \ "fadd v17.4s, v17.4s, v13.4s \n" \ diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index e4279d9a728bc7af0f14a00b781db449fc426582..c4fe965d0b17fa56d76812af14b40bddbc5b313a 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -20,6 +20,7 @@ #include "lite/backends/arm/math/sgemm.h" #include "lite/backends/arm/math/type_trans.h" #include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" #include "lite/utils/cp_logging.h" namespace paddle { @@ -28,6 +29,7 @@ namespace arm { namespace math { #define LITEMAX(a, b) ((a) > (b) ? (a) : (b)) +#define LITEMIN(a, b) ((a) < (b) ? (a) : (b)) #define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) template @@ -589,7 +591,238 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din, } } } +// clang-format off +#ifdef __aarch64__ +#define NCHWC1_TRANS_FP32_COMPUTE \ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "movi v20.4s, #0 \n" /* for relu */ \ + "1: \n" /* main loop*/ + +#define NCHWC1_TRANS_FP32_RELU \ + "fmax v0.4s, v0.4s, v20.4s \n" /*relu*/ \ + "fmax v1.4s, v1.4s, v20.4s \n" /*relu*/ \ + "fmax v2.4s, v2.4s, v20.4s \n" /*relu*/ \ + "fmax v3.4s, v3.4s, v20.4s \n" /*relu*/ + +#define NCHWC1_TRANS_FP32_RELU6 \ + "fmin v0.4s, v0.4s, %[six].4s \n" /* relu6 */ \ + "fmin v1.4s, v1.4s, %[six].4s \n" /* relu6 */ \ + "fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \ + "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC1_TRANS_FP32_LEAKY_RELU \ + "cmhs v4.4s, v0.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v5.4s, v1.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \ + "fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \ + "fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \ + "fmul v11.4s, v3.4s, %[scale].4s \n" /* mul */ \ + "bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \ + "bif v2.16b, v10.16b, v6.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v7.16b \n" /* choose*/ + +#define NCHWC1_TRANS_FP32_STORE \ + "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ + \ + "str q0, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q1, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "str q2, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ + "str q3, [%[doutc0r0]], #16 \n" /* store c2r0*/ \ + "ldr q2, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + "ldr q3, [%[ptr_din]], #16 \n" /* load data, c0r0, c1r0, c0r1*/ \ + \ + "bne 1b \n" /* jump to main loop*/ +#else +#define NCHWC1_TRANS_FP32_COMPUTE \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0 \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data, c0r0 \n" \ + "vmov.u32 q15, #0 @ dump zero\n" \ + "1: @ main loop\n" +#define NCHWC1_TRANS_FP32_RELU \ + "vmax.f32 q0, q0, q15 @ relu\n" \ + "vmax.f32 q1, q1, q15 @ relu\n" \ + "vmax.f32 q2, q2, q15 @ relu\n" \ + "vmax.f32 q3, q3, q15 @ relu\n" + +#define NCHWC1_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" \ + "vmin.f32 q2, q2, %q[six] @ relu6 \n" \ + "vmin.f32 q3, q3, %q[six] @ relu6 \n" + +#define NCHWC1_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q7, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q8, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vmul.f32 q11, q2, %q[scale] \n" \ + "vmul.f32 q12, q3, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" \ + "vbif q2, q11, q7 @ choose \n" \ + "vbif q3, q12, q8 @ choose \n" + +#define NCHWC1_TRANS_FP32_STORE \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \ + "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \ + "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ + \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \ + "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \ + "vst1.32 {d6-d7}, [%[doutc0r0]]! @ store result, \n" \ + \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \ + \ + "bne 1b @ jump to main loop\n" +#endif +// clang-format on +inline void act_switch_c1_fp32(const float* din_ptr, + float* doutc0_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_RELU + NCHWC1_TRANS_FP32_RELU6 NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_LEAKY_RELU + NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", "v1", "v2", "v3", "v20"); +#else + asm volatile(NCHWC1_TRANS_FP32_COMPUTE NCHWC1_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} /*wirte result in outputs * input din: [n, c, h, w], output dout: [n, c, h, w] */ @@ -605,13 +838,14 @@ inline bool write_to_output_c1_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { if (cs > channel) { return true; } const int c1 = 1; - const int w4 = 4; + const int w4 = 16; int size_c_out = width * height; @@ -623,98 +857,53 @@ inline bool write_to_output_c1_fp32(const float* din, int w_round = we - ws; int cnt = (width - ws) / w4; - + int remain = (width - ws) % w4; for (int i = 0; i < size_h; i++) { int size_w = i * width; float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; const float* din_hei_ptr = ptr_din + i * w_round * c1; if (cnt > 0) { int cnt_loop = cnt; - if (flag_relu) { -#ifdef __aarch64__ - asm volatile( - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "movi v20.4s, #0 \n" /* for relu */ - "1: \n" /* main loop*/ - "fmax v1.4s, v0.4s, v20.4s \n" /*relu*/ - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "str q1, [%[doutc0r0]], #16 \n" /* store c0r0*/ - "bne 1b \n" /* jump to main loop*/ - : [doutc0r0] "+r"(doutc0_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v20"); -#else - asm volatile( - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" - "vmov.u32 q15, #0 @ dump zero\n" - "1: @ main loop\n" - - "vmax.f32 q1, q0, q15 @ relu\n" - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" - - "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q15"); -#endif - } else { -#ifdef __aarch64__ - asm volatile( - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "1: \n" /* main loop*/ - "str q0, [%[doutc0r0]], #16 \n" /* store c2r0*/ - "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ - "ldr q0, [%[ptr_din]], #16 \n" /* load data, c0r0, c0r1, c0r2, - c0r3 */ - "bne 1b \n" /* jump to main loop*/ - - : [doutc0r0] "+r"(doutc0_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0"); -#else - asm volatile( - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data, c0r0, " - "c0r1, c0r2, c0r3\n" - "1: @ main loop\n" - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " - "pointer\n" - "subs %[cnt], %[cnt], #1 @ loop count - 1\n" - "vld1.32 {d0-d1}, [%[ptr_din]]! @ load data \n" - "bne 1b @ jump to main loop\n" - - : [doutc0r0] "+r"(doutc0_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0"); -#endif - } + act_switch_c1_fp32(din_hei_ptr, doutc0_ptr, cnt_loop, act_param); } - if (we > width) { + if (remain > 0) { int offset = i * w_round * c1 + c1 * w4 * cnt; din_hei_ptr = ptr_din + offset; - int j = we - w4; - if (flag_relu) { - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - din_hei_ptr++; + doutc0_ptr += w4 * cnt; + int j = w4 * cnt; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + din_hei_ptr++; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < width; ++j) { + float tmp = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp, six); + din_hei_ptr++; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < width; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + din_hei_ptr++; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < width; ++j) { @@ -725,6 +914,7 @@ inline bool write_to_output_c1_fp32(const float* din, } return true; } +// clang-format off #ifdef __aarch64__ #define NCHWC2_TRANS_FP32_COMPUTE \ "ldp q0, q1, [%[ptr_din]], #32 \n" /* load data, c0r0, c1r0, c0r1*/ \ @@ -740,6 +930,18 @@ inline bool write_to_output_c1_fp32(const float* din, "fmax v2.4s, v4.4s, v20.4s \n" /*relu*/ \ "fmax v3.4s, v5.4s, v20.4s \n" /*relu*/ +#define NCHWC2_TRANS_FP32_RELU6 \ + "fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \ + "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC2_TRANS_FP32_LEAKY_RELU \ + "cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \ + "bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \ + "bif v3.16b, v5.16b, v7.16b \n" /* choose*/ + #define NCHWC2_TRANS_FP32_STORE \ "subs %w[cnt], %w[cnt], #1 \n" /* loop count -1*/ \ \ @@ -749,8 +951,7 @@ inline bool write_to_output_c1_fp32(const float* din, "bne 1b \n" /* jump to main loop*/ #else #define NCHWC2_TRANS_FP32_COMPUTE \ - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, " \ - "c1r0, c0r1, c1r1, , c0r2, c1r2, c0r3, c1r3\n" \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data, c0r0, c1r0 \n" \ "vmov.u32 q15, #0 @ dump zero\n" \ "1: @ main loop\n" \ "vtrn.32 d0, d1 @ trans data:c0r0, c0r1, " \ @@ -764,11 +965,21 @@ inline bool write_to_output_c1_fp32(const float* din, "vmax.f32 q0, q0, q15 @ relu\n" \ "vmax.f32 q1, q1, q15 @ relu\n" +#define NCHWC2_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" + +#define NCHWC2_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" + #define NCHWC2_TRANS_FP32_STORE \ - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \ - "pointer\n" \ - "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add " \ - "pointer\n" \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \ + "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \ \ "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ \ @@ -776,6 +987,151 @@ inline bool write_to_output_c1_fp32(const float* din, \ "bne 1b @ jump to main loop\n" #endif +// clang-format on +inline void act_switch_c2_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU + NCHWC2_TRANS_FP32_RELU6 NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_LEAKY_RELU + NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); +#else + asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} /*wirte result in outputs * input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] */ @@ -791,11 +1147,11 @@ inline bool write_to_output_c2_fp32(const float* din, int height, int width, bool flag_relu, - float* trash_ptr) { + float* trash_ptr, + operators::ActivationParam* act_param) { if (cs > channel) { return true; } - const int c2 = 2; const int w4 = 4; @@ -828,55 +1184,56 @@ inline bool write_to_output_c2_fp32(const float* din, const float* din_hei_ptr = ptr_din + i * w_round * c2; if (cnt > 0) { int cnt_loop = cnt; - if (flag_relu) { -#ifdef __aarch64__ - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU - NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v20"); -#else - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_RELU - NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); -#endif - } else { -#ifdef __aarch64__ - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5"); -#else - asm volatile(NCHWC2_TRANS_FP32_COMPUTE NCHWC2_TRANS_FP32_STORE - : [doutc0r0] "+r"(doutc0_ptr), - [doutc1r0] "+r"(doutc1_ptr), - [ptr_din] "+r"(din_hei_ptr), - [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q15"); -#endif - } + act_switch_c2_fp32( + din_hei_ptr, doutc0_ptr, doutc1_ptr, cnt_loop, act_param); } if (we > width) { int offset = i * w_round * c2 + c2 * w4 * cnt; din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; int j = we - w4; - if (flag_relu) { - for (; j < width; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - din_hei_ptr += 2; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < width; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + din_hei_ptr += 2; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < width; ++j) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + din_hei_ptr += 2; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < width; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + din_hei_ptr += 2; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < width; ++j) { @@ -888,7 +1245,7 @@ inline bool write_to_output_c2_fp32(const float* din, } return true; } - +// clang-format off #ifdef __aarch64__ #define NCHWC4_TRANS_FP32_COMPUTE \ "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ @@ -912,6 +1269,26 @@ inline bool write_to_output_c2_fp32(const float* din, "fmax v18.4s, v18.4s, v20.4s \n" /*relu*/ \ "fmax v19.4s, v19.4s, v20.4s \n" /*relu*/ +#define NCHWC4_TRANS_FP32_RELU6 \ + "fmin v16.4s, v16.4s, %[six].4s \n" /* relu6 */ \ + "fmin v17.4s, v17.4s, %[six].4s \n" /* relu6 */ \ + "fmin v18.4s, v18.4s, %[six].4s \n" /* relu6 */ \ + "fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */ + +#define NCHWC4_TRANS_FP32_LEAKY_RELU \ + "cmhs v8.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v9.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v10.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v11.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \ + "fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \ + "fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \ + "bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \ + "bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \ + "bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \ + "bif v19.16b, v7.16b, v11.16b \n" /* choose*/ + #define NCHWC4_TRANS_FP32_STORE \ "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ @@ -940,6 +1317,26 @@ inline bool write_to_output_c2_fp32(const float* din, "vmax.f32 q2, q2, q15 @ relu\n" \ "vmax.f32 q3, q3, q15 @ relu\n" +#define NCHWC4_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6 \n" \ + "vmin.f32 q1, q1, %q[six] @ relu6 \n" \ + "vmin.f32 q2, q2, %q[six] @ relu6 \n" \ + "vmin.f32 q3, q3, %q[six] @ relu6 \n" + +#define NCHWC4_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q5, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q6, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q7, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q8, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q9, q0, %q[scale] \n" \ + "vmul.f32 q10, q1, %q[scale] \n" \ + "vmul.f32 q11, q2, %q[scale] \n" \ + "vmul.f32 q12, q3, %q[scale] \n" \ + "vbif q0, q9, q5 @ choose \n" \ + "vbif q1, q10, q6 @ choose \n" \ + "vbif q2, q11, q7 @ choose \n" \ + "vbif q3, q12, q8 @ choose \n" + #define NCHWC4_TRANS_FP32_STORE \ "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer\n" \ "vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer\n" \ @@ -953,68 +1350,19 @@ inline bool write_to_output_c2_fp32(const float* din, \ "bne 1b @ jump to main loop\n" #endif -/*wirte result in outputs -* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] -*/ -inline bool write_to_output_c4_fp32(const float* din, - float* dout, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - float* trash_ptr) { - const int c4 = 4; - const int w4 = 4; - const int w_round = we - ws; - const int ch_n = ce - cs; - if (ch_n != 4) { - LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is " - "more than zero"; - return false; - } - int size_c_out = width * height; - - float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - float* doutc1r0 = doutc0r0 + size_c_out; - float* doutc2r0 = doutc1r0 + size_c_out; - float* doutc3r0 = doutc2r0 + size_c_out; - - const float* ptr_din = din; - - int size_h = (he > height ? height : he) - hs; // size_h == hei_n - - int valid_we = we > width ? width : we; - int cnt = (valid_we - ws) / w4; - int remain = valid_we - ws - cnt * w4; - - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 3: - doutc1_ptr = trash_ptr; - case 2: - doutc2_ptr = trash_ptr; - case 1: - doutc3_ptr = trash_ptr; - default: - break; - } - } - const float* din_hei_ptr = ptr_din + i * w_round * ch_n; - if (cnt > 0) { - int cnt_loop = cnt; - if (flag_relu) { +// clang-format on +inline void act_switch_c4_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + float* doutc2_ptr, + float* doutc3_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: #ifdef __aarch64__ asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU NCHWC4_TRANS_FP32_STORE @@ -1023,7 +1371,7 @@ inline bool write_to_output_c4_fp32(const float* din, [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) + [ptr_din] "+r"(din_ptr) : : "v0", "v1", @@ -1052,57 +1400,290 @@ inline bool write_to_output_c4_fp32(const float* din, [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) : : "q0", "q1", "q2", "q3", "q15"); #endif - } else { + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ #ifdef __aarch64__ - asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) : "v0", "v1", "v2", "v3", + "v4", + "v5", + "v6", + "v7", "v8", "v9", "v10", "v11", + "v12", + "v13", + "v14", "v16", "v17", "v18", - "v19"); + "v19", + "v20"); #else - asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), [doutc3r0] "+r"(doutc3_ptr), - [ptr_din] "+r"(din_hei_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3"); + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_LEAKY_RELU + NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q15"); #endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v8", + "v9", + "v10", + "v11", + "v16", + "v17", + "v18", + "v19"); +#else + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q15"); +#endif + } +} +/*wirte result in outputs +* input din: [n, c / 4, h, w * 4], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c4_fp32(const float* din, + float* dout, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr, + operators::ActivationParam* act_param) { + const int c4 = 4; + const int w4 = 4; + const int w_round = we - ws; + const int ch_n = ce - cs; + + if (ch_n != 4) { + LOG(ERROR) << "write_to_output_c4_fp32 ch_n must be equal 4 and hei_n is " + "more than zero"; + return false; + } + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + float* doutc2r0 = doutc1r0 + size_c_out; + float* doutc3r0 = doutc2r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_we = we > width ? width : we; + int cnt = (valid_we - ws) / w4; + int remain = valid_we - ws - cnt * w4; + + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 3: + doutc1_ptr = trash_ptr; + case 2: + doutc2_ptr = trash_ptr; + case 1: + doutc3_ptr = trash_ptr; + default: + break; } } + const float* din_hei_ptr = ptr_din + i * w_round * ch_n; + if (cnt > 0) { + int cnt_loop = cnt; + act_switch_c4_fp32(din_hei_ptr, + doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + cnt_loop, + act_param); + } if (remain > 0) { int offset = i * w_round * c4 + c4 * w4 * cnt; din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; + doutc2_ptr += w4 * cnt; + doutc3_ptr += w4 * cnt; int j = 0; - if (flag_relu) { - for (; j < remain; ++j) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); - din_hei_ptr += w4; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; j < remain; ++j) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + din_hei_ptr += 4; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; j < remain; ++j) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + float tmp3 = LITEMAX(din_hei_ptr[2], 0.f); + float tmp4 = LITEMAX(din_hei_ptr[3], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + *(doutc2_ptr++) = LITEMIN(tmp3, six); + *(doutc3_ptr++) = LITEMIN(tmp4, six); + din_hei_ptr += 4; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; j < remain; ++j) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + if (din_hei_ptr[2] >= 0) { + *(doutc2_ptr++) = din_hei_ptr[2]; + } else { + *(doutc2_ptr++) = din_hei_ptr[2] * scale; + } + if (din_hei_ptr[3] >= 0) { + *(doutc3_ptr++) = din_hei_ptr[3]; + } else { + *(doutc3_ptr++) = din_hei_ptr[3] * scale; + } + din_hei_ptr += 4; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; } } else { for (; j < remain; ++j) { @@ -1110,14 +1691,14 @@ inline bool write_to_output_c4_fp32(const float* din, *(doutc1_ptr++) = din_hei_ptr[1]; *(doutc2_ptr++) = din_hei_ptr[2]; *(doutc3_ptr++) = din_hei_ptr[3]; - din_hei_ptr += w4; + din_hei_ptr += 4; } } } } return true; } - +// clang-format off #ifdef __aarch64__ #define NCHWC8_TRANS_FP32_COMPUTE \ "ldp q0, q1, [%[ptr_din]], #32 \n" /* load r00, r01 to q0, q1 */ \ @@ -1161,6 +1742,48 @@ inline bool write_to_output_c4_fp32(const float* din, "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ +#define NCHWC8_TRANS_FP32_RELU6 \ + "fmin v16.4s, v16.4s, %[six].4s \n" /*relu6*/ \ + "fmin v17.4s, v17.4s, %[six].4s \n" /*relu6*/ \ + "fmin v18.4s, v18.4s, %[six].4s \n" /*relu6*/ \ + "fmin v19.4s, v19.4s, %[six].4s \n" /*relu6*/ \ + \ + "fmin v8.4s, v8.4s, %[six].4s \n" /*relu6*/ \ + "fmin v9.4s, v9.4s, %[six].4s \n" /*relu6*/ \ + "fmin v12.4s, v12.4s, %[six].4s \n" /*relu6*/ \ + "fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/ + +#define NCHWC8_TRANS_FP32_LEAKY_RELU \ + "cmhs v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ + \ + "cmhs v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \ + "cmhs v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \ + \ + "fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \ + "fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \ + "fmul v27.4s, v18.4s, %[scale].4s \n" /* mul */ \ + "fmul v28.4s, v19.4s, %[scale].4s \n" /* mul */ \ + \ + "fmul v29.4s, v8.4s, %[scale].4s \n" /* mul */ \ + "fmul v30.4s, v9.4s, %[scale].4s \n" /* mul */ \ + "fmul v31.4s, v12.4s, %[scale].4s \n" /* mul */ \ + \ + "bif v16.16b, v25.16b, v10.16b \n" /* choose*/ \ + "bif v17.16b, v26.16b, v11.16b \n" /* choose*/ \ + "bif v18.16b, v27.16b, v14.16b \n" /* choose*/ \ + "bif v19.16b, v28.16b, v15.16b \n" /* choose*/ \ + "fmul v25.4s, v13.4s, %[scale].4s \n" /* mul */ \ + \ + "bif v8.16b, v29.16b, v21.16b \n" /* choose*/ \ + "bif v9.16b, v30.16b, v22.16b \n" /* choose*/ \ + "bif v12.16b, v31.16b, v23.16b \n" /* choose*/ \ + "bif v13.16b, v25.16b, v24.16b \n" /* choose*/ + #define NCHWC8_TRANS_FP32_STORE \ "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ "str q17, [%[doutc2r0]], #16 \n" /* store c2r0*/ \ @@ -1174,6 +1797,7 @@ inline bool write_to_output_c4_fp32(const float* din, "str q13, [%[doutc7r0]], #16 \n" /* store c3r0*/ \ \ "bne 1b \n" /* jump to main loop*/ + #else #define NCHWC8_TRANS_FP32_COMPUTE \ "vld1.32 {d0-d3}, [%[ptr_din]]! @load data \n" \ @@ -1203,6 +1827,48 @@ inline bool write_to_output_c4_fp32(const float* din, "vmax.f32 q6, q6, q15 @ relu\n" \ "vmax.f32 q7, q7, q15 @ relu\n" +#define NCHWC8_TRANS_FP32_RELU6 \ + "vmin.f32 q0, q0, %q[six] @ relu6\n" \ + "vmin.f32 q1, q1, %q[six] @ relu6\n" \ + "vmin.f32 q2, q2, %q[six] @ relu6\n" \ + "vmin.f32 q3, q3, %q[six] @ relu6\n" \ + \ + "vmin.f32 q4, q4, %q[six] @ relu6\n" \ + "vmin.f32 q5, q5, %q[six] @ relu6\n" \ + "vmin.f32 q6, q6, %q[six] @ relu6\n" \ + "vmin.f32 q7, q7, %q[six] @ relu6\n" + +#define NCHWC8_TRANS_FP32_LEAKY_RELU \ + "vcge.f32 q9, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q10, q1, q15 @ q0 > 0 \n" \ + "vcge.f32 q11, q2, q15 @ q0 > 0 \n" \ + "vcge.f32 q12, q3, q15 @ q0 > 0 \n" \ + "vmul.f32 q13, q0, %q[scale] \n" \ + "vmul.f32 q14, q1, %q[scale] \n" \ + "vmul.f32 q15, q2, %q[scale] \n" \ + \ + "vbif q0, q13, q9 @ choose \n" \ + "vmul.f32 q9, q3, %q[scale] \n" \ + \ + "vbif q1, q14, q10 @ choose \n" \ + "vbif q2, q15, q11 @ choose \n" \ + "vbif q3, q9, q12 @ choose \n" \ + \ + "vcge.f32 q9, q4, q15 @ q0 > 0 \n" \ + "vcge.f32 q10, q5, q15 @ q0 > 0 \n" \ + "vcge.f32 q11, q6, q15 @ q0 > 0 \n" \ + "vcge.f32 q12, q7, q15 @ q0 > 0 \n" \ + "vmul.f32 q13, q4, %q[scale] \n" \ + "vmul.f32 q14, q5, %q[scale] \n" \ + "vmul.f32 q15, q6, %q[scale] \n" \ + \ + "vbif q4, q13, q9 @ choose \n" \ + "vmul.f32 q9, q7, %q[scale] \n" \ + \ + "vbif q5, q14, q10 @ choose \n" \ + "vbif q6, q15, q11 @ choose \n" \ + "vbif q7, q9, q12 @ choose \n" + #define NCHWC8_TRANS_FP32_STORE \ "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add " \ @@ -1232,84 +1898,23 @@ inline bool write_to_output_c4_fp32(const float* din, "bne 1b @ jump to main loop\n" #endif -/*wirte result in outputs -* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] -*/ -inline bool write_to_output_c8_fp32(const float* din, - float* dout, - int ch_n, - int hei_n, - int cs, - int ce, - int hs, - int he, - int ws, - int we, - int channel, - int height, - int width, - bool flag_relu, - float* trash_ptr) { - if (ch_n != 8 || hei_n <= 0) { - LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; - return false; - } - int size_c_out = width * height; - - float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; - float* doutc1r0 = doutc0r0 + size_c_out; - float* doutc2r0 = doutc1r0 + size_c_out; - float* doutc3r0 = doutc2r0 + size_c_out; - float* doutc4r0 = doutc3r0 + size_c_out; - float* doutc5r0 = doutc4r0 + size_c_out; - float* doutc6r0 = doutc5r0 + size_c_out; - float* doutc7r0 = doutc6r0 + size_c_out; - - const float* ptr_din = din; - - int size_h = (he > height ? height : he) - hs; // size_h == hei_n - - int valid_w = we - ws; - int cnt = valid_w / 4; - - if (we > width) { - cnt--; - } - if (flag_relu) { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - float* doutc4_ptr = doutc4r0 + size_w; - float* doutc5_ptr = doutc5r0 + size_w; - float* doutc6_ptr = doutc6r0 + size_w; - float* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const float* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; +// clang-format on +inline void act_switch_c8_fp32(const float* din_ptr, + float* doutc0_ptr, + float* doutc1_ptr, + float* doutc2_ptr, + float* doutc3_ptr, + float* doutc4_ptr, + float* doutc5_ptr, + float* doutc6_ptr, + float* doutc7_ptr, + int cnt_loop, + const operators::ActivationParam* act_param) { + if (act_param != nullptr && act_param->has_active) { + float32x4_t six = vdupq_n_f32(act_param->Relu_clipped_coef); + float32x4_t scale = vdupq_n_f32(act_param->Leaky_relu_alpha); + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: #ifdef __aarch64__ asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU NCHWC8_TRANS_FP32_STORE @@ -1322,9 +1927,10 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) + [ptr_din] "+r"(din_ptr) : - : "v1", + : "v0", + "v1", "v2", "v3", "v4", @@ -1338,7 +1944,6 @@ inline bool write_to_output_c8_fp32(const float* din, "v12", "v13", "v14", - "v15", "v16", "v17", "v18", @@ -1355,66 +1960,17 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc5r0] "+r"(doutc5_ptr), [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) : - : "q0", "q1", "q2", "q3", "q4", "q15"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); #endif - } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; - for (; i < width; ++i) { - *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); - *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); - *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); - *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); - *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f); - *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f); - *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f); - *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f); - din_hei_ptr += 8; - } - } - } - } else { - for (int i = 0; i < size_h; i++) { - int size_w = i * width; - float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; - float* doutc1_ptr = doutc1r0 + size_w; - float* doutc2_ptr = doutc2r0 + size_w; - float* doutc3_ptr = doutc3r0 + size_w; - float* doutc4_ptr = doutc4r0 + size_w; - float* doutc5_ptr = doutc5r0 + size_w; - float* doutc6_ptr = doutc6r0 + size_w; - float* doutc7_ptr = doutc7r0 + size_w; - if (ce > channel) { - switch (ce - channel) { - case 7: - doutc1_ptr = trash_ptr; - case 6: - doutc2_ptr = trash_ptr; - case 5: - doutc3_ptr = trash_ptr; - case 4: - doutc4_ptr = trash_ptr; - case 3: - doutc5_ptr = trash_ptr; - case 2: - doutc6_ptr = trash_ptr; - case 1: - doutc7_ptr = trash_ptr; - default: - break; - } - } - ptr_din = din + i * valid_w * ch_n; - const float* din_hei_ptr = ptr_din; - if (cnt > 0) { - int cnt_loop = cnt; + break; + case lite_api::ActivationType::kRelu6: +/* 0 <= din <= 6 */ #ifdef __aarch64__ - asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_RELU6 + NCHWC8_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), @@ -1424,8 +1980,8 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), [cnt] "+r"(cnt_loop), - [ptr_din] "+r"(din_hei_ptr) - : + [ptr_din] "+r"(din_ptr) + : [six] "w"(six) : "v0", "v1", "v2", @@ -1441,14 +1997,29 @@ inline bool write_to_output_c8_fp32(const float* din, "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20"); #else - asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + asm volatile(NCHWC4_TRANS_FP32_COMPUTE NCHWC4_TRANS_FP32_RELU + NCHWC4_TRANS_FP32_RELU6 NCHWC4_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : [six] "w"(six) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +/*din = din >= 0 ? din : din * scale*/ +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU + NCHWC8_TRANS_FP32_STORE : [doutc0r0] "+r"(doutc0_ptr), [doutc1r0] "+r"(doutc1_ptr), [doutc2r0] "+r"(doutc2_ptr), @@ -1457,16 +2028,323 @@ inline bool write_to_output_c8_fp32(const float* din, [doutc5r0] "+r"(doutc5_ptr), [doutc6r0] "+r"(doutc6_ptr), [doutc7r0] "+r"(doutc7_ptr), - [ptr_din] "+r"(din_hei_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : [scale] "w"(scale) + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22", + "v23", + "v24", + "v25", + "v26", + "v27", + "v28", + "v29", + "v30", + "v31"); +#else + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_LEAKY_RELU + NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_ptr), [cnt] "+r"(cnt_loop) - : - : "q0", "q1", "q2", "q3", "q4"); + : [scale] "w"(scale) + : "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [cnt] "+r"(cnt_loop), + [ptr_din] "+r"(din_ptr) + : + : "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(NCHWC8_TRANS_FP32_COMPUTE NCHWC8_TRANS_FP32_STORE + : [doutc0r0] "+r"(doutc0_ptr), + [doutc1r0] "+r"(doutc1_ptr), + [doutc2r0] "+r"(doutc2_ptr), + [doutc3r0] "+r"(doutc3_ptr), + [doutc4r0] "+r"(doutc4_ptr), + [doutc5r0] "+r"(doutc5_ptr), + [doutc6r0] "+r"(doutc6_ptr), + [doutc7r0] "+r"(doutc7_ptr), + [ptr_din] "+r"(din_ptr), + [cnt] "+r"(cnt_loop) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q15"); #endif + } +} + +/*wirte result in outputs +* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] +*/ +inline bool write_to_output_c8_fp32(const float* din, + float* dout, + int ch_n, + int hei_n, + int cs, + int ce, + int hs, + int he, + int ws, + int we, + int channel, + int height, + int width, + bool flag_relu, + float* trash_ptr, + operators::ActivationParam* act_param) { + if (ch_n != 8 || hei_n <= 0) { + LOG(ERROR) << "ch_n must be equal 8 and hei_n is more than zero"; + return false; + } + int size_c_out = width * height; + + float* doutc0r0 = dout + cs * size_c_out + hs * width + ws; + float* doutc1r0 = doutc0r0 + size_c_out; + float* doutc2r0 = doutc1r0 + size_c_out; + float* doutc3r0 = doutc2r0 + size_c_out; + float* doutc4r0 = doutc3r0 + size_c_out; + float* doutc5r0 = doutc4r0 + size_c_out; + float* doutc6r0 = doutc5r0 + size_c_out; + float* doutc7r0 = doutc6r0 + size_c_out; + + const float* ptr_din = din; + + int size_h = (he > height ? height : he) - hs; // size_h == hei_n + + int valid_w = we - ws; + int w4 = 4; + int cnt = valid_w / 4; + + if (we > width) { + cnt--; + } + for (int i = 0; i < size_h; i++) { + int size_w = i * width; + float* doutc0_ptr = doutc0r0 + size_w; // doutc0r0 + width; + float* doutc1_ptr = doutc1r0 + size_w; + float* doutc2_ptr = doutc2r0 + size_w; + float* doutc3_ptr = doutc3r0 + size_w; + float* doutc4_ptr = doutc4r0 + size_w; + float* doutc5_ptr = doutc5r0 + size_w; + float* doutc6_ptr = doutc6r0 + size_w; + float* doutc7_ptr = doutc7r0 + size_w; + if (ce > channel) { + switch (ce - channel) { + case 7: + doutc1_ptr = trash_ptr; + case 6: + doutc2_ptr = trash_ptr; + case 5: + doutc3_ptr = trash_ptr; + case 4: + doutc4_ptr = trash_ptr; + case 3: + doutc5_ptr = trash_ptr; + case 2: + doutc6_ptr = trash_ptr; + case 1: + doutc7_ptr = trash_ptr; + default: + break; } - if (we > width) { - int offset = 32 * (valid_w / 4 - 1); - din_hei_ptr = ptr_din + offset; - int i = we - 4; + } + ptr_din = din + i * valid_w * ch_n; + const float* din_hei_ptr = ptr_din; + if (cnt > 0) { + int cnt_loop = cnt; + act_switch_c8_fp32(din_hei_ptr, + doutc0_ptr, + doutc1_ptr, + doutc2_ptr, + doutc3_ptr, + doutc4_ptr, + doutc5_ptr, + doutc6_ptr, + doutc7_ptr, + cnt_loop, + act_param); + } + if (we > width) { + int offset = 32 * (valid_w / 4 - 1); + din_hei_ptr = ptr_din + offset; + doutc0_ptr += w4 * cnt; + doutc1_ptr += w4 * cnt; + doutc2_ptr += w4 * cnt; + doutc3_ptr += w4 * cnt; + doutc4_ptr += w4 * cnt; + doutc5_ptr += w4 * cnt; + doutc6_ptr += w4 * cnt; + doutc7_ptr += w4 * cnt; + int i = we - 4; + if (act_param != nullptr && act_param->has_active) { + float six = act_param->Relu_clipped_coef; + float scale = act_param->Leaky_relu_alpha; + switch (act_param->active_type) { + case lite_api::ActivationType::kRelu: + for (; i < width; ++i) { + *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); + *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); + *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); + *(doutc3_ptr++) = LITEMAX(din_hei_ptr[3], 0.f); + *(doutc4_ptr++) = LITEMAX(din_hei_ptr[4], 0.f); + *(doutc5_ptr++) = LITEMAX(din_hei_ptr[5], 0.f); + *(doutc6_ptr++) = LITEMAX(din_hei_ptr[6], 0.f); + *(doutc7_ptr++) = LITEMAX(din_hei_ptr[7], 0.f); + din_hei_ptr += 8; + } + break; + case lite_api::ActivationType::kRelu6: + /* 0 <= din <= 6 */ + for (; i < width; ++i) { + float tmp1 = LITEMAX(din_hei_ptr[0], 0.f); + float tmp2 = LITEMAX(din_hei_ptr[1], 0.f); + float tmp3 = LITEMAX(din_hei_ptr[2], 0.f); + float tmp4 = LITEMAX(din_hei_ptr[3], 0.f); + float tmp5 = LITEMAX(din_hei_ptr[4], 0.f); + float tmp6 = LITEMAX(din_hei_ptr[5], 0.f); + float tmp7 = LITEMAX(din_hei_ptr[6], 0.f); + float tmp8 = LITEMAX(din_hei_ptr[7], 0.f); + *(doutc0_ptr++) = LITEMIN(tmp1, six); + *(doutc1_ptr++) = LITEMIN(tmp2, six); + *(doutc2_ptr++) = LITEMIN(tmp3, six); + *(doutc3_ptr++) = LITEMIN(tmp4, six); + *(doutc4_ptr++) = LITEMIN(tmp5, six); + *(doutc5_ptr++) = LITEMIN(tmp6, six); + *(doutc6_ptr++) = LITEMIN(tmp7, six); + *(doutc7_ptr++) = LITEMIN(tmp8, six); + din_hei_ptr += 8; + } + break; + case lite_api::ActivationType::kLeakyRelu: + /*din = din >= 0 ? din : din * scale*/ + for (; i < width; ++i) { + if (din_hei_ptr[0] >= 0) { + *(doutc0_ptr++) = din_hei_ptr[0]; + } else { + *(doutc0_ptr++) = din_hei_ptr[0] * scale; + } + if (din_hei_ptr[1] >= 0) { + *(doutc1_ptr++) = din_hei_ptr[1]; + } else { + *(doutc1_ptr++) = din_hei_ptr[1] * scale; + } + if (din_hei_ptr[2] >= 0) { + *(doutc2_ptr++) = din_hei_ptr[2]; + } else { + *(doutc2_ptr++) = din_hei_ptr[2] * scale; + } + if (din_hei_ptr[3] >= 0) { + *(doutc3_ptr++) = din_hei_ptr[3]; + } else { + *(doutc3_ptr++) = din_hei_ptr[3] * scale; + } + if (din_hei_ptr[4] >= 0) { + *(doutc4_ptr++) = din_hei_ptr[4]; + } else { + *(doutc4_ptr++) = din_hei_ptr[4] * scale; + } + if (din_hei_ptr[4] >= 0) { + *(doutc5_ptr++) = din_hei_ptr[5]; + } else { + *(doutc5_ptr++) = din_hei_ptr[5] * scale; + } + if (din_hei_ptr[6] >= 0) { + *(doutc6_ptr++) = din_hei_ptr[6]; + } else { + *(doutc6_ptr++) = din_hei_ptr[6] * scale; + } + if (din_hei_ptr[7] >= 0) { + *(doutc7_ptr++) = din_hei_ptr[7]; + } else { + *(doutc7_ptr++) = din_hei_ptr[7] * scale; + } + din_hei_ptr += 8; + } + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param->active_type) + << " fuse not support"; + } + } else { for (; i < width; ++i) { *(doutc0_ptr++) = din_hei_ptr[0]; *(doutc1_ptr++) = din_hei_ptr[1]; diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index b6c3478880d5cb59999d23ff03e2e342708ca95b..503dab29b6c4f0b9d3ff30a89060e473194216a9 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, const float* weights, const float* bias, const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_3x3s2_depthwise_fp32(const float* i_data, @@ -67,6 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din, int pad, bool flag_bias, bool flag_relu, + const operators::ActivationParam act_param, ARMContext* ctx); void conv_depthwise_3x3s2_fp32(const float* din, diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index dc68e65f42a799d7fa7e8be75f5afcf3166b1df3..642d1c2c1b964b9553e522d70a086531f1706420 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -579,6 +579,7 @@ void conv_depthwise_3x3_fp32(const void* din, ARMContext* ctx, const float* scale) { auto paddings = *param.paddings; + auto act_param = param.activation_param; const int pad_h = paddings[0]; const int pad_w = paddings[2]; int stride = param.strides[1]; @@ -603,6 +604,7 @@ void conv_depthwise_3x3_fp32(const void* din, pad, flag_bias, flag_relu, + act_param, ctx); } else { conv_3x3s1_depthwise_fp32(reinterpret_cast(din), @@ -617,6 +619,7 @@ void conv_depthwise_3x3_fp32(const void* din, reinterpret_cast(weights), bias, param, + act_param, ctx); } diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 383934e5d51e0756cd3fdd3269a916dcc1431037..3e86e933a1624a32eb425261477a3c81c5e06f97 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -67,7 +67,7 @@ void ConvCompute::PrepareForRun() { impl_ = new DepthwiseConv; VLOG(3) << "invoking dw conv"; } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && - no_dilation) { + no_dilation && pads_all_equal) { /// winograd conv impl impl_ = new WinogradConv; VLOG(3) << "invoking winograd conv"; diff --git a/lite/operators/conv_op.cc b/lite/operators/conv_op.cc index 6dab55ff3b6c55e7763484d78c6c36bf85017128..d9c0ecb4fd8457782ac90850b8b6a002c7dfcffe 100644 --- a/lite/operators/conv_op.cc +++ b/lite/operators/conv_op.cc @@ -52,6 +52,34 @@ inline int ConvOutputSize(int input_size, return output_size; } +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilations, + const std::vector& strides, + const std::string padding_algorithm, + const lite::DDim data_dims, + const lite::DDim& ksize) { + // when padding_desc is "VALID" or "SAME" + if (padding_algorithm == "SAME") { + for (size_t i = 0; i < strides.size(); ++i) { + int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i]; + int pad_sum = std::max( + (out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2], + (int64_t)0); + int pad_0 = pad_sum / 2; + int pad_1 = pad_sum - pad_0; + // pad + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; + // dilation + *(dilations->begin() + i) = 1; + } + } else if (padding_algorithm == "VALID") { + for (auto& it : *paddings) { + it = 0; + } + } +} + bool ConvOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 3ab34bc1d0bd631b0641cebd3db29cfff9316bb0..24848803fb7ea2139f87aa5b5f2119592dc00084 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -137,34 +137,6 @@ class ConvOpLite : public OpLite { std::string padding_algorithm_{""}; }; -inline void UpdatePaddingAndDilation(std::vector* paddings, - std::vector* dilations, - const std::vector& strides, - const std::string padding_algorithm, - const lite::DDim data_dims, - const lite::DDim& ksize) { - // when padding_desc is "VALID" or "SAME" - if (padding_algorithm == "SAME") { - for (size_t i = 0; i < strides.size(); ++i) { - int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i]; - int pad_sum = std::max( - (out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2], - (int64_t)0); - int pad_0 = pad_sum / 2; - int pad_1 = pad_sum - pad_0; - // pad - *(paddings->begin() + i * 2) = pad_0; - *(paddings->begin() + i * 2 + 1) = pad_1; - // dilation - *(dilations->begin() + i) = 1; - } - } else if (padding_algorithm == "VALID") { - for (auto& it : *paddings) { - it = 0; - } - } -} - } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/tests/math/conv_compute_test.cc b/lite/tests/math/conv_compute_test.cc index bda50d35633c853ba6e8c8695d0175da38865d1c..367eb6c34761b8d0989da0d2e99aa00442d0c76b 100644 --- a/lite/tests/math/conv_compute_test.cc +++ b/lite/tests/math/conv_compute_test.cc @@ -59,6 +59,8 @@ DEFINE_bool(flag_bias, true, "with bias"); typedef paddle::lite::DDim DDim; typedef paddle::lite::Tensor Tensor; typedef paddle::lite::operators::ConvParam ConvParam; +typedef paddle::lite::operators::ActivationParam ActivationParam; + using paddle::lite::profile::Timer; DDim compute_out_dim(const DDim& dim_in, @@ -118,6 +120,13 @@ void test_conv_fp32(const std::vector& input_dims, param.dilations = std::make_shared>(dilas); param.fuse_relu = flag_relu; param.groups = group; + if (flag_relu) { + ActivationParam act_param; + act_param.has_active = true; + act_param.active_type = + (paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu + param.activation_param = act_param; + } param.output = new Tensor; param.output->set_precision(PRECISION(kFloat)); @@ -243,6 +252,7 @@ void test_conv_fp32(const std::vector& input_dims, << pads[2] << ", " << pads[3] << ", stride: " << strides[0] << ", " << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", group: " << group << ", bias: " << (flag_bias ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false") << ", threads: " << th << ", power_mode: " << cls @@ -255,6 +265,7 @@ void test_conv_fp32(const std::vector& input_dims, << ", pad: " << pads[0] << ", " << pads[1] << ", stride: " << strides[0] << ", " << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", group: " << group << ", bias: " << (flag_bias ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false") << ", threads: " << th << ", power_mode: " << cls