diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc index 34f1a30eaaba62f40d90fda6bf40baeb8ad2eb5b..9de59d2185debc30f8f9a002f977f29cbbf300d0 100644 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -614,11 +614,11 @@ void conv_depthwise_3x3s1_fp32(const float *din, "blt 3f \n" #define LEFT_RESULT_S1_LEAKY_RELU \ - "cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ - "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ + "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ \ "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ @@ -639,8 +639,8 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ - "cmhs v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ \ "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ \ @@ -657,10 +657,10 @@ void conv_depthwise_3x3s1_fp32(const float *din, "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ \ - "cmhs v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ + "fcmge v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ + "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ + "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ "cmp %w[cnt], #1 \n" \ "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ @@ -802,7 +802,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, #define MID_RESULT_S1_LEAKY_RELU \ "movi v21.4s, #0 \n" \ - "cmhs v18.4s, v12.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fcmge v18.4s, v12.4s, v21.4s \n" /* vcgeq_f32 */ \ "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -824,7 +824,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "cmhs v18.4s, v13.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fcmge v18.4s, v13.4s, v21.4s \n" /* vcgeq_f32 */ \ "fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \ \ "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -846,7 +846,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "cmhs v18.4s, v14.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fcmge v18.4s, v14.4s, v21.4s \n" /* vcgeq_f32 */ \ "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ \ "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -861,7 +861,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ "st1 {v14.4s}, [%[doutr2]], #16 \n" \ \ - "cmhs v18.4s, v15.4s, v21.4s \n" /* vcgeq_u32 */ \ + "fcmge v18.4s, v15.4s, v21.4s \n" /* vcgeq_f32 */ \ "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ \ "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ @@ -980,7 +980,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, #define RIGHT_RESULT_S1_LEAKY_RELU \ "movi v1.4s, #0 \n" \ - "cmhs v20.4s, v12.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fcmge v20.4s, v12.4s, v1.4s \n" /* vcgeq_f32 */ \ "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ \ "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ @@ -999,7 +999,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ - "cmhs v20.4s, v13.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fcmge v20.4s, v13.4s, v1.4s \n" /* vcgeq_f32 */ \ "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ "st1 {v12.4s}, [%[doutr0]], #16 \n" \ \ @@ -1017,7 +1017,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ \ - "cmhs v20.4s, v14.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fcmge v20.4s, v14.4s, v1.4s \n" /* vcgeq_f32 */ \ "fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \ "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ \ @@ -1028,7 +1028,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, \ "bif v14.16b, v24.16b, v18.16b \n" \ \ - "cmhs v20.4s, v15.4s, v1.4s \n" /* vcgeq_u32 */ \ + "fcmge v20.4s, v15.4s, v1.4s \n" /* vcgeq_f32 */ \ "fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \ \ "st1 {v14.4s}, [%[doutr2]], #16 \n" \ @@ -1128,18 +1128,18 @@ void conv_depthwise_3x3s1_fp32(const float *din, "st1 {v12.4s}, [%[out1]]\n" \ "st1 {v13.4s}, [%[out2]]\n" -#define RESULT_S_S1_LEAKY_RELU \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "cmhs v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "cmhs v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ - "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ - \ - "bif v12.16b, v20.16b, v18.16b \n" \ - "bif v13.16b, v21.16b, v19.16b \n" \ - "st1 {v12.4s}, [%[out1]]\n" \ +#define RESULT_S_S1_LEAKY_RELU \ + "prfm pldl1keep, [%[out1]]\n" \ + "prfm pldl1keep, [%[out2]]\n" \ + \ + "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ + "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ + \ + "bif v12.16b, v20.16b, v18.16b \n" \ + "bif v13.16b, v21.16b, v19.16b \n" \ + "st1 {v12.4s}, [%[out1]]\n" \ "st1 {v13.4s}, [%[out2]]\n" #define COMPUTE_S_S1_P0 \ "prfm pldl1keep, [%[din0]]\n" \ diff --git a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc index 7ca7536beb890ec419341776b9098340883753a5..55ea94949ba93396c97be5e3ea66d6e29ce95429 100644 --- a/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc @@ -179,11 +179,11 @@ namespace math { #define LEAKY_RELU \ "movi v0.4s, #0\n" /* for relu */ \ "ldr x0, [%[outl], #88]\n" \ - "cmhs v1.4s, v15.4s, v0.4s \n" /* vcgeq_u32 */ \ - "cmhs v2.4s, v16.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fcmge v1.4s, v15.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v2.4s, v16.4s, v0.4s \n" /* vcgeq_f32 */ \ "ld1 {v9.4s}, [x0] \n" \ - "cmhs v3.4s, v17.4s, v0.4s \n" /* vcgeq_u32 */ \ - "cmhs v4.4s, v18.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fcmge v3.4s, v17.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v4.4s, v18.4s, v0.4s \n" /* vcgeq_f32 */ \ "ldr x0, [%[outl]] \n" \ "fmul v5.4s, v15.4s, v9.4s \n" /* mul */ \ "fmul v6.4s, v16.4s, v9.4s \n" /* mul */ \ @@ -193,10 +193,10 @@ namespace math { "bif v16.16b, v6.16b, v2.16b \n" /* choose*/ \ "bif v17.16b, v7.16b, v3.16b \n" /* choose*/ \ "bif v18.16b, v8.16b, v4.16b \n" /* choose*/ \ - "cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ - "cmhs v2.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ - "cmhs v3.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ - "cmhs v4.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v2.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v3.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fcmge v4.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \ "fmul v5.4s, v19.4s, v9.4s \n" /* mul */ \ "fmul v6.4s, v20.4s, v9.4s \n" /* mul */ \ "fmul v7.4s, v21.4s, v9.4s \n" /* mul */ \ diff --git a/lite/backends/arm/math/conv3x3s2_direct_int8.cc b/lite/backends/arm/math/conv3x3s2_direct_int8.cc index 26829544bfd34d7acfc1d49086e86c3e0edad5f1..3d6f3dd743c3e46b6123f2c93dbfed586ad7b4c6 100644 --- a/lite/backends/arm/math/conv3x3s2_direct_int8.cc +++ b/lite/backends/arm/math/conv3x3s2_direct_int8.cc @@ -50,7 +50,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; int pad_h = paddings[0]; - int pad_w = paddings[1]; + int pad_w = paddings[2]; const int threads = ctx->threads(); int llc_size = ctx->llc_size() / 4; @@ -477,7 +477,7 @@ void conv_3x3s2_direct_int8(const int8_t* din, bool flag_relu = param.fuse_relu; bool flag_bias = param.bias; int pad_h = paddings[0]; - int pad_w = paddings[1]; + int pad_w = paddings[2]; const int threads = ctx->threads(); //! set 1/4 l2 cache int llc_size = ctx->llc_size() / 4; diff --git a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc index 3823c556f2c72096abb3e9502b26dc07a87c4523..3e5569365119b97397c6d42f48bacd2552b248e5 100644 --- a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc @@ -451,44 +451,44 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "blt 1f \n" -#define LEFT_RESULT_S2_LEAKY_RELU \ - "ld1 {v22.4s}, [%[scale_ptr]] \n" \ - "cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - \ - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ - \ - "fmul v12.4s, v16.4s, v22.4s \n" \ - "fadd v17.4s, v17.4s, v13.4s \n" \ - \ - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ - "ld1 {v15.4s}, [%[inptr0]] \n" \ - \ - "fadd v17.4s, v17.4s, v14.4s \n" \ - "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ - \ - "ld1 {v18.4s}, [%[inptr1]] \n" \ - "ld1 {v19.4s}, [%[inptr2]] \n" \ - \ - "ext v10.16b, v0.16b, v15.16b, #4 \n" \ - \ - "st1 {v16.4s}, [%[outptr0]], #16 \n" \ - "cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v12.4s, v16.4s, v22.4s \n" \ - \ - "ld1 {v20.4s}, [%[inptr3]] \n" \ - "ld1 {v21.4s}, [%[inptr4]] \n" \ - \ - "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ - "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ - \ - "cmp %w[cnt], #1 \n" \ - \ - "st1 {v17.4s}, [%[outptr1]], #16 \n" \ - "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ - \ +#define LEFT_RESULT_S2_LEAKY_RELU \ + "ld1 {v22.4s}, [%[scale_ptr]] \n" \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ "blt 1f \n" #define MID_RESULT_S2_RELU \ @@ -542,30 +542,30 @@ void conv_depthwise_3x3s2_fp32(const float* din, \ "bne 2b \n" -#define MID_RESULT_S2_LEAKY_RELU \ - "cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v12.4s, v16.4s, v22.4s \n" \ - \ - "fadd v17.4s, v17.4s, v13.4s \n" \ - \ - "ld1 {v19.4s}, [%[inptr2]] \n" \ - "ld1 {v20.4s}, [%[inptr3]] \n" \ - "ld1 {v21.4s}, [%[inptr4]] \n" \ - \ - "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ - "ext v10.16b, v0.16b, v15.16b, #4 \n" \ - "cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v12.4s, v17.4s, v22.4s \n" \ - \ - "st1 {v16.4s}, [%[outptr0]], #16 \n" \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ - "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ - "st1 {v17.4s}, [%[outptr1]], #16 \n" \ - \ - "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ - \ +#define MID_RESULT_S2_LEAKY_RELU \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ "bne 2b \n" #define RIGHT_RESULT_S2_RELU \ @@ -606,25 +606,25 @@ void conv_depthwise_3x3s2_fp32(const float* din, "st1 {v17.4s}, [%[outptr1]], #16 \n" \ "4: \n" -#define RIGHT_RESULT_S2_LEAKY_RELU \ - "cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v12.4s, v16.4s, v22.4s \n" \ - "fadd v17.4s, v17.4s, v13.4s \n" \ - \ - "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ - \ - "fadd v17.4s, v17.4s, v14.4s \n" \ - \ - "bif v16.16b, v0.16b, %[wmask].16b \n" \ - \ - "cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v12.4s, v17.4s, v22.4s \n" \ - \ - "st1 {v16.4s}, [%[outptr0]], #16 \n" \ - "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ - "bif v17.16b, v1.16b, %[wmask].16b \n" \ - \ - "st1 {v17.4s}, [%[outptr1]], #16 \n" \ +#define RIGHT_RESULT_S2_LEAKY_RELU \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ "4: \n" #define COMPUTE_S_S2 \ diff --git a/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc index 8ecc21134017d6469071eb2adc4b2215877c8437..4617d40f4372f6589f20b50205fb307cdc705808 100644 --- a/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc @@ -104,13 +104,13 @@ namespace math { "fmin v22.4s, v22.4s, %[vsix].4s\n" #define LEAKY_RELU /* LeakyRelu */ \ "movi v0.4s, #0\n" /* for relu */ \ - "cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ - "cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ - "cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ - "cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc index 1a2e42e0a9ca4193be84a21247112de8cdc144a1..daf3957bb1fe92cf9d979439407732bba3b0d9a4 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_fp32.cc @@ -13,9602 +13,750 @@ // limitations under the License. #include +#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif namespace paddle { namespace lite { namespace arm { namespace math { -//! weights layout -//! *-----------------------*-----* -//! w0 <-- | W0 W1 W2 W3 | W4 | -//! *-----------------------* | -//! w1 <-- | W5 W6 W7 W8 | W9 | -//! *-----------------------* | --> w5 -//! w2 <-- | W10 W11 W12 W13 | W14 | -//! *-----------------------* | -//! w3 <-- | W15 W16 W17 W18 | W19 | -//! *-----------------------*-----* -//! w4 <-- | W20 W21 W22 W23 | W24 | --> w6[0] -//! *-----------------------*-----* - -void conv_depthwise_5x5s1_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s1_small_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s1_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s1_small_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -static float* prepad_input( - const float* input, int num, int ch_in, int h_in, int w_in, int pad) { - int h_new = h_in + 2 * pad; - int w_new = w_in + 2 * pad; - float* new_input = - static_cast(malloc(h_new * w_new * ch_in * num * sizeof(float))); - float* new_input_ptr = new_input; - for (int c = 0; c < num * ch_in; ++c) { - memset(new_input_ptr, 0x00, w_new * pad * sizeof(float)); - new_input_ptr += w_new * pad; - for (int i = 0; i < h_in; ++i) { - memset(new_input_ptr, 0x00, pad * sizeof(float)); - new_input_ptr += pad; - memcpy(new_input_ptr, input, w_in * sizeof(float)); - new_input_ptr += w_in; - input += w_in; - memset(new_input_ptr, 0x00, pad * sizeof(float)); - new_input_ptr += pad; - } - memset(new_input_ptr, 0x00, w_new * pad * sizeof(float)); - new_input_ptr += w_new * pad; - } - return new_input; -} - -#ifdef __aarch64__ - -//! kernel for one out without extracting data mid -//! deal with four lines out -void compute_one_out_without_extract(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - float32x4_t w5, - float32x4_t w6, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! din0 - din7: 5 v20, v21 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // ext - "ext v22.16b, v20.16b, v21.16b, #4 \n" // 1 2 3 4 - "ext v23.16b, v20.16b, v21.16b, #8 \n" // 2 3 4 5 - "ext v24.16b, v20.16b, v21.16b, #12 \n" // 3 4 5 6 - - // in col5 - "fmla v16.4s, %[w5].4s, v20.4s \n" - "fmla v17.4s, %[w5].4s, v22.4s \n" - "fmla v18.4s, %[w5].4s, v23.4s \n" - "fmla v19.4s, %[w5].4s, v24.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // in[24] * w6[0] - "fmla v25.4s, v21.4s, %[w6].s[0]\n" - "fadd v25.4s, v25.4s, v31.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v31"); -} - -//! kernel for one out without extracting data mid -//! deal with four lines out -void compute_one_out_without_extract_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - float32x4_t w5, - float32x4_t w6, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! din0 - din7: 5 v20, v21 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // ext - "ext v22.16b, v20.16b, v21.16b, #4 \n" // 1 2 3 4 - "ext v23.16b, v20.16b, v21.16b, #8 \n" // 2 3 4 5 - "ext v24.16b, v20.16b, v21.16b, #12 \n" // 3 4 5 6 - - // in col5 - "fmla v16.4s, %[w5].4s, v20.4s \n" - "fmla v17.4s, %[w5].4s, v22.4s \n" - "fmla v18.4s, %[w5].4s, v23.4s \n" - "fmla v19.4s, %[w5].4s, v24.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "movi v30.4s, #0 \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // in[24] * w6[0] - "fmla v25.4s, v21.4s, %[w6].s[0] \n" - "fadd v25.4s, v25.4s, v31.4s \n" - "fmax v25.4s, v25.4s, v30.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v30", - "v31"); -} - -//! kernel for one out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_one_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" - "ldr q1, [%[wh]], #20 \n" - "ldr q2, [%[wh]], #20 \n" - "ldr q3, [%[wh]], #20 \n" - "ldr q4, [%[wh]], #20 \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v31"); -} - -//! kernel for one out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_one_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" - "ldr q1, [%[wh]], #20 \n" - "ldr q2, [%[wh]], #20 \n" - "ldr q3, [%[wh]], #20 \n" - "ldr q4, [%[wh]], #20 \n" - - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "movi v30.4s, #0 \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - "fmax v25.4s, v25.4s, v30.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v30", - "v31"); -} - -//! kernel for one out with extracting data post -//! deal with four lines out -void compute_one_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v31.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v31"); -} - -//! kernel for one out with extracting data post -//! deal with four lines out -void compute_one_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - "ld1 {v31.4s}, [%[bias]] \n" - "movi v30.4s, #0 \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v31.4s \n" - "fmax v25.4s, v25.4s, v30.4s \n" - - // write output - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v25", - "v26", - "v30", - "v31"); -} - -//! kernel for two out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_two_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for two out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_two_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for two out with extracting data post -//! deal with four lines out -void compute_two_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "movi v31.4s, #0 \n" - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for two out with extracting data post -//! deal with four lines out -void compute_two_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - asm volatile( - "movi v31.4s, #0 \n" - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v8.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v8.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [dout0] "r"(dout0), - [dout1] "r"(dout1), - [dout2] "r"(dout2), - [dout3] "r"(dout3), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v31"); -} - -//! kernel for three out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_three_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "st1 {v25.s}[0], [%[dout0]], #4 \n" - "st1 {v25.s}[1], [%[dout1]], #4 \n" - "st1 {v25.s}[2], [%[dout2]], #4 \n" - "st1 {v25.s}[3], [%[dout3]], #4 \n" - - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for three out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_three_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v16-v19 - //! weights: v0-v4 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]], #20 \n" // 21, 22, 23, 24 - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - "fmax v25.4s, v25.4s, v31.4s \n" - - // zip - "zip1 v6.4s, v7.4s, v5.4s \n" - "zip2 v8.4s, v7.4s, v5.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "st1 {v25.s}[0], [%[dout0]], #4 \n" - "st1 {v25.s}[1], [%[dout1]], #4 \n" - "st1 {v25.s}[2], [%[dout2]], #4 \n" - "st1 {v25.s}[3], [%[dout3]], #4 \n" - - "str d6, [%[dout0]] \n" - "str d7, [%[dout1]] \n" - "str d8, [%[dout2]] \n" - "str d9, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for three out with extracting data post -//! deal with four lines out -void compute_three_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v6, v8, v25 - asm volatile( - "movi v31.4s, #0 \n" - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]], #8 \n" - "str d7, [%[dout1]], #8 \n" - "str d8, [%[dout2]], #8 \n" - "str d9, [%[dout3]], #8 \n" - - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for three out with extracting data post -//! deal with four lines out -void compute_three_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v6, v8, v25 - asm volatile( - "movi v31.4s, #0 \n" - - // load inputs - "ld1 {v20.4s}, [%[bias]] \n" - "ld1 {v8.4s}, [%[din0]], #16 \n" - "ld1 {v9.4s}, [%[din1]], #16 \n" - "ld1 {v10.4s}, [%[din2]], #16 \n" - "ld1 {v11.4s}, [%[din3]], #16 \n" - "ld1 {v12.4s}, [%[din4]], #16 \n" - "ld1 {v13.4s}, [%[din5]], #16 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], #16 \n" - "ld1 {v15.4s}, [%[din7]], #16 \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v5 - "faddp v5.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v5.4s, v5.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v7 - "faddp v7.4s, v16.4s, v17.4s \n" - "faddp v6.4s, v18.4s, v19.4s \n" - "faddp v7.4s, v7.4s, v6.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - "fadd v25.4s, v25.4s, v20.4s \n" - "fmax v25.4s, v25.4s, v31.4s \n" - - // zip - "zip1 v6.4s, v5.4s, v7.4s \n" - "zip2 v8.4s, v5.4s, v7.4s \n" - - // add bias - "fadd v6.4s, v6.4s, v20.4s \n" - "fadd v8.4s, v8.4s, v20.4s \n" - - // relu - "fmax v6.4s, v6.4s, v31.4s \n" - "fmax v8.4s, v8.4s, v31.4s \n" - - "ext v7.16b, v6.16b, v31.16b, #8 \n" - "ext v9.16b, v8.16b, v31.16b, #8 \n" - - // write output - "str d6, [%[dout0]], #8 \n" - "str d7, [%[dout1]], #8 \n" - "str d8, [%[dout2]], #8 \n" - "str d9, [%[dout3]], #8 \n" - - "st1 {v25.s}[0], [%[dout0]] \n" - "st1 {v25.s}[1], [%[dout1]] \n" - "st1 {v25.s}[2], [%[dout2]] \n" - "st1 {v25.s}[3], [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v25", - "v26", - "v31"); -} - -//! kernel for four out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_four_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - //! weights: v0-v4, v5, v6 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "mov x0, #20 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]] \n" // 21, 22, 23, 24 - "sub %[wh], %[wh], #68 \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]] \n" - "ld1 {v9.4s}, [%[din1]] \n" - "ld1 {v10.4s}, [%[din2]] \n" - "ld1 {v11.4s}, [%[din3]] \n" - "ld1 {v12.4s}, [%[din4]] \n" - "ld1 {v13.4s}, [%[din5]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]] \n" - "ld1 {v15.4s}, [%[din7]] \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load weights col5 - "ld1 {v5.s}[0], [%[wh]], x0 \n" - "ld1 {v5.s}[1], [%[wh]], x0 \n" - "ld1 {v5.s}[2], [%[wh]], x0 \n" - "ld1 {v5.s}[3], [%[wh]], x0 \n" - "ld1 {v6.s}[0], [%[wh]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // load in col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v22.16b, v20.16b, v21.16b, #4 \n" - "ext v23.16b, v20.16b, v21.16b, #8 \n" - "ext v24.16b, v20.16b, v21.16b, #12 \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v22.4s \n" - "fmul v18.4s, v5.4s, v23.4s \n" - "fmul v19.4s, v5.4s, v24.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, v6.s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v28.4s, v26.4s \n" - "zip2 v2.4s, v28.4s, v26.4s \n" - "zip1 v4.4s, v27.4s, v25.4s \n" - "zip2 v6.4s, v27.4s, v25.4s \n" - - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [%[dout0]], #8 \n" - "str d1, [%[dout1]], #8 \n" - "str d2, [%[dout2]], #8 \n" - "str d3, [%[dout3]], #8 \n" - - "str d4, [%[dout0]] \n" - "str d5, [%[dout1]] \n" - "str d6, [%[dout2]] \n" - "str d7, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "x0", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -//! kernel for four out with extracting data pre -//! deal with four lines out -//! need extra load weights -void compute_four_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - const float* weights, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - //! weights: v0-v4, v5, v6 - asm volatile( - // load weights - "movi v31.4s, #0 \n" - "mov x0, #20 \n" - "add %[wh], %[wh], #4 \n" - "ldr q0, [%[wh]], #20 \n" // 1, 2, 3, 4 - "ldr q1, [%[wh]], #20 \n" // 6, 7, 8, 9 - "ldr q2, [%[wh]], #20 \n" // 11, 12, 13, 14 - "ldr q3, [%[wh]], #20 \n" // 16, 17, 18, 19 - "ldr q4, [%[wh]] \n" // 21, 22, 23, 24 - "sub %[wh], %[wh], #68 \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]] \n" - "ld1 {v9.4s}, [%[din1]] \n" - "ld1 {v10.4s}, [%[din2]] \n" - "ld1 {v11.4s}, [%[din3]] \n" - "ld1 {v12.4s}, [%[din4]] \n" - "ld1 {v13.4s}, [%[din5]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]] \n" - "ld1 {v15.4s}, [%[din7]] \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load weights col5 - "ld1 {v5.s}[0], [%[wh]], x0 \n" - "ld1 {v5.s}[1], [%[wh]], x0 \n" - "ld1 {v5.s}[2], [%[wh]], x0 \n" - "ld1 {v5.s}[3], [%[wh]], x0 \n" - "ld1 {v6.s}[0], [%[wh]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 2, 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 7, 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 12, 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 17, 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 22, 23, 24 - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // load in col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext weights - "ext v0.16b, v0.16b, v31.16b, #4 \n" // 3, 4 - "ext v1.16b, v1.16b, v31.16b, #4 \n" // 8, 9 - "ext v2.16b, v2.16b, v31.16b, #4 \n" // 13, 14 - "ext v3.16b, v3.16b, v31.16b, #4 \n" // 18, 19 - "ext v4.16b, v4.16b, v31.16b, #4 \n" // 23, 24 - - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, v0.4s, v8.4s \n" - "fmul v17.4s, v0.4s, v9.4s \n" - "fmul v18.4s, v0.4s, v10.4s \n" - "fmul v19.4s, v0.4s, v11.4s \n" - - // in row1 - "fmla v16.4s, v1.4s, v9.4s \n" - "fmla v17.4s, v1.4s, v10.4s \n" - "fmla v18.4s, v1.4s, v11.4s \n" - "fmla v19.4s, v1.4s, v12.4s \n" - - // in row2 - "fmla v16.4s, v2.4s, v10.4s \n" - "fmla v17.4s, v2.4s, v11.4s \n" - "fmla v18.4s, v2.4s, v12.4s \n" - "fmla v19.4s, v2.4s, v13.4s \n" - - // in row3 - "fmla v16.4s, v3.4s, v11.4s \n" - "fmla v17.4s, v3.4s, v12.4s \n" - "fmla v18.4s, v3.4s, v13.4s \n" - "fmla v19.4s, v3.4s, v14.4s \n" - - // in row4 - "fmla v16.4s, v4.4s, v12.4s \n" - "fmla v17.4s, v4.4s, v13.4s \n" - "fmla v18.4s, v4.4s, v14.4s \n" - "fmla v19.4s, v4.4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v22.16b, v20.16b, v21.16b, #4 \n" - "ext v23.16b, v20.16b, v21.16b, #8 \n" - "ext v24.16b, v20.16b, v21.16b, #12 \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v22.4s \n" - "fmul v18.4s, v5.4s, v23.4s \n" - "fmul v19.4s, v5.4s, v24.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, v6.s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v28.4s, v26.4s \n" - "zip2 v2.4s, v28.4s, v26.4s \n" - "zip1 v4.4s, v27.4s, v25.4s \n" - "zip2 v6.4s, v27.4s, v25.4s \n" - - // add bias - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - // relu - "fmax v0.4s, v0.4s, v31.4s \n" - "fmax v2.4s, v2.4s, v31.4s \n" - "fmax v4.4s, v4.4s, v31.4s \n" - "fmax v6.4s, v6.4s, v31.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [%[dout0]], #8 \n" - "str d1, [%[dout1]], #8 \n" - "str d2, [%[dout2]], #8 \n" - "str d3, [%[dout3]], #8 \n" - - "str d4, [%[dout0]] \n" - "str d5, [%[dout1]] \n" - "str d6, [%[dout2]] \n" - "str d7, [%[dout3]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [dout2] "+r"(dout2), - [dout3] "+r"(dout3), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "x0", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -//! kernel for four out with extracting data post -//! deal with four lines out -void compute_four_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - const int64_t s_12 = 12; - const float* doutl[4] = {dout0, dout1, dout2, dout3}; - void* doutl_ptr = reinterpret_cast(doutl); - asm volatile( - "movi v31.4s, #0 \n" - "ldp x0, x1, [%[doutl]], #16 \n" - "ldp x2, x3, [%[doutl]] \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]], %[s_12] \n" - "ld1 {v9.4s}, [%[din1]], %[s_12] \n" - "ld1 {v10.4s}, [%[din2]], %[s_12] \n" - "ld1 {v11.4s}, [%[din3]], %[s_12] \n" - "ld1 {v12.4s}, [%[din4]], %[s_12] \n" - "ld1 {v13.4s}, [%[din5]], %[s_12] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], %[s_12] \n" - "ld1 {v15.4s}, [%[din7]], %[s_12] \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load input col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // load input col5 - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v8.16b, v20.16b, v21.16b, #4 \n" - "ext v9.16b, v20.16b, v21.16b, #8 \n" - "ext v10.16b, v20.16b, v21.16b, #12 \n" - - // ext weights col0 - "ins v5.s[0], %[w0].s[0] \n" - "ins v5.s[1], %[w1].s[0] \n" - "ins v5.s[2], %[w2].s[0] \n" - "ins v5.s[3], %[w3].s[0] \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v8.4s \n" - "fmul v18.4s, v5.4s, v9.4s \n" - "fmul v19.4s, v5.4s, v10.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, %[w4].s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v25.4s, v27.4s \n" - "zip2 v2.4s, v25.4s, v27.4s \n" - "zip1 v4.4s, v26.4s, v28.4s \n" - "zip2 v6.4s, v26.4s, v28.4s \n" - - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [x0], #8 \n" - "str d1, [x1], #8 \n" - "str d2, [x2], #8 \n" - "str d3, [x3], #8 \n" - - "str d4, [x0] \n" - "str d5, [x1] \n" - "str d6, [x2] \n" - "str d7, [x3] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [doutl] "+r"(doutl_ptr) - : [s_12] "r"(s_12), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "x0", - "x1", - "x2", - "x3", - "v0", - "v1", - "v2", - "v3", - "v5", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -//! kernel for four out with extracting data post -//! deal with four lines out -void compute_four_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - const float* din6, - const float* din7, - float* dout0, - float* dout1, - float* dout2, - float* dout3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - const float* bias) { - //! din0 - din7: 0-4 v8-v15 - //! dout0 - dout3: v0-v3 - const int64_t s_12 = 12; - const float* doutl[4] = {dout0, dout1, dout2, dout3}; - void* doutl_ptr = reinterpret_cast(doutl); - asm volatile( - "movi v31.4s, #0 \n" - "ldp x0, x1, [%[doutl]], #16 \n" - "ldp x2, x3, [%[doutl]] \n" - - // load inputs - "ld1 {v8.4s}, [%[din0]], %[s_12] \n" - "ld1 {v9.4s}, [%[din1]], %[s_12] \n" - "ld1 {v10.4s}, [%[din2]], %[s_12] \n" - "ld1 {v11.4s}, [%[din3]], %[s_12] \n" - "ld1 {v12.4s}, [%[din4]], %[s_12] \n" - "ld1 {v13.4s}, [%[din5]], %[s_12] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - "ld1 {v14.4s}, [%[din6]], %[s_12] \n" - "ld1 {v15.4s}, [%[din7]], %[s_12] \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v25 - "faddp v25.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v25.4s, v25.4s, v26.4s \n" - - // load input col5 - "ld1 {v20.s}[0], [%[din0]] \n" - "ld1 {v20.s}[1], [%[din1]] \n" - "ld1 {v20.s}[2], [%[din2]] \n" - "ld1 {v20.s}[3], [%[din3]] \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // load input col5 - "ld1 {v21.s}[0], [%[din4]] \n" - "ld1 {v21.s}[1], [%[din5]] \n" - "ld1 {v21.s}[2], [%[din6]] \n" - "ld1 {v21.s}[3], [%[din7]] \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v27 - "faddp v27.4s, v16.4s, v17.4s \n" - "faddp v26.4s, v18.4s, v19.4s \n" - "faddp v27.4s, v27.4s, v26.4s \n" - - // ext input - "ext v8.16b, v8.16b, v31.16b, #4 \n" - "ext v9.16b, v9.16b, v31.16b, #4 \n" - "ext v10.16b, v10.16b, v31.16b, #4 \n" - "ext v11.16b, v11.16b, v31.16b, #4 \n" - "ext v12.16b, v12.16b, v31.16b, #4 \n" - "ext v13.16b, v13.16b, v31.16b, #4 \n" - "ext v14.16b, v14.16b, v31.16b, #4 \n" - "ext v15.16b, v15.16b, v31.16b, #4 \n" - - // in row0 - "fmul v16.4s, %[w0].4s, v8.4s \n" - "fmul v17.4s, %[w0].4s, v9.4s \n" - "fmul v18.4s, %[w0].4s, v10.4s \n" - "fmul v19.4s, %[w0].4s, v11.4s \n" - - // in row1 - "fmla v16.4s, %[w1].4s, v9.4s \n" - "fmla v17.4s, %[w1].4s, v10.4s \n" - "fmla v18.4s, %[w1].4s, v11.4s \n" - "fmla v19.4s, %[w1].4s, v12.4s \n" - - // in row2 - "fmla v16.4s, %[w2].4s, v10.4s \n" - "fmla v17.4s, %[w2].4s, v11.4s \n" - "fmla v18.4s, %[w2].4s, v12.4s \n" - "fmla v19.4s, %[w2].4s, v13.4s \n" - - // in row3 - "fmla v16.4s, %[w3].4s, v11.4s \n" - "fmla v17.4s, %[w3].4s, v12.4s \n" - "fmla v18.4s, %[w3].4s, v13.4s \n" - "fmla v19.4s, %[w3].4s, v14.4s \n" - - // in row4 - "fmla v16.4s, %[w4].4s, v12.4s \n" - "fmla v17.4s, %[w4].4s, v13.4s \n" - "fmla v18.4s, %[w4].4s, v14.4s \n" - "fmla v19.4s, %[w4].4s, v15.4s \n" - - // add to out register v26 - "faddp v26.4s, v16.4s, v17.4s \n" - "faddp v28.4s, v18.4s, v19.4s \n" - "faddp v26.4s, v26.4s, v28.4s \n" - - // ext input col5 - "ext v8.16b, v20.16b, v21.16b, #4 \n" - "ext v9.16b, v20.16b, v21.16b, #8 \n" - "ext v10.16b, v20.16b, v21.16b, #12 \n" - - // ext weights col0 - "ins v5.s[0], %[w0].s[0] \n" - "ins v5.s[1], %[w1].s[0] \n" - "ins v5.s[2], %[w2].s[0] \n" - "ins v5.s[3], %[w3].s[0] \n" - - // in col5 - "fmul v16.4s, v5.4s, v20.4s \n" - "fmul v17.4s, v5.4s, v8.4s \n" - "fmul v18.4s, v5.4s, v9.4s \n" - "fmul v19.4s, v5.4s, v10.4s \n" - - // add to out register v28 - "faddp v28.4s, v16.4s, v17.4s \n" - "faddp v29.4s, v18.4s, v19.4s \n" - "faddp v28.4s, v28.4s, v29.4s \n" - "fmla v28.4s, v21.4s, %[w4].s[0] \n" - - "ld1 {v8.4s}, [%[bias]] \n" - - // zip - "zip1 v0.4s, v25.4s, v27.4s \n" - "zip2 v2.4s, v25.4s, v27.4s \n" - "zip1 v4.4s, v26.4s, v28.4s \n" - "zip2 v6.4s, v26.4s, v28.4s \n" - - // add bias - "fadd v0.4s, v0.4s, v8.4s \n" - "fadd v2.4s, v2.4s, v8.4s \n" - "fadd v4.4s, v4.4s, v8.4s \n" - "fadd v6.4s, v6.4s, v8.4s \n" - - // relu - "fmax v0.4s, v0.4s, v31.4s \n" - "fmax v2.4s, v2.4s, v31.4s \n" - "fmax v4.4s, v4.4s, v31.4s \n" - "fmax v6.4s, v6.4s, v31.4s \n" - - "ext v1.16b, v0.16b, v31.16b, #8 \n" - "ext v3.16b, v2.16b, v31.16b, #8 \n" - "ext v5.16b, v4.16b, v31.16b, #8 \n" - "ext v7.16b, v6.16b, v31.16b, #8 \n" - - // write output - "str d0, [x0], #8 \n" - "str d1, [x1], #8 \n" - "str d2, [x2], #8 \n" - "str d3, [x3], #8 \n" - - "str d4, [x0] \n" - "str d5, [x1] \n" - "str d6, [x2] \n" - "str d7, [x3] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [din6] "+r"(din6), - [din7] "+r"(din7), - [doutl] "+r"(doutl_ptr) - : [s_12] "r"(s_12), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [bias] "r"(bias) - : "memory", - "x0", - "x1", - "x2", - "x3", - "v0", - "v1", - "v2", - "v3", - "v5", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v25", - "v26", - "v27", - "v28", - "v29", - "v31"); -} - -void conv_depthwise_5x5s1_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[8]; - const float* dinl[8]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 8; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 - pad_new > h_in) { - switch (h + 8 - pad_new - h_in) { - case 7: - din_list[1] = zero_ptr; - case 6: - din_list[2] = zero_ptr; - case 5: - din_list[3] = zero_ptr; - case 4: - din_list[4] = zero_ptr; - case 3: - din_list[5] = zero_ptr; - case 2: - din_list[6] = zero_ptr; - case 1: - din_list[7] = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - - //! every h loop, deal with 8 line input - dinl[0] = din_list[0]; - dinl[1] = din_list[1]; - dinl[2] = din_list[2]; - dinl[3] = din_list[3]; - dinl[4] = din_list[4]; - dinl[5] = din_list[5]; - dinl[6] = din_list[6]; - dinl[7] = din_list[7]; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - *dout_ptr2++ = bias_c; - *dout_ptr3++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_pre(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - //! mid loop - if (mid_cnt > 0) { - void* dinl_ptr = reinterpret_cast(dinl); - int mid_loop = mid_cnt; - asm volatile( - //! din: v7-v14 - //! dout: v15-v18 - "mov x0, #0 \n" - "mov x1, #4 \n" - "ldp x2, x3, [%[dinl]], #16 \n" - "ldp x4, x5, [%[dinl]], #16 \n" - "ldp x6, x7, [%[dinl]], #16 \n" - "ldp x8, x9, [%[dinl]], #16 \n" - - "ld1 {v7.4s} , [x2], x1 \n" - "ld1 {v8.4s} , [x3], x1 \n" - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - //! load bias - "ld1 {v19.4s}, [%[bias]] \n" - - "1: \n" - //! add bias to output - "mov v15.16b, v19.16b \n" - "mov v16.16b, v19.16b \n" - "mov v17.16b, v19.16b \n" - "mov v18.16b, v19.16b \n" - - //! loop cnt is even, prefetch 64 Byte to l1 cache - "cmp x0, #1 \n" - "bne 2f \n" - "mov x0, #0 \n" - "prfm pldl1keep, [x2] \n" - "prfm pldl1keep, [x3] \n" - "prfm pldl1keep, [x4] \n" - "prfm pldl1keep, [x5] \n" - "prfm pldl1keep, [x6] \n" - "prfm pldl1keep, [x7] \n" - "prfm pldl1keep, [x8] \n" - "prfm pldl1keep, [x9] \n" - - "2: \n" - // weights col 0 - "fmla v15.4s, v7.4s , %[w0].s[0] \n" - "fmla v16.4s, v8.4s , %[w0].s[0] \n" - "fmla v17.4s, v9.4s , %[w0].s[0] \n" - "fmla v18.4s, v10.4s, %[w0].s[0] \n" - - "fmla v15.4s, v8.4s , %[w1].s[0] \n" - "fmla v16.4s, v9.4s , %[w1].s[0] \n" - "fmla v17.4s, v10.4s, %[w1].s[0] \n" - "fmla v18.4s, v11.4s, %[w1].s[0] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[0] \n" - "fmla v16.4s, v10.4s, %[w2].s[0] \n" - "fmla v17.4s, v11.4s, %[w2].s[0] \n" - "fmla v18.4s, v12.4s, %[w2].s[0] \n" - - "fmla v15.4s, v10.4s, %[w3].s[0] \n" - "fmla v16.4s, v11.4s, %[w3].s[0] \n" - "fmla v17.4s, v12.4s, %[w3].s[0] \n" - "fmla v18.4s, v13.4s, %[w3].s[0] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[0] \n" - "fmla v16.4s, v12.4s, %[w4].s[0] \n" - "fmla v17.4s, v13.4s, %[w4].s[0] \n" - "fmla v18.4s, v14.4s, %[w4].s[0] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 1 - "fmla v15.4s, v7.4s , %[w0].s[1] \n" - "fmla v16.4s, v8.4s , %[w0].s[1] \n" - "fmla v17.4s, v9.4s , %[w0].s[1] \n" - "fmla v18.4s, v10.4s, %[w0].s[1] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[1] \n" - "fmla v16.4s, v9.4s , %[w1].s[1] \n" - "fmla v17.4s, v10.4s, %[w1].s[1] \n" - "fmla v18.4s, v11.4s, %[w1].s[1] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[1] \n" - "fmla v16.4s, v10.4s, %[w2].s[1] \n" - "fmla v17.4s, v11.4s, %[w2].s[1] \n" - "fmla v18.4s, v12.4s, %[w2].s[1] \n" - - "fmla v15.4s, v10.4s, %[w3].s[1] \n" - "fmla v16.4s, v11.4s, %[w3].s[1] \n" - "fmla v17.4s, v12.4s, %[w3].s[1] \n" - "fmla v18.4s, v13.4s, %[w3].s[1] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[1] \n" - "fmla v16.4s, v12.4s, %[w4].s[1] \n" - "fmla v17.4s, v13.4s, %[w4].s[1] \n" - "fmla v18.4s, v14.4s, %[w4].s[1] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 2 - "fmla v15.4s, v7.4s , %[w0].s[2] \n" - "fmla v16.4s, v8.4s , %[w0].s[2] \n" - "fmla v17.4s, v9.4s , %[w0].s[2] \n" - "fmla v18.4s, v10.4s, %[w0].s[2] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[2] \n" - "fmla v16.4s, v9.4s , %[w1].s[2] \n" - "fmla v17.4s, v10.4s, %[w1].s[2] \n" - "fmla v18.4s, v11.4s, %[w1].s[2] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[2] \n" - "fmla v16.4s, v10.4s, %[w2].s[2] \n" - "fmla v17.4s, v11.4s, %[w2].s[2] \n" - "fmla v18.4s, v12.4s, %[w2].s[2] \n" - - "fmla v15.4s, v10.4s, %[w3].s[2] \n" - "fmla v16.4s, v11.4s, %[w3].s[2] \n" - "fmla v17.4s, v12.4s, %[w3].s[2] \n" - "fmla v18.4s, v13.4s, %[w3].s[2] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[2] \n" - "fmla v16.4s, v12.4s, %[w4].s[2] \n" - "fmla v17.4s, v13.4s, %[w4].s[2] \n" - "fmla v18.4s, v14.4s, %[w4].s[2] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 3 - "fmla v15.4s, v7.4s , %[w0].s[3] \n" - "fmla v16.4s, v8.4s , %[w0].s[3] \n" - "fmla v17.4s, v9.4s , %[w0].s[3] \n" - "fmla v18.4s, v10.4s, %[w0].s[3] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[3] \n" - "fmla v16.4s, v9.4s , %[w1].s[3] \n" - "fmla v17.4s, v10.4s, %[w1].s[3] \n" - "fmla v18.4s, v11.4s, %[w1].s[3] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[3] \n" - "fmla v16.4s, v10.4s, %[w2].s[3] \n" - "fmla v17.4s, v11.4s, %[w2].s[3] \n" - "fmla v18.4s, v12.4s, %[w2].s[3] \n" - - "fmla v15.4s, v10.4s, %[w3].s[3] \n" - "fmla v16.4s, v11.4s, %[w3].s[3] \n" - "fmla v17.4s, v12.4s, %[w3].s[3] \n" - "fmla v18.4s, v13.4s, %[w3].s[3] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[3] \n" - "fmla v16.4s, v12.4s, %[w4].s[3] \n" - "fmla v17.4s, v13.4s, %[w4].s[3] \n" - "fmla v18.4s, v14.4s, %[w4].s[3] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 4 - "fmla v15.4s, v7.4s, %[w5].s[0] \n" - "fmla v16.4s, v8.4s, %[w5].s[0] \n" - "fmla v17.4s, v9.4s, %[w5].s[0] \n" - "fmla v18.4s, v10.4s, %[w5].s[0] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s, %[w5].s[1] \n" - "fmla v16.4s, v9.4s, %[w5].s[1] \n" - "fmla v17.4s, v10.4s, %[w5].s[1] \n" - "fmla v18.4s, v11.4s, %[w5].s[1] \n" - - "fmla v15.4s, v9.4s , %[w5].s[2] \n" - "fmla v16.4s, v10.4s, %[w5].s[2] \n" - "fmla v17.4s, v11.4s, %[w5].s[2] \n" - "fmla v18.4s, v12.4s, %[w5].s[2] \n" - - "fmla v15.4s, v10.4s, %[w5].s[3] \n" - "fmla v16.4s, v11.4s, %[w5].s[3] \n" - "fmla v17.4s, v12.4s, %[w5].s[3] \n" - "fmla v18.4s, v13.4s, %[w5].s[3] \n" - - "fmla v15.4s, v11.4s, %[w6].s[0] \n" - "fmla v16.4s, v12.4s, %[w6].s[0] \n" - "fmla v17.4s, v13.4s, %[w6].s[0] \n" - "fmla v18.4s, v14.4s, %[w6].s[0] \n" - - "st1 {v15.4s}, [%[dout0]], #16 \n" - "st1 {v16.4s}, [%[dout1]], #16 \n" - "st1 {v17.4s}, [%[dout2]], #16 \n" - "st1 {v18.4s}, [%[dout3]], #16 \n" - - "subs %w[cnt], %w[cnt], #1 \n" - "add x0, x0, #1 \n" - "bne 1b \n" - - : [dout0] "+r"(dout_ptr0), - [dout1] "+r"(dout_ptr1), - [dout2] "+r"(dout_ptr2), - [dout3] "+r"(dout_ptr3), - [cnt] "+r"(mid_loop), - [dinl] "+r"(dinl_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(vbias) - : "cc", - "memory", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7", - "x8", - "x9", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19"); - } - dinl[0] += 4 * mid_cnt; - dinl[1] += 4 * mid_cnt; - dinl[2] += 4 * mid_cnt; - dinl[3] += 4 * mid_cnt; - dinl[4] += 4 * mid_cnt; - dinl[5] += 4 * mid_cnt; - dinl[6] += 4 * mid_cnt; - dinl[7] += 4 * mid_cnt; - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - dinl[0]++; - dinl[1]++; - dinl[2]++; - dinl[3]++; - dinl[4]++; - dinl[5]++; - dinl[6]++; - dinl[7]++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_post(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din_list[0] = din_list[4]; - din_list[1] = din_list[5]; - din_list[2] = din_list[6]; - din_list[3] = din_list[7]; - din_list[4] = din_list[3] + w_in; - din_list[5] = din_list[4] + w_in; - din_list[6] = din_list[5] + w_in; - din_list[7] = din_list[6] + w_in; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[8]; - const float* dinl[8]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 8; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 - pad_new > h_in) { - switch (h + 8 - pad_new - h_in) { - case 7: - din_list[1] = zero_ptr; - case 6: - din_list[2] = zero_ptr; - case 5: - din_list[3] = zero_ptr; - case 4: - din_list[4] = zero_ptr; - case 3: - din_list[5] = zero_ptr; - case 2: - din_list[6] = zero_ptr; - case 1: - din_list[7] = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - - //! every h loop, deal with 8 line input - dinl[0] = din_list[0]; - dinl[1] = din_list[1]; - dinl[2] = din_list[2]; - dinl[3] = din_list[3]; - dinl[4] = din_list[4]; - dinl[5] = din_list[5]; - dinl[6] = din_list[6]; - dinl[7] = din_list[7]; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - *dout_ptr2++ = bias_relu; - *dout_ptr3++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_pre_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - weights_ptr, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - //! mid loop - if (mid_cnt > 0) { - void* dinl_ptr = reinterpret_cast(dinl); - int mid_loop = mid_cnt; - asm volatile( - //! din: v7-v14 - //! dout: v15-v18 - "mov x0, #0 \n" - "mov x1, #4 \n" - "movi v31.4s, #0 \n" - "ldp x2, x3, [%[dinl]], #16 \n" - "ldp x4, x5, [%[dinl]], #16 \n" - "ldp x6, x7, [%[dinl]], #16 \n" - "ldp x8, x9, [%[dinl]], #16 \n" - - "ld1 {v7.4s} , [x2], x1 \n" - "ld1 {v8.4s} , [x3], x1 \n" - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - //! load bias - "ld1 {v19.4s}, [%[bias]] \n" - - "1: \n" - //! add bias to output - "mov v15.16b, v19.16b \n" - "mov v16.16b, v19.16b \n" - "mov v17.16b, v19.16b \n" - "mov v18.16b, v19.16b \n" - - //! loop cnt is even, prefetch 64 Byte to l1 cache - "cmp x0, #1 \n" - "bne 2f \n" - "mov x0, #0 \n" - "prfm pldl1keep, [x2] \n" - "prfm pldl1keep, [x3] \n" - "prfm pldl1keep, [x4] \n" - "prfm pldl1keep, [x5] \n" - "prfm pldl1keep, [x6] \n" - "prfm pldl1keep, [x7] \n" - "prfm pldl1keep, [x8] \n" - "prfm pldl1keep, [x9] \n" - - "2: \n" - // weights col 0 - "fmla v15.4s, v7.4s , %[w0].s[0] \n" - "fmla v16.4s, v8.4s , %[w0].s[0] \n" - "fmla v17.4s, v9.4s , %[w0].s[0] \n" - "fmla v18.4s, v10.4s, %[w0].s[0] \n" - - "fmla v15.4s, v8.4s , %[w1].s[0] \n" - "fmla v16.4s, v9.4s , %[w1].s[0] \n" - "fmla v17.4s, v10.4s, %[w1].s[0] \n" - "fmla v18.4s, v11.4s, %[w1].s[0] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[0] \n" - "fmla v16.4s, v10.4s, %[w2].s[0] \n" - "fmla v17.4s, v11.4s, %[w2].s[0] \n" - "fmla v18.4s, v12.4s, %[w2].s[0] \n" - - "fmla v15.4s, v10.4s, %[w3].s[0] \n" - "fmla v16.4s, v11.4s, %[w3].s[0] \n" - "fmla v17.4s, v12.4s, %[w3].s[0] \n" - "fmla v18.4s, v13.4s, %[w3].s[0] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[0] \n" - "fmla v16.4s, v12.4s, %[w4].s[0] \n" - "fmla v17.4s, v13.4s, %[w4].s[0] \n" - "fmla v18.4s, v14.4s, %[w4].s[0] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 1 - "fmla v15.4s, v7.4s , %[w0].s[1] \n" - "fmla v16.4s, v8.4s , %[w0].s[1] \n" - "fmla v17.4s, v9.4s , %[w0].s[1] \n" - "fmla v18.4s, v10.4s, %[w0].s[1] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[1] \n" - "fmla v16.4s, v9.4s , %[w1].s[1] \n" - "fmla v17.4s, v10.4s, %[w1].s[1] \n" - "fmla v18.4s, v11.4s, %[w1].s[1] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[1] \n" - "fmla v16.4s, v10.4s, %[w2].s[1] \n" - "fmla v17.4s, v11.4s, %[w2].s[1] \n" - "fmla v18.4s, v12.4s, %[w2].s[1] \n" - - "fmla v15.4s, v10.4s, %[w3].s[1] \n" - "fmla v16.4s, v11.4s, %[w3].s[1] \n" - "fmla v17.4s, v12.4s, %[w3].s[1] \n" - "fmla v18.4s, v13.4s, %[w3].s[1] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[1] \n" - "fmla v16.4s, v12.4s, %[w4].s[1] \n" - "fmla v17.4s, v13.4s, %[w4].s[1] \n" - "fmla v18.4s, v14.4s, %[w4].s[1] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 2 - "fmla v15.4s, v7.4s , %[w0].s[2] \n" - "fmla v16.4s, v8.4s , %[w0].s[2] \n" - "fmla v17.4s, v9.4s , %[w0].s[2] \n" - "fmla v18.4s, v10.4s, %[w0].s[2] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[2] \n" - "fmla v16.4s, v9.4s , %[w1].s[2] \n" - "fmla v17.4s, v10.4s, %[w1].s[2] \n" - "fmla v18.4s, v11.4s, %[w1].s[2] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[2] \n" - "fmla v16.4s, v10.4s, %[w2].s[2] \n" - "fmla v17.4s, v11.4s, %[w2].s[2] \n" - "fmla v18.4s, v12.4s, %[w2].s[2] \n" - - "fmla v15.4s, v10.4s, %[w3].s[2] \n" - "fmla v16.4s, v11.4s, %[w3].s[2] \n" - "fmla v17.4s, v12.4s, %[w3].s[2] \n" - "fmla v18.4s, v13.4s, %[w3].s[2] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[2] \n" - "fmla v16.4s, v12.4s, %[w4].s[2] \n" - "fmla v17.4s, v13.4s, %[w4].s[2] \n" - "fmla v18.4s, v14.4s, %[w4].s[2] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 3 - "fmla v15.4s, v7.4s , %[w0].s[3] \n" - "fmla v16.4s, v8.4s , %[w0].s[3] \n" - "fmla v17.4s, v9.4s , %[w0].s[3] \n" - "fmla v18.4s, v10.4s, %[w0].s[3] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s , %[w1].s[3] \n" - "fmla v16.4s, v9.4s , %[w1].s[3] \n" - "fmla v17.4s, v10.4s, %[w1].s[3] \n" - "fmla v18.4s, v11.4s, %[w1].s[3] \n" - - "ld1 {v7.4s}, [x2], x1 \n" - "ld1 {v8.4s}, [x3], x1 \n" - - "fmla v15.4s, v9.4s , %[w2].s[3] \n" - "fmla v16.4s, v10.4s, %[w2].s[3] \n" - "fmla v17.4s, v11.4s, %[w2].s[3] \n" - "fmla v18.4s, v12.4s, %[w2].s[3] \n" - - "fmla v15.4s, v10.4s, %[w3].s[3] \n" - "fmla v16.4s, v11.4s, %[w3].s[3] \n" - "fmla v17.4s, v12.4s, %[w3].s[3] \n" - "fmla v18.4s, v13.4s, %[w3].s[3] \n" - - "ld1 {v9.4s} , [x4], x1 \n" - "ld1 {v10.4s}, [x5], x1 \n" - - "fmla v15.4s, v11.4s, %[w4].s[3] \n" - "fmla v16.4s, v12.4s, %[w4].s[3] \n" - "fmla v17.4s, v13.4s, %[w4].s[3] \n" - "fmla v18.4s, v14.4s, %[w4].s[3] \n" - - "ld1 {v11.4s}, [x6], x1 \n" - "ld1 {v12.4s}, [x7], x1 \n" - - // weights col 4 - "fmla v15.4s, v7.4s, %[w5].s[0] \n" - "fmla v16.4s, v8.4s, %[w5].s[0] \n" - "fmla v17.4s, v9.4s, %[w5].s[0] \n" - "fmla v18.4s, v10.4s, %[w5].s[0] \n" - - "ld1 {v13.4s}, [x8], x1 \n" - "ld1 {v14.4s}, [x9], x1 \n" - - "fmla v15.4s, v8.4s, %[w5].s[1] \n" - "fmla v16.4s, v9.4s, %[w5].s[1] \n" - "fmla v17.4s, v10.4s, %[w5].s[1] \n" - "fmla v18.4s, v11.4s, %[w5].s[1] \n" - - "fmla v15.4s, v9.4s , %[w5].s[2] \n" - "fmla v16.4s, v10.4s, %[w5].s[2] \n" - "fmla v17.4s, v11.4s, %[w5].s[2] \n" - "fmla v18.4s, v12.4s, %[w5].s[2] \n" - - "fmla v15.4s, v10.4s, %[w5].s[3] \n" - "fmla v16.4s, v11.4s, %[w5].s[3] \n" - "fmla v17.4s, v12.4s, %[w5].s[3] \n" - "fmla v18.4s, v13.4s, %[w5].s[3] \n" - - "fmla v15.4s, v11.4s, %[w6].s[0] \n" - "fmla v16.4s, v12.4s, %[w6].s[0] \n" - "fmla v17.4s, v13.4s, %[w6].s[0] \n" - "fmla v18.4s, v14.4s, %[w6].s[0] \n" - - "fmax v15.4s, v15.4s, v31.4s \n" - "fmax v16.4s, v16.4s, v31.4s \n" - "fmax v17.4s, v17.4s, v31.4s \n" - "fmax v18.4s, v18.4s, v31.4s \n" - - "st1 {v15.4s}, [%[dout0]], #16 \n" - "st1 {v16.4s}, [%[dout1]], #16 \n" - "st1 {v17.4s}, [%[dout2]], #16 \n" - "st1 {v18.4s}, [%[dout3]], #16 \n" - - "subs %w[cnt], %w[cnt], #1 \n" - "add x0, x0, #1 \n" - "bne 1b \n" - - : [dout0] "+r"(dout_ptr0), - [dout1] "+r"(dout_ptr1), - [dout2] "+r"(dout_ptr2), - [dout3] "+r"(dout_ptr3), - [cnt] "+r"(mid_loop), - [dinl] "+r"(dinl_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [w6] "w"(w6), - [bias] "r"(vbias) - : "cc", - "memory", - "x0", - "x1", - "x2", - "x3", - "x4", - "x5", - "x6", - "x7", - "x8", - "x9", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v31"); - } - dinl[0] += 4 * mid_cnt; - dinl[1] += 4 * mid_cnt; - dinl[2] += 4 * mid_cnt; - dinl[3] += 4 * mid_cnt; - dinl[4] += 4 * mid_cnt; - dinl[5] += 4 * mid_cnt; - dinl[6] += 4 * mid_cnt; - dinl[7] += 4 * mid_cnt; - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - dinl[0]++; - dinl[1]++; - dinl[2]++; - dinl[3]++; - dinl[4]++; - dinl[5]++; - dinl[6]++; - dinl[7]++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - break; - case 3: - compute_three_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - dout_ptr2 += 3; - dout_ptr3 += 3; - break; - case 2: - compute_two_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - dout_ptr2 += 2; - dout_ptr3 += 2; - break; - case 1: - compute_one_out_extract_post_relu(dinl[0], - dinl[1], - dinl[2], - dinl[3], - dinl[4], - dinl[5], - dinl[6], - dinl[7], - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - dout_ptr2 += 1; - dout_ptr3 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din_list[0] = din_list[4]; - din_list[1] = din_list[5]; - din_list[2] = din_list[6]; - din_list[3] = din_list[7]; - din_list[4] = din_list[3] + w_in; - din_list[5] = din_list[4] + w_in; - din_list[6] = din_list[5] + w_in; - din_list[7] = din_list[6] + w_in; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_small_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - //! every h loop, deal with 8 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - const float* din6 = din5 + w_in_new; - const float* din7 = din6 + w_in_new; - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 > h_in_new) { - switch (h + 8 - h_in_new) { - case 7: - din1 = zero_ptr; - case 6: - din2 = zero_ptr; - case 5: - din3 = zero_ptr; - case 4: - din4 = zero_ptr; - case 3: - din5 = zero_ptr; - case 2: - din6 = zero_ptr; - case 1: - din7 = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - const float* din_ptr7 = din7; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - *dout_ptr2++ = bias_c; - *dout_ptr3++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - din_ptr6, - din_ptr7, - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - din_ptr6++; - din_ptr7++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din7; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - din6 = din5 + w_in_new; - din7 = din6 + w_in_new; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } - free(din_new); -} - -void conv_depthwise_5x5s1_small_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - - //! every h loop, deal with 8 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - const float* din6 = din5 + w_in_new; - const float* din7 = din6 + w_in_new; - //! every h loop, deal with 4 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - float* dout2 = dout1 + w_out; - float* dout3 = dout2 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - float32x4_t w5; - float32x4_t w6; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 5); - float32x4_t w2 = vld1q_f32(weights_c + 10); - float32x4_t w3 = vld1q_f32(weights_c + 15); - float32x4_t w4 = vld1q_f32(weights_c + 20); - w5 = vsetq_lane_f32(weights_c[4], w5, 0); - w5 = vsetq_lane_f32(weights_c[9], w5, 1); - w5 = vsetq_lane_f32(weights_c[14], w5, 2); - w5 = vsetq_lane_f32(weights_c[19], w5, 3); - w6 = vsetq_lane_f32(weights_c[24], w6, 0); - - //! h loop - for (int h = 0; h < h_out_new; h += 4) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 8 > h_in_new) { - switch (h + 8 - h_in_new) { - case 7: - din1 = zero_ptr; - case 6: - din2 = zero_ptr; - case 5: - din3 = zero_ptr; - case 4: - din4 = zero_ptr; - case 3: - din5 = zero_ptr; - case 2: - din6 = zero_ptr; - case 1: - din7 = zero_ptr; - default: - break; - } - } - if (h + 4 > h_out_new) { - switch (h + 4 - h_out_new) { - case 3: - dout1 = write_ptr; - case 2: - dout2 = write_ptr; - case 1: - dout3 = write_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - const float* din_ptr7 = din7; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - float* dout_ptr2 = dout2; - float* dout_ptr3 = dout3; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - vst1q_f32(dout_ptr2, vbias_c); - vst1q_f32(dout_ptr3, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - dout_ptr2 += 4; - dout_ptr3 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - *dout_ptr2++ = bias_relu; - *dout_ptr3++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - dout_ptr2 += pad_0; - dout_ptr3 += pad_0; - } - - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - din_ptr6, - din_ptr7, - dout_ptr0, - dout_ptr1, - dout_ptr2, - dout_ptr3, - w0, - w1, - w2, - w3, - w4, - w5, - w6, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - din_ptr6++; - din_ptr7++; - - dout_ptr0++; - dout_ptr1++; - dout_ptr2++; - dout_ptr3++; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - memcpy(dout_ptr2, dout2, pad_0 * sizeof(float)); - memcpy(dout_ptr3, dout3, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr2, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr3, 0x00, pad_0 * sizeof(float)); - } - - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din7; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - din6 = din5 + w_in_new; - din7 = din6 + w_in_new; - - dout0 = dout3 + w_out; - dout1 = dout0 + w_out; - dout2 = dout1 + w_out; - dout3 = dout2 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } - free(din_new); -} - -#else - -//! kernel for one out without extracting data mid -//! deal with two lines out -void compute_one_out_without_extract(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d6[0]}, [%[din0]] \n" - "vld1.32 {d6[1]}, [%[din1]] \n" - "vld1.32 {d7[0]}, [%[din2]] \n" - "vld1.32 {d7[1]}, [%[din3]] \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d8[0]}, [%[din4]] \n" - "vld1.32 {d8[1]}, [%[din5]] \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights col4 - "sub %[wh], #64 \n" - "vld1.32 {d4[0]}, [%[wh]], r0 \n" - "vld1.32 {d4[1]}, [%[wh]], r0 \n" - "vld1.32 {d5[0]}, [%[wh]], r0 \n" - "vld1.32 {d5[1]}, [%[wh]], r0 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vext.32 q5, q3, q4, #1 \n" - - "vmla.f32 q9, q2, q3 \n" - "vmla.f32 q10, q2, q5 \n" - - "vld1.32 {d4[0]}, [%[wh]] \n" - "vld1.32 {d6}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - "vmla.f32 d18, d8, d4[0] \n" - - // add bias - "vadd.f32 d18, d18, d6 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -} - -//! kernel for one out without extracting data mid -//! deal with two lines out -void compute_one_out_without_extract_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vmov.i32 q15, #0x0 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d6[0]}, [%[din0]] \n" - "vld1.32 {d6[1]}, [%[din1]] \n" - "vld1.32 {d7[0]}, [%[din2]] \n" - "vld1.32 {d7[1]}, [%[din3]] \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d8[0]}, [%[din4]] \n" - "vld1.32 {d8[1]}, [%[din5]] \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights col4 - "sub %[wh], #64 \n" - "vld1.32 {d4[0]}, [%[wh]], r0 \n" - "vld1.32 {d4[1]}, [%[wh]], r0 \n" - "vld1.32 {d5[0]}, [%[wh]], r0 \n" - "vld1.32 {d5[1]}, [%[wh]], r0 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vext.32 q5, q3, q4, #1 \n" - - "vmla.f32 q9, q2, q3 \n" - "vmla.f32 q10, q2, q5 \n" - - "vld1.32 {d4[0]}, [%[wh]] \n" - "vld1.32 {d6}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - "vmla.f32 d18, d8, d4[0] \n" - - // add bias - "vadd.f32 d18, d18, d6 \n" - - // relu - "vmax.f32 d18, d18, d30 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -} - -//! kernel for one out without extracting data pre -//! deal with two lines out -void compute_one_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #4 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - // load bias - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -} - -//! kernel for one out without extracting data pre -//! deal with two lines out -void compute_one_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #4 \n" - "vmov.i32 q15, #0x0 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - // load bias - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - // relu - "vmax.f32 d18, d18, d30 \n" - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -} - -//! kernel for one out with extracting data post -//! deal with two lines out -void compute_one_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -} - -//! kernel for one out with extracting data post -//! deal with two lines out -void compute_one_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vmov.i32 q15, #0x0 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q0, q4 \n" - "vmla.f32 q10, q0, q5 \n" - - "vld1.32 {d0-d1}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q1, q5 \n" - "vmla.f32 q10, q1, q6 \n" - - // weights r4 - "vmla.f32 q9, q0, q6 \n" - "vmla.f32 q10, q0, q7 \n" - - "vld1.32 {d0}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d0 \n" - - // relu - "vmax.f32 d18, d18, d30 \n" - - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q15"); -} - -//! kernel for two out with extracting data pre -//! deal with two lines out -void compute_two_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #0 \n" - "add %[wh], #8 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d22, d23 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for two out with extracting data pre -//! deal with two lines out -void compute_two_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #0 \n" - "add %[wh], #8 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d22, d23 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q9 \n" - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for two out with extracting data post -//! deal with two lines out -void compute_two_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - //! out zero - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - "vpadd.f32 d22, d22, d23 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - "vext.32 q8, q8, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for two out with extracting data post -//! deal with two lines out -void compute_two_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]]! \n" - "vld1.32 {d6-d7}, [%[din1]]! \n" - "vld1.32 {d8-d9}, [%[din2]]! \n" - "vld1.32 {d10-d11}, [%[din3]]! \n" - "vld1.32 {d12-d13}, [%[din4]]! \n" - "vld1.32 {d14-d15}, [%[din5]]! \n" - - //! out zero - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d22, d18, d19 \n" - "vpadd.f32 d23, d20, d21 \n" - "vpadd.f32 d22, d22, d23 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - "vext.32 q8, q8, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q9 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [dout0] "r"(dout0), [dout1] "r"(dout1), [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data pre -//! deal with two lines out -void compute_three_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #12 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - "vpadd.f32 d22, d18, d20 \n" - - //! out one - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d30 \n" - - // store result - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data pre -//! deal with two lines out -void compute_three_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #12 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - "vpadd.f32 d22, d18, d20 \n" - - //! out one - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q8, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q8 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d18, d18, d19 \n" - - // add bias - "vadd.f32 d18, d18, d30 \n" - - // relu - "vmax.f32 d18, d18, d16 \n" - - // store result - "vst1.32 {d18[0]}, [%[dout0]] \n" - "vst1.32 {d18[1]}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_three_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero && two - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - "vadd.f32 d16, d16, d30 \n" - - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - "vst1.32 {d16[0]}, [%[dout0]]! \n" - "vst1.32 {d16[1]}, [%[dout1]]! \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_three_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out zero && two - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - //! out one - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - // load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q9, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - "vadd.f32 d16, d16, d30 \n" - - // relu - "vmax.f32 q11, q11, q9 \n" - "vmax.f32 d16, d16, d18 \n" - - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - "vst1.32 {d16[0]}, [%[dout0]]! \n" - "vst1.32 {d16[1]}, [%[dout1]]! \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for four out with extracting data pre -//! deal with two lines out -void compute_four_out_extract_pre(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #16 \n" - - //! out zero - // load input - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - // load weights - "vld1.32 d0[0], [%[wh]], r0 \n" - "vld1.32 d0[1], [%[wh]], r0 \n" - "vld1.32 d1[0], [%[wh]], r0 \n" - "vld1.32 d1[1], [%[wh]], r0 \n" - "vld1.32 d2[0], [%[wh]]\n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmla.f32 d22, d6, d2[0] \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out one - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - - "vpadd.f32 d23, d18, d20 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d18, d19 \n" - - //! out three - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for four out with extracting data pre -//! deal with two lines out -void compute_four_out_extract_pre_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "add %[wh], #16 \n" - - //! out zero - // load input - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - // load weights - "vld1.32 d0[0], [%[wh]], r0 \n" - "vld1.32 d0[1], [%[wh]], r0 \n" - "vld1.32 d1[0], [%[wh]], r0 \n" - "vld1.32 d1[1], [%[wh]], r0 \n" - "vld1.32 d2[0], [%[wh]]\n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d22, d18, d19 \n" - - "vmla.f32 d22, d6, d2[0] \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0}, [%[wh]], r0 \n" - "vld1.32 {d2}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]] \n" - "vld1.32 {d6-d7}, [%[din1]] \n" - "vld1.32 {d8-d9}, [%[din2]] \n" - "vld1.32 {d10-d11}, [%[din3]] \n" - "vld1.32 {d12-d13}, [%[din4]] \n" - "vld1.32 {d14-d15}, [%[din5]] \n" - - //! out one - // weights r0 - "vmul.f32 d18, d0, d4 \n" - "vmul.f32 d20, d0, d6 \n" - - "vld1.32 {d24}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 d18, d2, d6 \n" - "vmla.f32 d20, d2, d8 \n" - - "vld1.32 {d26}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 d18, d24, d8 \n" - "vmla.f32 d20, d24, d10 \n" - - "vld1.32 {d28}, [%[wh]] \n" - - // weights r3 - "vmla.f32 d18, d26, d10 \n" - "vmla.f32 d20, d26, d12 \n" - - // weights r4 - "vmla.f32 d18, d28, d12 \n" - "vmla.f32 d20, d28, d14 \n" - - "vpadd.f32 d23, d18, d20 \n" - "vmov.i32 q8, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q8 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out two - "mov r1, #0 \n" - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vmov.32 d1[1], r1 \n" - "vmov.32 d3[1], r1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - "vmov.32 d25[1], r1 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vmov.32 d27[1], r1 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - "vld1.32 {d28-d29}, [%[wh]]\n" - "vmov.32 d29[1], r1 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "sub %[wh], #84 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - "vld1.32 {d28-d29}, [%[wh]]\n" - - "vpadd.f32 d22, d18, d19 \n" - - //! out three - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q8 \n" - - // store result - "vst1.32 {d22}, [%[dout0]] \n" - "vst1.32 {d23}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), [dout1] "+r"(dout1), [wh] "+r"(weights) - : [din0] "r"(din0), - [din1] "r"(din1), - [din2] "r"(din2), - [din3] "r"(din3), - [din4] "r"(din4), - [din5] "r"(din5), - [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_four_out_extract_post(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #12 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]], r1 \n" - "vld1.32 {d6-d7}, [%[din1]], r1 \n" - "vld1.32 {d8-d9}, [%[din2]], r1 \n" - "vld1.32 {d10-d11}, [%[din3]], r1 \n" - "vld1.32 {d12-d13}, [%[din4]], r1 \n" - "vld1.32 {d14-d15}, [%[din5]], r1 \n" - - //! out zero && two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - //! out one - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out three - "sub %[wh], #80 \n" - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - "vld1.32 {d0[0]}, [%[wh]], r0 \n" - "vld1.32 {d0[1]}, [%[wh]], r0 \n" - "vld1.32 {d1[0]}, [%[wh]], r0 \n" - "vld1.32 {d1[1]}, [%[wh]], r0 \n" - "vld1.32 {d2[0]}, [%[wh]] \n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d20, d20, d21 \n" - "vpadd.f32 d17, d18, d20 \n" - - "vmla.f32 d17, d6, d2[0] \n" - - // trn out neon register - "vtrn.32 d16, d17 \n" - - // add bias - "vadd.f32 q8, q8, q15 \n" - - // store result - "vst1.32 {d16}, [%[dout0]] \n" - "vst1.32 {d17}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -//! kernel for three out with extracting data post -//! deal with two lines out -void compute_four_out_extract_post_relu(const float* din0, - const float* din1, - const float* din2, - const float* din3, - const float* din4, - const float* din5, - float* dout0, - float* dout1, - const float* weights, - const float* bias) { - asm volatile( - "mov r0, #20 \n" - "mov r1, #12 \n" - "vld1.32 {d0-d1}, [%[wh]], r0 \n" - "vld1.32 {d2-d3}, [%[wh]], r0 \n" - - "vld1.32 {d4-d5}, [%[din0]], r1 \n" - "vld1.32 {d6-d7}, [%[din1]], r1 \n" - "vld1.32 {d8-d9}, [%[din2]], r1 \n" - "vld1.32 {d10-d11}, [%[din3]], r1 \n" - "vld1.32 {d12-d13}, [%[din4]], r1 \n" - "vld1.32 {d14-d15}, [%[din5]], r1 \n" - - //! out zero && two - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - "vmul.f32 d16, d0, d5 \n" - "vmul.f32 d17, d0, d7 \n" - - "vld1.32 {d24-d25}, [%[wh]], r0 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - "vmla.f32 d16, d2, d7 \n" - "vmla.f32 d17, d2, d9 \n" - - "vld1.32 {d26-d27}, [%[wh]], r0 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - "vmla.f32 d16, d24, d9 \n" - "vmla.f32 d17, d24, d11 \n" - - "vld1.32 {d28-d29}, [%[wh]] \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - "vmla.f32 d16, d26, d11 \n" - "vmla.f32 d17, d26, d13 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - "vmla.f32 d16, d28, d13 \n" - "vmla.f32 d17, d28, d15 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d16, d16, d17 \n" - "vpadd.f32 d22, d18, d19 \n" - - //! out one - "vmov.f32 q15, #0.0 \n" - "vext.32 q2, q2, q15, #1 \n" - "vext.32 q3, q3, q15, #1 \n" - "vext.32 q4, q4, q15, #1 \n" - "vext.32 q5, q5, q15, #1 \n" - "vext.32 q6, q6, q15, #1 \n" - "vext.32 q7, q7, q15, #1 \n" - - // weights r0 - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q3 \n" - - // weights r1 - "vmla.f32 q9, q1, q3 \n" - "vmla.f32 q10, q1, q4 \n" - - // weights r2 - "vmla.f32 q9, q12, q4 \n" - "vmla.f32 q10, q12, q5 \n" - - // weights r3 - "vmla.f32 q9, q13, q5 \n" - "vmla.f32 q10, q13, q6 \n" - - // weights r4 - "vmla.f32 q9, q14, q6 \n" - "vmla.f32 q10, q14, q7 \n" - - "vld1.32 {d30-d31}, [%[bias]] \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d19, d20, d21 \n" - "vpadd.f32 d23, d18, d19 \n" - "vmov.i32 q5, #0x0 \n" - - // trn out neon register - "vtrn.32 d22, d23 \n" - - // add bias - "vadd.f32 q11, q11, q15 \n" - - // relu - "vmax.f32 q11, q11, q5 \n" - - // store result - "vst1.32 {d22}, [%[dout0]]! \n" - "vst1.32 {d23}, [%[dout1]]! \n" - - //! out three - "sub %[wh], #80 \n" - "vld1.32 {d4[0]}, [%[din0]] \n" - "vld1.32 {d4[1]}, [%[din1]] \n" - "vld1.32 {d5[0]}, [%[din2]] \n" - "vld1.32 {d5[1]}, [%[din3]] \n" - "vld1.32 {d6[0]}, [%[din4]] \n" - "vld1.32 {d6[1]}, [%[din5]] \n" - - "vext.32 q4, q2, q3, #1 \n" - - "vld1.32 {d0[0]}, [%[wh]], r0 \n" - "vld1.32 {d0[1]}, [%[wh]], r0 \n" - "vld1.32 {d1[0]}, [%[wh]], r0 \n" - "vld1.32 {d1[1]}, [%[wh]], r0 \n" - "vld1.32 {d2[0]}, [%[wh]] \n" - - "vmul.f32 q9, q0, q2 \n" - "vmul.f32 q10, q0, q4 \n" - - "vpadd.f32 d18, d18, d19 \n" - "vpadd.f32 d20, d20, d21 \n" - "vpadd.f32 d17, d18, d20 \n" - - "vmla.f32 d17, d6, d2[0] \n" - - // trn out neon register - "vtrn.32 d16, d17 \n" - - // add bias - "vadd.f32 q8, q8, q15 \n" - - // relu - "vmax.f32 q8, q8, q5 \n" - - // store result - "vst1.32 {d16}, [%[dout0]] \n" - "vst1.32 {d17}, [%[dout1]] \n" - - : [dout0] "+r"(dout0), - [dout1] "+r"(dout1), - [din0] "+r"(din0), - [din1] "+r"(din1), - [din2] "+r"(din2), - [din3] "+r"(din3), - [din4] "+r"(din4), - [din5] "+r"(din5), - [wh] "+r"(weights) - : [bias] "r"(bias) - : "memory", - "r0", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -} - -void conv_depthwise_5x5s1_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, +#define ROUNDUP(a, b) ((((a) + (b)-1) / (b)) * (b)) +#ifdef __aarch64__ +void conv_depthwise_5x5s1_fp32(float* dout, + const float* din, const float* weights, const float* bias, - int pad, bool flag_bias, bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + const operators::ConvParam& param, ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[6]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 6; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - //! every h loop, deal with 6 line input - const float* din0 = din_list[0]; - const float* din1 = din_list[1]; - const float* din2 = din_list[2]; - const float* din3 = din_list[3]; - const float* din4 = din_list[4]; - const float* din5 = din_list[5]; - - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 6 - pad_new > h_in) { - switch (h + 6 - pad_new - h_in) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_pre(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - //! mid loop - if (mid_cnt > 0) { - int mid_loop = mid_cnt; - const float* weights_ptr = weights_c; - asm volatile( - //! din: q7-q12 - //! dout: q13, q14 - "mov r1, #20 \n" - //! load weights - "vld1.32 {d0-d1}, [%[wh]], r1 \n" - "vld1.32 {d2-d3}, [%[wh]], r1 \n" - "vld1.32 {d4-d5}, [%[wh]], r1 \n" - "vld1.32 {d6-d7}, [%[wh]], r1 \n" - "vld1.32 {d8-d9}, [%[wh]] \n" - - "sub %[wh], #64 \n" - "vld1.32 {d10[0]}, [%[wh]], r1 \n" - "vld1.32 {d10[1]}, [%[wh]], r1 \n" - "vld1.32 {d11[0]}, [%[wh]], r1 \n" - "vld1.32 {d11[1]}, [%[wh]], r1 \n" - "vld1.32 {d12[0]}, [%[wh]] \n" - - //! load input - "mov r1, #4 \n" - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - //! load bias - "vld1.32 {d30-d31}, [%[bias]] \n" - - "1: \n" - //! add bias to output - "vmov.32 q13, q15 \n" - "vmov.32 q14, q15 \n" - - "pld [%[din0]] \n" - "pld [%[din1]] \n" - "pld [%[din2]] \n" - "pld [%[din3]] \n" - "pld [%[din4]] \n" - "pld [%[din5]] \n" - - // weights col 0 - "vmla.f32 q13, q7, d0[0] \n" - "vmla.f32 q14, q8, d0[0] \n" - - "vmla.f32 q13, q8, d2[0] \n" - "vmla.f32 q14, q9, d2[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[0] \n" - "vmla.f32 q14, q10, d4[0] \n" - - "vmla.f32 q13, q10, d6[0] \n" - "vmla.f32 q14, q11, d6[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[0] \n" - "vmla.f32 q14, q12, d8[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 1 - "vmla.f32 q13, q7, d0[1] \n" - "vmla.f32 q14, q8, d0[1] \n" - - "vmla.f32 q13, q8, d2[1] \n" - "vmla.f32 q14, q9, d2[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[1] \n" - "vmla.f32 q14, q10, d4[1] \n" - - "vmla.f32 q13, q10, d6[1] \n" - "vmla.f32 q14, q11, d6[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[1] \n" - "vmla.f32 q14, q12, d8[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 2 - "vmla.f32 q13, q7, d1[0] \n" - "vmla.f32 q14, q8, d1[0] \n" - - "vmla.f32 q13, q8, d3[0] \n" - "vmla.f32 q14, q9, d3[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[0] \n" - "vmla.f32 q14, q10, d5[0] \n" - - "vmla.f32 q13, q10, d7[0] \n" - "vmla.f32 q14, q11, d7[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[0] \n" - "vmla.f32 q14, q12, d9[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 3 - "vmla.f32 q13, q7, d1[1] \n" - "vmla.f32 q14, q8, d1[1] \n" - - "vmla.f32 q13, q8, d3[1] \n" - "vmla.f32 q14, q9, d3[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[1] \n" - "vmla.f32 q14, q10, d5[1] \n" - - "vmla.f32 q13, q10, d7[1] \n" - "vmla.f32 q14, q11, d7[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[1] \n" - "vmla.f32 q14, q12, d9[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 4 - "vmla.f32 q13, q7, d10[0] \n" - "vmla.f32 q14, q8, d10[0] \n" - - "vmla.f32 q13, q8, d10[1] \n" - "vmla.f32 q14, q9, d10[1] \n" - - "vmla.f32 q13, q9, d11[0] \n" - "vmla.f32 q14, q10, d11[0] \n" - - "vmla.f32 q13, q10, d11[1] \n" - "vmla.f32 q14, q11, d11[1] \n" - - "vmla.f32 q13, q11, d12[0] \n" - "vmla.f32 q14, q12, d12[0] \n" - - // store reslult - "vst1.32 {d26-d27}, [%[out0]]! \n" - "vst1.32 {d28-d29}, [%[out1]]! \n" - - "subs %[cnt], #1 \n" - "bne 1b \n" - - "sub %[din0], r1 \n" - "sub %[din1], r1 \n" - "sub %[din2], r1 \n" - "sub %[din3], r1 \n" - "sub %[din4], r1 \n" - "sub %[din5], r1 \n" - - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3), - [din4] "+r"(din_ptr4), - [din5] "+r"(din_ptr5), - [out0] "+r"(dout_ptr0), - [out1] "+r"(dout_ptr1), - [wh] "+r"(weights_ptr), - [cnt] "+r"(mid_loop) - : [bias] "r"(vbias) - : "cc", - "memory", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_post(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in; - din5 = din4 + w_in; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_out_new = h_out - 2 * pad_0; - int mid_out = w_out - 2 * pad; - int mid_cnt = mid_out >> 2; - int mid_remain = mid_out - (mid_cnt << 2); - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + auto act_param = param.activation_param; + const int hout_c_block = 4; + const int hout_r_kernel = 2; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 4; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(float) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(float) + //! win_round = wout_round + 4 + //! hin_r_block = hout_r_block + 4 + int hout_r_block = (llc_size - 16 * win_round * hout_c_block * threads) / + (win_round * hout_c_block * threads * 4 + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 4; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + float* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 25; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - const float* din_list[6]; - //! set din ptr with zero buffer - for (int i = 0; i < pad_new; ++i) { - din_list[i] = zero_ptr; - } - //! set din ptr with input data - for (int i = pad_new; i < 6; ++i) { - din_list[i] = din_ch; - din_ch += w_in; - } - //! every h loop, deal with 6 line input - const float* din0 = din_list[0]; - const float* din1 = din_list[1]; - const float* din2 = din_list[2]; - const float* din3 = din_list[3]; - const float* din4 = din_list[4]; - const float* din5 = din_list[5]; - - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - //! load weights to neon register - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 7 > h_in - 1 - if (h + 6 - pad_new > h_in) { - switch (h + 6 - pad_new - h_in) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; + const float* din_batch = din + n * chin * size_in_channel; + float* dout_batch = dout + n * chout * size_out_channel; + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel + 4; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size); + float* pre_out = pre_din + pre_in_size; +#else + float pre_din = tmp_din; + float* pre_out = pre_din + pre_in_size; +#endif + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin, ptr_zero); + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + const float* block_inr4 = block_inr3 + in_len; + const float* block_inr5 = block_inr4 + in_len; + + const float* weight_c = weights + c * w_stride; + float bias_local[4] = {0, 0, 0, 0}; if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - - //! deal with w_out pad_new column pre - switch (pad_new) { - case 4: - compute_four_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_pre_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - //! mid loop - if (mid_cnt > 0) { - int mid_loop = mid_cnt; - const float* weights_ptr = weights_c; + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + const float* inr5 = block_inr5; + + float* ptr_out0 = pre_out + hk * out_row_stride; + float* ptr_out1 = ptr_out0 + out_row_stride; + // clang-format off + auto wptr = weight_c; asm volatile( - //! din: q7-q12 - //! dout: q13, q14 - "mov r1, #20 \n" - "vmov.i32 q15, #0x0 \n" - //! load weights - "vld1.32 {d0-d1}, [%[wh]], r1 \n" - "vld1.32 {d2-d3}, [%[wh]], r1 \n" - "vld1.32 {d4-d5}, [%[wh]], r1 \n" - "vld1.32 {d6-d7}, [%[wh]], r1 \n" - "vld1.32 {d8-d9}, [%[wh]] \n" - - "sub %[wh], #64 \n" - "vld1.32 {d10[0]}, [%[wh]], r1 \n" - "vld1.32 {d10[1]}, [%[wh]], r1 \n" - "vld1.32 {d11[0]}, [%[wh]], r1 \n" - "vld1.32 {d11[1]}, [%[wh]], r1 \n" - "vld1.32 {d12[0]}, [%[wh]] \n" - - //! load input - "mov r1, #4 \n" - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - "1: \n" - - //! load bias to output - "vld1.32 {d26-d27}, [%[bias]] \n" - "vld1.32 {d28-d29}, [%[bias]] \n" - - "pld [%[din0]] \n" - "pld [%[din1]] \n" - "pld [%[din2]] \n" - "pld [%[din3]] \n" - "pld [%[din4]] \n" - "pld [%[din5]] \n" - - // weights col 0 - "vmla.f32 q13, q7, d0[0] \n" - "vmla.f32 q14, q8, d0[0] \n" - - "vmla.f32 q13, q8, d2[0] \n" - "vmla.f32 q14, q9, d2[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[0] \n" - "vmla.f32 q14, q10, d4[0] \n" - - "vmla.f32 q13, q10, d6[0] \n" - "vmla.f32 q14, q11, d6[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[0] \n" - "vmla.f32 q14, q12, d8[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 1 - "vmla.f32 q13, q7, d0[1] \n" - "vmla.f32 q14, q8, d0[1] \n" - - "vmla.f32 q13, q8, d2[1] \n" - "vmla.f32 q14, q9, d2[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d4[1] \n" - "vmla.f32 q14, q10, d4[1] \n" - - "vmla.f32 q13, q10, d6[1] \n" - "vmla.f32 q14, q11, d6[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d8[1] \n" - "vmla.f32 q14, q12, d8[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 2 - "vmla.f32 q13, q7, d1[0] \n" - "vmla.f32 q14, q8, d1[0] \n" - - "vmla.f32 q13, q8, d3[0] \n" - "vmla.f32 q14, q9, d3[0] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[0] \n" - "vmla.f32 q14, q10, d5[0] \n" - - "vmla.f32 q13, q10, d7[0] \n" - "vmla.f32 q14, q11, d7[0] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[0] \n" - "vmla.f32 q14, q12, d9[0] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 3 - "vmla.f32 q13, q7, d1[1] \n" - "vmla.f32 q14, q8, d1[1] \n" - - "vmla.f32 q13, q8, d3[1] \n" - "vmla.f32 q14, q9, d3[1] \n" - - "vld1.32 {d14-d15}, [%[din0]], r1 \n" - "vld1.32 {d16-d17}, [%[din1]], r1 \n" - - "vmla.f32 q13, q9, d5[1] \n" - "vmla.f32 q14, q10, d5[1] \n" - - "vmla.f32 q13, q10, d7[1] \n" - "vmla.f32 q14, q11, d7[1] \n" - - "vld1.32 {d18-d19}, [%[din2]], r1 \n" - "vld1.32 {d20-d21}, [%[din3]], r1 \n" - - "vmla.f32 q13, q11, d9[1] \n" - "vmla.f32 q14, q12, d9[1] \n" - - "vld1.32 {d22-d23}, [%[din4]], r1 \n" - "vld1.32 {d24-d25}, [%[din5]], r1 \n" - - // weights col 4 - "vmla.f32 q13, q7, d10[0] \n" - "vmla.f32 q14, q8, d10[0] \n" - - "vmla.f32 q13, q8, d10[1] \n" - "vmla.f32 q14, q9, d10[1] \n" - - "vmla.f32 q13, q9, d11[0] \n" - "vmla.f32 q14, q10, d11[0] \n" - - "vmla.f32 q13, q10, d11[1] \n" - "vmla.f32 q14, q11, d11[1] \n" - - "vmla.f32 q13, q11, d12[0] \n" - "vmla.f32 q14, q12, d12[0] \n" - - // relu - "vmax.f32 q13, q13, q15 \n" - "vmax.f32 q14, q14, q15 \n" - - // store result - "vst1.32 {d26-d27}, [%[out0]]! \n" - "vst1.32 {d28-d29}, [%[out1]]! \n" - - "subs %[cnt], #1 \n" - "bne 1b \n" - - "sub %[din0], r1 \n" - "sub %[din1], r1 \n" - "sub %[din2], r1 \n" - "sub %[din3], r1 \n" - "sub %[din4], r1 \n" - "sub %[din5], r1 \n" - - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3), - [din4] "+r"(din_ptr4), - [din5] "+r"(din_ptr5), - [out0] "+r"(dout_ptr0), - [out1] "+r"(dout_ptr1), - [wh] "+r"(weights_ptr), - [cnt] "+r"(mid_loop) - : [bias] "r"(vbias) - : "cc", - "memory", - "r1", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - } - //! deal with mid remain - for (int i = 0; i < mid_remain; ++i) { - compute_one_out_without_extract_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - //! deal with w_out pad_new column post - switch (pad_new) { - case 4: - compute_four_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 4; - dout_ptr1 += 4; - break; - case 3: - compute_three_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 3; - dout_ptr1 += 3; - break; - case 2: - compute_two_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 2; - dout_ptr1 += 2; - break; - case 1: - compute_one_out_extract_post_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - dout_ptr0 += 1; - dout_ptr1 += 1; - break; - } - - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in; - din5 = din4 + w_in; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } -} - -void conv_depthwise_5x5s1_small_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_c); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_c; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - //! every h loop, deal with 6 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 6 > h_in - 1 - if (h + 6 > h_in_new) { - switch (h + 6 - h_in_new) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_c; - *dout_ptr1++ = bias_c; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); - } - } - } - free(din_new); -} - -void conv_depthwise_5x5s1_small_relu_impl(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int pad, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - int pad_new = pad > 4 ? 4 : pad; - int pad_0 = pad - pad_new; - int h_in_new = h_in + 2 * pad_new; - int w_in_new = w_in + 2 * pad_new; - int h_out_new = h_out - 2 * pad_0; - int w_out_new = w_out - 2 * pad_0; - float zero_ptr[w_in_new + w_out]; // NOLINT - memset(zero_ptr, 0, w_in_new * sizeof(float)); - float* write_ptr = zero_ptr + w_in_new; - int pad_cnt = pad_0 >> 2; - int pad_remain = pad_0 - (pad_cnt << 2); - int bias_cnt = (w_out * pad_0) >> 2; - int bias_remain = (w_out * pad_0) - (bias_cnt << 2); - int in_spatial_size = w_in_new * h_in_new; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - float* din_new = prepad_input(din, num, ch_in, h_in, w_in, pad_new); - for (int n = 0; n < num; ++n) { - const float* din_batch = din_new + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - float bias_c = flag_bias ? bias[c] : 0.f; - float bias_relu = bias_c > 0.f ? bias_c : 0.f; - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - float32x4_t vbias_c = vdupq_n_f32(bias_relu); - if (flag_bias) { - //! deal with h_out pad_0 line with bias - for (int i = 0; i < bias_cnt; ++i) { - vst1q_f32(dout_ch, vbias_c); - dout_ch += 4; - } - for (int i = 0; i < bias_remain; ++i) { - *dout_ch++ = bias_relu; - } - } else { - //! deal with h_out pad_0 line without bias - for (int i = 0; i < pad_0; ++i) { - memset(dout_ch, 0x00, w_out * sizeof(float)); - dout_ch += w_out; - } - } - //! every h loop, deal with 6 line input - const float* din0 = din_ch; - const float* din1 = din0 + w_in_new; - const float* din2 = din1 + w_in_new; - const float* din3 = din2 + w_in_new; - const float* din4 = din3 + w_in_new; - const float* din5 = din4 + w_in_new; - //! every h loop, deal with 2 line output - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - - //! h loop - for (int h = 0; h < h_out_new; h += 2) { - //! (h - pad_new) + 6 > h_in - 1 - if (h + 6 > h_in_new) { - switch (h + 6 - h_in_new) { - case 5: - din1 = zero_ptr; - case 4: - din2 = zero_ptr; - case 3: - din3 = zero_ptr; - case 2: - din4 = zero_ptr; - case 1: - din5 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out_new) { - dout1 = write_ptr; - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - if (flag_bias) { - //! deal with w_out pad_0 column pre with bias - for (int i = 0; i < pad_cnt; i++) { - vst1q_f32(dout_ptr0, vbias_c); - vst1q_f32(dout_ptr1, vbias_c); - dout_ptr0 += 4; - dout_ptr1 += 4; - } - for (int i = 0; i < pad_remain; ++i) { - *dout_ptr0++ = bias_relu; - *dout_ptr1++ = bias_relu; - } - } else { - //! deal with w_out pad_0 column pre without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - dout_ptr0 += pad_0; - dout_ptr1 += pad_0; - } - //! mid loop - for (int i = 0; i < w_out_new; ++i) { - compute_one_out_without_extract_relu(din_ptr0, - din_ptr1, - din_ptr2, - din_ptr3, - din_ptr4, - din_ptr5, - dout_ptr0, - dout_ptr1, - weights_c, - vbias); - din_ptr0++; - din_ptr1++; - din_ptr2++; - din_ptr3++; - din_ptr4++; - din_ptr5++; - - dout_ptr0++; - dout_ptr1++; - } - if (flag_bias) { - //! deal with w_out pad_0 column post with bias - memcpy(dout_ptr0, dout0, pad_0 * sizeof(float)); - memcpy(dout_ptr1, dout1, pad_0 * sizeof(float)); - } else { - //! deal with w_out pad_0 column post without bias - memset(dout_ptr0, 0x00, pad_0 * sizeof(float)); - memset(dout_ptr1, 0x00, pad_0 * sizeof(float)); - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din5; - din4 = din3 + w_in_new; - din5 = din4 + w_in_new; - - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - float* dout_pad_end = dout_ch + h_out_new * w_out; - if (flag_bias) { - //! deal with h_out pad_0 line with bias - memcpy(reinterpret_cast(dout_pad_end), - dout_ch - pad_0 * w_out, - pad_0 * w_out * sizeof(float)); - } else { - //! deal with h_out pad_0 line without bias - memset(reinterpret_cast(dout_pad_end), - 0x00, - pad_0 * w_out * sizeof(float)); + "ldr q24, [%[bias]] \n" /* load bias to out00 */ + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[wc]], #64 \n" /* load w0-w3 */ + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[inr0]], #64 \n" /* load inr0, 0-3 */ + "1:\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[inr1]], #64 \n" /* load inr1, 0-3 */ + "mov v25.16b, v24.16b \n" /* mov bias to out01 */ + "mov v26.16b, v24.16b \n" /* mov bias to out02 */ + "mov v27.16b, v24.16b \n" /* mov bias to out03 */ + "mov v28.16b, v24.16b \n" /* mov bias to out10 */ + "mov v29.16b, v24.16b \n" /* mov bias to out11 */ + "mov v30.16b, v24.16b \n" /* mov bias to out12 */ + "mov v31.16b, v24.16b \n" /* mov bias to out13 */ + // out row0 + "fmla v24.4s, v8.4s, v0.4s \n" /* out00 = w0 * inr00 */ + "fmla v25.4s, v9.4s, v0.4s \n" /* out01 = w0 * inr01 */ + "ldp q12, q13, [%[inr0]] \n" /* load inr0, 4-5 */ + "fmla v26.4s, v10.4s, v0.4s \n" /* out02 = w0 * inr02 */ + "fmla v27.4s, v11.4s, v0.4s \n" /* out03 = w0 * inr03 */ + "fmla v28.4s, v16.4s, v0.4s \n" /* out10 = w0 * inr10 */ + "fmla v29.4s, v17.4s, v0.4s \n" /* out11 = w0 * inr11 */ + "ldp q20, q21, [%[inr1]] \n" /* load inr1, 4-5 */ + "fmla v30.4s, v18.4s, v0.4s \n" /* out12 = w0 * inr12 */ + "fmla v31.4s, v19.4s, v0.4s \n" /* out13 = w0 * inr13 */ + "fmla v24.4s, v9.4s, v1.4s \n" /* out00 = w1 * inr01 */ + "fmla v25.4s, v10.4s, v1.4s \n" /* out01 = w1 * inr02 */ + "fmla v26.4s, v11.4s, v1.4s \n" /* out02 = w1 * inr03 */ + "fmla v27.4s, v12.4s, v1.4s \n" /* out03 = w1 * inr04 */ + "ldp q14, q15, [%[inr0], #32] \n" /* load inr0, 6-7 */ + "fmla v28.4s, v17.4s, v1.4s \n" /* out10 = w1 * inr11 */ + "fmla v29.4s, v18.4s, v1.4s \n" /* out11 = w1 * inr12 */ + "fmla v30.4s, v19.4s, v1.4s \n" /* out12 = w1 * inr13 */ + "fmla v31.4s, v20.4s, v1.4s \n" /* out13 = w1 * inr14 */ + "fmla v24.4s, v10.4s, v2.4s \n" /* out00 = w2 * inr02 */ + "fmla v25.4s, v11.4s, v2.4s \n" /* out01 = w2 * inr03 */ + "fmla v26.4s, v12.4s, v2.4s \n" /* out02 = w2 * inr04 */ + "fmla v27.4s, v13.4s, v2.4s \n" /* out03 = w2 * inr05 */ + "ldp q22, q23, [%[inr1], #32] \n" /* load inr1, 6-7 */ + "fmla v28.4s, v18.4s, v2.4s \n" /* out10 = w2 * inr12 */ + "fmla v29.4s, v19.4s, v2.4s \n" /* out11 = w2 * inr13 */ + "fmla v30.4s, v20.4s, v2.4s \n" /* out12 = w2 * inr14 */ + "fmla v31.4s, v21.4s, v2.4s \n" /* out13 = w2 * inr15 */ + "ldp q4, q5, [%[wc]], #32 \n" /* load w4-w5 */ + "fmla v24.4s, v11.4s, v3.4s \n" /* out00 = w3 * inr03 */ + "fmla v25.4s, v12.4s, v3.4s \n" /* out01 = w3 * inr04 */ + "fmla v26.4s, v13.4s, v3.4s \n" /* out02 = w3 * inr05 */ + "fmla v27.4s, v14.4s, v3.4s \n" /* out03 = w3 * inr06 */ + "ldp q6, q7, [%[wc]], #32 \n" /* load w6-w7 */ + "fmla v28.4s, v19.4s, v3.4s \n" /* out10 = w3 * inr13 */ + "fmla v29.4s, v20.4s, v3.4s \n" /* out11 = w3 * inr14 */ + "fmla v30.4s, v21.4s, v3.4s \n" /* out12 = w3 * inr15 */ + "fmla v31.4s, v22.4s, v3.4s \n" /* out13 = w3 * inr16 */ + "fmla v24.4s, v12.4s, v4.4s \n" /* out00 = w4 * inr04 */ + "fmla v25.4s, v13.4s, v4.4s \n" /* out01 = w4 * inr05 */ + "fmla v26.4s, v14.4s, v4.4s \n" /* out02 = w4 * inr06 */ + "fmla v27.4s, v15.4s, v4.4s \n" /* out03 = w4 * inr07 */ + "ldp q8, q9, [%[inr2]], #32 \n" /* load inr2, 0-1 */ + "fmla v28.4s, v20.4s, v4.4s \n" /* out10 = w4 * inr14 */ + "fmla v29.4s, v21.4s, v4.4s \n" /* out11 = w4 * inr15 */ + "fmla v30.4s, v22.4s, v4.4s \n" /* out12 = w4 * inr16 */ + "fmla v31.4s, v23.4s, v4.4s \n" /* out13 = w4 * inr17 */ + "ldp q10, q11, [%[inr2]], #32\n" /* load inr2, 2-3 */ + // out row1 + "fmla v24.4s, v16.4s, v5.4s \n" /* out00 = w5 * inr10 */ + "fmla v25.4s, v17.4s, v5.4s \n" /* out01 = w5 * inr11 */ + "fmla v26.4s, v18.4s, v5.4s \n" /* out02 = w5 * inr12 */ + "fmla v27.4s, v19.4s, v5.4s \n" /* out03 = w5 * inr13 */ + "ldp q12, q13, [%[inr2]] \n" /* load inr2, 4-5 */ + "fmla v28.4s, v8.4s, v5.4s \n" /* out10 = w5 * inr20 */ + "fmla v29.4s, v9.4s, v5.4s \n" /* out11 = w5 * inr21 */ + "fmla v30.4s, v10.4s, v5.4s \n" /* out12 = w5 * inr22 */ + "fmla v31.4s, v11.4s, v5.4s \n" /* out13 = w5 * inr23 */ + "fmla v24.4s, v17.4s, v6.4s \n" /* out00 = w6 * inr11 */ + "fmla v25.4s, v18.4s, v6.4s \n" /* out01 = w6 * inr12 */ + "fmla v26.4s, v19.4s, v6.4s \n" /* out02 = w6 * inr13 */ + "fmla v27.4s, v20.4s, v6.4s \n" /* out03 = w6 * inr14 */ + "ldp q14, q15, [%[inr2], #32]\n" /* load inr2, 6-7 */ + "fmla v28.4s, v9.4s, v6.4s \n" /* out10 = w6 * inr21 */ + "fmla v29.4s, v10.4s, v6.4s \n" /* out11 = w6 * inr22 */ + "fmla v30.4s, v11.4s, v6.4s \n" /* out12 = w6 * inr23 */ + "fmla v31.4s, v12.4s, v6.4s \n" /* out13 = w6 * inr24 */ + "fmla v24.4s, v18.4s, v7.4s \n" /* out00 = w7 * inr12 */ + "fmla v25.4s, v19.4s, v7.4s \n" /* out01 = w7 * inr13 */ + "fmla v26.4s, v20.4s, v7.4s \n" /* out02 = w7 * inr14 */ + "fmla v27.4s, v21.4s, v7.4s \n" /* out03 = w7 * inr15 */ + "ldp q0, q1, [%[wc]], #32 \n" /* load w8-w9 */ + "fmla v28.4s, v10.4s, v7.4s \n" /* out10 = w7 * inr22 */ + "fmla v29.4s, v11.4s, v7.4s \n" /* out11 = w7 * inr23 */ + "fmla v30.4s, v12.4s, v7.4s \n" /* out12 = w7 * inr24 */ + "fmla v31.4s, v13.4s, v7.4s \n" /* out13 = w7 * inr25 */ + "fmla v24.4s, v19.4s, v0.4s \n" /* out00 = w8 * inr13 */ + "fmla v25.4s, v20.4s, v0.4s \n" /* out01 = w8 * inr14 */ + "fmla v26.4s, v21.4s, v0.4s \n" /* out02 = w8 * inr15 */ + "fmla v27.4s, v22.4s, v0.4s \n" /* out03 = w8 * inr16 */ + "ldp q2, q3, [%[wc]], #32 \n" /* load w10-w11 */ + "fmla v28.4s, v11.4s, v0.4s \n" /* out10 = w8 * inr23 */ + "fmla v29.4s, v12.4s, v0.4s \n" /* out11 = w8 * inr24 */ + "fmla v30.4s, v13.4s, v0.4s \n" /* out12 = w8 * inr25 */ + "fmla v31.4s, v14.4s, v0.4s \n" /* out13 = w8 * inr26 */ + "ldp q16, q17, [%[inr3]], #32\n" /* load inr3, 0-1 */ + "fmla v24.4s, v20.4s, v1.4s \n" /* out00 = w9 * inr14 */ + "fmla v25.4s, v21.4s, v1.4s \n" /* out01 = w9 * inr15 */ + "fmla v26.4s, v22.4s, v1.4s \n" /* out02 = w9 * inr16 */ + "fmla v27.4s, v23.4s, v1.4s \n" /* out03 = w9 * inr17 */ + "ldp q18, q19, [%[inr3]], #32\n" /* load inr3, 2-3 */ + "fmla v28.4s, v12.4s, v1.4s \n" /* out10 = w9 * inr24 */ + "fmla v29.4s, v13.4s, v1.4s \n" /* out11 = w9 * inr25 */ + "fmla v30.4s, v14.4s, v1.4s \n" /* out12 = w9 * inr26 */ + "fmla v31.4s, v15.4s, v1.4s \n" /* out13 = w9 * inr27 */ + // out row2 + "fmla v24.4s, v8.4s, v2.4s \n" /* out00 = w10 * inr20 */ + "fmla v25.4s, v9.4s, v2.4s \n" /* out01 = w10 * inr21 */ + "fmla v26.4s, v10.4s, v2.4s \n" /* out02 = w10 * inr22 */ + "fmla v27.4s, v11.4s, v2.4s \n" /* out03 = w10 * inr23 */ + "ldp q4, q5, [%[wc]], #32 \n" /* load w12-w13 */ + "fmla v28.4s, v16.4s, v2.4s \n" /* out10 = w10 * inr30 */ + "fmla v29.4s, v17.4s, v2.4s \n" /* out11 = w10 * inr31 */ + "fmla v30.4s, v18.4s, v2.4s \n" /* out12 = w10 * inr32 */ + "fmla v31.4s, v19.4s, v2.4s \n" /* out13 = w10 * inr33 */ + "ldp q20, q21, [%[inr3]] \n" /* load inr3, 4-5 */ + "fmla v24.4s, v9.4s, v3.4s \n" /* out00 = w11 * inr21 */ + "fmla v25.4s, v10.4s, v3.4s \n" /* out01 = w11 * inr22 */ + "fmla v26.4s, v11.4s, v3.4s \n" /* out02 = w11 * inr23 */ + "fmla v27.4s, v12.4s, v3.4s \n" /* out03 = w11 * inr24 */ + "ldp q22, q23, [%[inr3], #32]\n" /* load inr3, 6-7 */ + "fmla v28.4s, v17.4s, v3.4s \n" /* out10 = w11 * inr31 */ + "fmla v29.4s, v18.4s, v3.4s \n" /* out11 = w11 * inr32 */ + "fmla v30.4s, v19.4s, v3.4s \n" /* out12 = w11 * inr33 */ + "fmla v31.4s, v20.4s, v3.4s \n" /* out13 = w11 * inr34 */ + "fmla v24.4s, v10.4s, v4.4s \n" /* out00 = w12 * inr22 */ + "fmla v25.4s, v11.4s, v4.4s \n" /* out01 = w12 * inr23 */ + "fmla v26.4s, v12.4s, v4.4s \n" /* out02 = w12 * inr24 */ + "fmla v27.4s, v13.4s, v4.4s \n" /* out03 = w12 * inr25 */ + "ldp q6, q7, [%[wc]], #32 \n" /* load w14-w15 */ + "fmla v28.4s, v18.4s, v4.4s \n" /* out10 = w12 * inr32 */ + "fmla v29.4s, v19.4s, v4.4s \n" /* out11 = w12 * inr33 */ + "fmla v30.4s, v20.4s, v4.4s \n" /* out12 = w12 * inr34 */ + "fmla v31.4s, v21.4s, v4.4s \n" /* out13 = w12 * inr35 */ + "fmla v24.4s, v11.4s, v5.4s \n" /* out00 = w13 * inr23 */ + "fmla v25.4s, v12.4s, v5.4s \n" /* out01 = w13 * inr24 */ + "fmla v26.4s, v13.4s, v5.4s \n" /* out02 = w13 * inr25 */ + "fmla v27.4s, v14.4s, v5.4s \n" /* out03 = w13 * inr26 */ + "ldp q8, q9, [%[inr4]], #32 \n" /* load inr4, 0-1 */ + "fmla v28.4s, v19.4s, v5.4s \n" /* out10 = w13 * inr33 */ + "fmla v29.4s, v20.4s, v5.4s \n" /* out11 = w13 * inr34 */ + "fmla v30.4s, v21.4s, v5.4s \n" /* out12 = w13 * inr35 */ + "fmla v31.4s, v22.4s, v5.4s \n" /* out13 = w13 * inr36 */ + "fmla v24.4s, v12.4s, v6.4s \n" /* out00 = w14 * inr24 */ + "fmla v25.4s, v13.4s, v6.4s \n" /* out01 = w14 * inr25 */ + "fmla v26.4s, v14.4s, v6.4s \n" /* out02 = w14 * inr26 */ + "fmla v27.4s, v15.4s, v6.4s \n" /* out03 = w14 * inr27 */ + "ldp q10, q11, [%[inr4]], #32\n" /* load inr4, 2-3 */ + "fmla v28.4s, v20.4s, v6.4s \n" /* out10 = w14 * inr34 */ + "fmla v29.4s, v21.4s, v6.4s \n" /* out11 = w14 * inr35 */ + "fmla v30.4s, v22.4s, v6.4s \n" /* out12 = w14 * inr36 */ + "fmla v31.4s, v23.4s, v6.4s \n" /* out13 = w14 * inr37 */ + "ldp q0, q1, [%[wc]], #32 \n" /* load w16-w17 */ + // out row3 + "fmla v24.4s, v16.4s, v7.4s \n" /* out00 = w15 * inr30 */ + "fmla v25.4s, v17.4s, v7.4s \n" /* out01 = w15 * inr31 */ + "fmla v26.4s, v18.4s, v7.4s \n" /* out02 = w15 * inr32 */ + "fmla v27.4s, v19.4s, v7.4s \n" /* out03 = w15 * inr33 */ + "ldp q12, q13, [%[inr4]] \n" /* load inr4, 4-5 */ + "fmla v28.4s, v8.4s, v7.4s \n" /* out10 = w15 * inr40 */ + "fmla v29.4s, v9.4s, v7.4s \n" /* out11 = w15 * inr41 */ + "fmla v30.4s, v10.4s, v7.4s \n" /* out12 = w15 * inr42 */ + "fmla v31.4s, v11.4s, v7.4s \n" /* out13 = w15 * inr42 */ + "ldp q2, q3, [%[wc]], #32 \n" /* load w18-w19 */ + "fmla v24.4s, v17.4s, v0.4s \n" /* out00 = w16 * inr31 */ + "fmla v25.4s, v18.4s, v0.4s \n" /* out01 = w16 * inr32 */ + "fmla v26.4s, v19.4s, v0.4s \n" /* out02 = w16 * inr33 */ + "fmla v27.4s, v20.4s, v0.4s \n" /* out03 = w16 * inr34 */ + "ldp q14, q15, [%[inr4], #32]\n" /* load inr4, 6-7 */ + "fmla v28.4s, v9.4s, v0.4s \n" /* out10 = w16 * inr41 */ + "fmla v29.4s, v10.4s, v0.4s \n" /* out11 = w16 * inr42 */ + "fmla v30.4s, v11.4s, v0.4s \n" /* out12 = w16 * inr43 */ + "fmla v31.4s, v12.4s, v0.4s \n" /* out13 = w16 * inr44 */ + "fmla v24.4s, v18.4s, v1.4s \n" /* out00 = w17 * inr32 */ + "fmla v25.4s, v19.4s, v1.4s \n" /* out01 = w17 * inr33 */ + "fmla v26.4s, v20.4s, v1.4s \n" /* out02 = w17 * inr34 */ + "fmla v27.4s, v21.4s, v1.4s \n" /* out03 = w17 * inr35 */ + "ldp q4, q5, [%[wc]], #32 \n" /* load w20-w21 */ + "fmla v28.4s, v10.4s, v1.4s \n" /* out10 = w17 * inr42 */ + "fmla v29.4s, v11.4s, v1.4s \n" /* out11 = w17 * inr43 */ + "fmla v30.4s, v12.4s, v1.4s \n" /* out12 = w17 * inr44 */ + "fmla v31.4s, v13.4s, v1.4s \n" /* out13 = w17 * inr45 */ + "fmla v24.4s, v19.4s, v2.4s \n" /* out00 = w18 * inr33 */ + "fmla v25.4s, v20.4s, v2.4s \n" /* out01 = w18 * inr34 */ + "fmla v26.4s, v21.4s, v2.4s \n" /* out02 = w18 * inr35 */ + "fmla v27.4s, v22.4s, v2.4s \n" /* out03 = w18 * inr36 */ + "ldp q16, q17, [%[inr5]], #32\n" /* load inr5, 0-1 */ + "fmla v28.4s, v11.4s, v2.4s \n" /* out10 = w18 * inr43 */ + "fmla v29.4s, v12.4s, v2.4s \n" /* out11 = w18 * inr44 */ + "fmla v30.4s, v13.4s, v2.4s \n" /* out12 = w18 * inr45 */ + "fmla v31.4s, v14.4s, v2.4s \n" /* out13 = w18 * inr46 */ + "fmla v24.4s, v20.4s, v3.4s \n" /* out00 = w19 * inr34 */ + "fmla v25.4s, v21.4s, v3.4s \n" /* out01 = w19 * inr35 */ + "fmla v26.4s, v22.4s, v3.4s \n" /* out02 = w19 * inr36 */ + "fmla v27.4s, v23.4s, v3.4s \n" /* out03 = w19 * inr37 */ + "ldp q18, q19, [%[inr5]], #32\n" /* load inr5, 2-3 */ + "fmla v28.4s, v12.4s, v3.4s \n" /* out10 = w19 * inr44 */ + "fmla v29.4s, v13.4s, v3.4s \n" /* out11 = w19 * inr45 */ + "fmla v30.4s, v14.4s, v3.4s \n" /* out12 = w19 * inr46 */ + "fmla v31.4s, v15.4s, v3.4s \n" /* out13 = w19 * inr47 */ + // out row4 + "fmla v24.4s, v8.4s, v4.4s \n" /* out00 = w20 * inr40 */ + "fmla v25.4s, v9.4s, v4.4s \n" /* out01 = w20 * inr41 */ + "fmla v26.4s, v10.4s, v4.4s \n" /* out02 = w20 * inr42 */ + "fmla v27.4s, v11.4s, v4.4s \n" /* out03 = w20 * inr43 */ + "ldp q20, q21, [%[inr5]] \n" /* load inr5, 4-5 */ + "fmla v28.4s, v16.4s, v4.4s \n" /* out10 = w20 * inr50 */ + "fmla v29.4s, v17.4s, v4.4s \n" /* out11 = w20 * inr51 */ + "fmla v30.4s, v18.4s, v4.4s \n" /* out12 = w20 * inr52 */ + "fmla v31.4s, v19.4s, v4.4s \n" /* out13 = w20 * inr53 */ + "ldp q6, q7, [%[wc]], #32 \n" /* load w22-w23 */ + "fmla v24.4s, v9.4s, v5.4s \n" /* out00 = w21 * inr41 */ + "fmla v25.4s, v10.4s, v5.4s \n" /* out01 = w21 * inr42 */ + "fmla v26.4s, v11.4s, v5.4s \n" /* out02 = w21 * inr43 */ + "fmla v27.4s, v12.4s, v5.4s \n" /* out03 = w21 * inr44 */ + "ldp q22, q23, [%[inr5], #32]\n" /* load inr5, 6-7 */ + "fmla v28.4s, v17.4s, v5.4s \n" /* out10 = w21 * inr51 */ + "fmla v29.4s, v18.4s, v5.4s \n" /* out11 = w21 * inr52 */ + "fmla v30.4s, v19.4s, v5.4s \n" /* out12 = w21 * inr53 */ + "fmla v31.4s, v20.4s, v5.4s \n" /* out13 = w21 * inr54 */ + "ldp q8, q9, [%[inr0]], #32 \n" /* load inr0, 0-1 */ + "fmla v24.4s, v10.4s, v6.4s \n" /* out00 = w22 * inr42 */ + "fmla v25.4s, v11.4s, v6.4s \n" /* out01 = w22 * inr43 */ + "fmla v26.4s, v12.4s, v6.4s \n" /* out02 = w22 * inr44 */ + "fmla v27.4s, v13.4s, v6.4s \n" /* out03 = w22 * inr45 */ + "ldp q4, q5, [%[wc]], #-384 \n" /* load w24 */ + "fmla v28.4s, v18.4s, v6.4s \n" /* out10 = w22 * inr52 */ + "fmla v29.4s, v19.4s, v6.4s \n" /* out11 = w22 * inr53 */ + "fmla v30.4s, v20.4s, v6.4s \n" /* out12 = w22 * inr54 */ + "fmla v31.4s, v21.4s, v6.4s \n" /* out13 = w22 * inr55 */ + "ldp q0, q1, [%[wc]], #32 \n" /* load w0-w1 */ + "fmla v24.4s, v11.4s, v7.4s \n" /* out00 = w23 * inr43 */ + "fmla v25.4s, v12.4s, v7.4s \n" /* out01 = w23 * inr44 */ + "fmla v26.4s, v13.4s, v7.4s \n" /* out02 = w23 * inr45 */ + "fmla v27.4s, v14.4s, v7.4s \n" /* out03 = w23 * inr46 */ + "ldp q2, q3, [%[wc]], #32 \n" /* load w1-w2 */ + "fmla v28.4s, v19.4s, v7.4s \n" /* out10 = w23 * inr53 */ + "fmla v29.4s, v20.4s, v7.4s \n" /* out11 = w23 * inr54 */ + "fmla v30.4s, v21.4s, v7.4s \n" /* out12 = w23 * inr55 */ + "fmla v31.4s, v22.4s, v7.4s \n" /* out13 = w23 * inr56 */ + "ldp q10, q11, [%[inr0]], #32\n" /* load inr0, 2-3 */ + "fmla v24.4s, v12.4s, v4.4s \n" /* out00 = w24 * inr44 */ + "fmla v25.4s, v13.4s, v4.4s \n" /* out01 = w24 * inr45 */ + "fmla v26.4s, v14.4s, v4.4s \n" /* out02 = w24 * inr46 */ + "fmla v27.4s, v15.4s, v4.4s \n" /* out03 = w24 * inr47 */ + "stp q24, q25, [%[out0]], #32\n" /* store outr0, 0-1 */ + "fmla v28.4s, v20.4s, v4.4s \n" /* out10 = w24 * inr54 */ + "fmla v29.4s, v21.4s, v4.4s \n" /* out11 = w24 * inr55 */ + "stp q26, q27, [%[out0]], #32\n" /* store outr0, 2-3 */ + "fmla v30.4s, v22.4s, v4.4s \n" /* out12 = w24 * inr56 */ + "fmla v31.4s, v23.4s, v4.4s \n" /* out13 = w24 * inr57 */ + "ldr q24, [%[bias]] \n" /* load bias to out00 */ + "subs %w[cnt], %w[cnt], #1\n" /* cnt = cnt - 1 */ + "stp q28, q29, [%[out1]], #32\n" /* store outr1, 0-1 */ + "stp q30, q31, [%[out1]], #32\n" /* store outr1, 2-3 */ + "bne 1b\n" + : [cnt] "+r"(cnt), + [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [inr5] "+r"(inr5), + [wc] "+r"(wptr), + [out0] "+r"(ptr_out0), + [out1] "+r"(ptr_out1) + : [bias] "r"(bias_local) + : "cc","memory", + "v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13", + "v14","v15","v16","v17","v18","v19", + "v20","v21","v22","v23","v24","v25", + "v26","v27","v28","v29","v30","v31" + ); + // clang-format on + block_inr0 = block_inr2; + block_inr1 = block_inr3; + block_inr2 = block_inr4; + block_inr3 = block_inr5; + block_inr4 = block_inr3 + in_len; + block_inr5 = block_inr4 + in_len; + } + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, + chout, + hout, + wout, + flag_relu, + ptr_write, + &act_param); } } } - free(din_new); } -#endif // __aarch64__ - -void conv_depthwise_5x5s1_fp32(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, +#else // __aarch64__ +void conv_depthwise_5x5s1_fp32(float* dout, + const float* din, const float* weights, const float* bias, - int pad, bool flag_bias, bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + const operators::ConvParam& param, ARMContext* ctx) { - if (win < 4) { - if (flag_relu) { - conv_depthwise_5x5s1_small_relu_impl(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s1_small_impl(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); - } - } else { - if (flag_relu) { - conv_depthwise_5x5s1_relu_impl(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s1_impl(din, - dout, - num, + const int threads = ctx->threads(); + int llc_size = ctx->llc_size() / 4; + auto act_param = param.activation_param; + const int hout_c_block = 4; + const int hout_r_kernel = 1; + const int wout_block = 4; + const int wout_round = ((wout + wout_block - 1) / wout_block) * wout_block; + const int win_round = wout_round + 4; + + //! get h block + //! llc_size = threads * win_round * hout_c_block * hin_r_block * + //! sizeof(float) + //! + wout_round * hout_c_block * hout_r_block * threads * sizeof(float) + //! win_round = wout_round + 4 + //! hin_r_block = hout_r_block + 4 + int hout_r_block = (llc_size - 16 * win_round * hout_c_block * threads) / + (win_round * hout_c_block * threads * 4 + + hout_c_block * wout_round * threads * 4); + hout_r_block = hout_r_block > hout ? hout : hout_r_block; + hout_r_block = + ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; + hout_r_block = hout_r_block < hout_r_kernel ? hout_r_kernel : hout_r_block; + + const int hin_r_block = hout_r_block + 4; + + float* tmp_work_space = ctx->workspace_data(); + float ptr_zero[win_round]; // NOLINT + memset(ptr_zero, 0, sizeof(float) * win_round); + float ptr_write[wout_round]; // NOLINT + + int in_len = win_round * hout_c_block; + int pre_in_size = hin_r_block * in_len; + pre_in_size = ROUNDUP(pre_in_size, 4); + int pre_out_size = hout_c_block * hout_r_block * wout_round; + + float* tmp_din = tmp_work_space; + + int size_in_channel = win * hin; + int size_out_channel = wout * hout; + int w_stride = 25; // kernel_w * kernel_h; + + int ws = -padw; + int we = ws + win_round; + int w_loop = wout_round / 4; + int chout = chin; + + int out_row_stride = hout_c_block * wout_round; + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * chin * size_in_channel; + float* dout_batch = dout + n * chout * size_out_channel; + for (int h = 0; h < hout; h += hout_r_block) { + int h_kernel = hout_r_block; + if (h + hout_r_block > hout) { + h_kernel = hout - h; + } + int hs = h - padh; + int he = hs + h_kernel + 4; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < chout; c += hout_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = + tmp_din + omp_get_thread_num() * (pre_in_size + pre_out_size); + float* pre_out = pre_din + pre_in_size; +#else + float* pre_din = tmp_din; + float* pre_out = pre_din + pre_in_size; +#endif + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, chin, win, hin, ptr_zero); + const float* block_inr0 = pre_din; + const float* block_inr1 = block_inr0 + in_len; + const float* block_inr2 = block_inr1 + in_len; + const float* block_inr3 = block_inr2 + in_len; + const float* block_inr4 = block_inr3 + in_len; + + const float* weight_c = weights + c * w_stride; + float bias_local[4] = {0, 0, 0, 0}; + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } + for (int hk = 0; hk < h_kernel; hk += hout_r_kernel) { + int cnt = w_loop; + const float* inr0 = block_inr0; + const float* inr1 = block_inr1; + const float* inr2 = block_inr2; + const float* inr3 = block_inr3; + const float* inr4 = block_inr4; + + float* ptr_out0 = pre_out + hk * out_row_stride; + // clang-format off + auto wptr = weight_c; + asm volatile( + "vld1.32 {d24-d25}, [%[bias]] \n" /* load bias to out00 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w0-w1 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w2-w3 */ + "vld1.32 {d8-d11}, [%[inr0]]! \n" /* load inr0, 0-1 */ + "vld1.32 {d12-d15}, [%[inr0]]! \n" /* load inr0, 2-3 */ + "1:\n" + "vld1.32 {d16-d19}, [%[inr0]]! \n" /* load inr0, 4-5 */ + "vmov.u32 q13, q12 \n" /* mov bias to out01 */ + "vmov.u32 q14, q12 \n" /* mov bias to out02 */ + "vmov.u32 q15, q12 \n" /* mov bias to out03 */ + // out row0 + "vmla.f32 q12, q4, q0 \n" /* out00 = w0 * inr00 */ + "vmla.f32 q13, q5, q0 \n" /* out01 = w0 * inr01 */ + "vmla.f32 q14, q6, q0 \n" /* out02 = w0 * inr02 */ + "vmla.f32 q15, q7, q0 \n" /* out03 = w0 * inr03 */ + "vld1.32 {d20-d23}, [%[inr0]]! \n" /* load inr0, 6-7 */ + "sub %[inr0], %[inr0], #64 \n" /* inr0 -= 64 */ + "vmla.f32 q12, q5, q1 \n" /* out00 = w1 * inr01 */ + "vmla.f32 q13, q6, q1 \n" /* out01 = w1 * inr02 */ + "vmla.f32 q14, q7, q1 \n" /* out02 = w1 * inr03 */ + "vmla.f32 q15, q8, q1 \n" /* out03 = w1 * inr04 */ + "vld1.32 {d8-d11}, [%[inr1]]!\n" /* load inr1, 0-1 */ + "vmla.f32 q12, q6, q2 \n" /* out00 = w2 * inr02 */ + "vmla.f32 q13, q7, q2 \n" /* out01 = w2 * inr03 */ + "vmla.f32 q14, q8, q2 \n" /* out02 = w2 * inr04 */ + "vmla.f32 q15, q9, q2 \n" /* out03 = w2 * inr05 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w4-w5 */ + "vmla.f32 q12, q7, q3 \n" /* out00 = w3 * inr03 */ + "vmla.f32 q13, q8, q3 \n" /* out01 = w3 * inr04 */ + "vmla.f32 q14, q9, q3 \n" /* out02 = w3 * inr05 */ + "vmla.f32 q15, q10, q3 \n" /* out03 = w3 * inr06 */ + "vld1.32 {d12-d15}, [%[inr1]]!\n" /* load inr1, 2-3 */ + "vmla.f32 q12, q8, q0 \n" /* out00 = w4 * inr04 */ + "vmla.f32 q13, q9, q0 \n" /* out01 = w4 * inr05 */ + "vmla.f32 q14, q10, q0 \n" /* out02 = w4 * inr06 */ + "vmla.f32 q15, q11, q0 \n" /* out03 = w4 * inr07 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w6-w7 */ + // out row1 + "vmla.f32 q12, q4, q1 \n" /* out00 = w5 * inr10 */ + "vmla.f32 q13, q5, q1 \n" /* out01 = w5 * inr11 */ + "vmla.f32 q14, q6, q1 \n" /* out02 = w5 * inr12 */ + "vmla.f32 q15, q7, q1 \n" /* out03 = w5 * inr13 */ + "vld1.32 {d16-d19}, [%[inr1]]!\n" /* load inr1, 4-5 */ + "vmla.f32 q12, q5, q2 \n" /* out00 = w6 * inr11 */ + "vmla.f32 q13, q6, q2 \n" /* out01 = w6 * inr12 */ + "vmla.f32 q14, q7, q2 \n" /* out02 = w6 * inr13 */ + "vmla.f32 q15, q8, q2 \n" /* out03 = w6 * inr14 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w8-w9 */ + "vmla.f32 q12, q6, q3 \n" /* out00 = w7 * inr12 */ + "vmla.f32 q13, q7, q3 \n" /* out01 = w7 * inr13 */ + "vld1.32 {d20-d23}, [%[inr1]]!\n" /* load inr1, 6-7 */ + "vmla.f32 q14, q8, q3 \n" /* out02 = w7 * inr14 */ + "vmla.f32 q15, q9, q3 \n" /* out03 = w7 * inr15 */ + "sub %[inr1], %[inr1], #64 \n" /* inr1 -= 64 */ + "vmla.f32 q12, q7, q0 \n" /* out00 = w8 * inr13 */ + "vmla.f32 q13, q8, q0 \n" /* out01 = w8 * inr14 */ + "vld1.32 {d8-d11}, [%[inr2]]!\n" /* load inr2, 0-1 */ + "vmla.f32 q14, q9, q0 \n" /* out02 = w8 * inr15 */ + "vmla.f32 q15, q10, q0 \n" /* out03 = w8 * inr16 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w10-w11 */ + "vmla.f32 q12, q8, q1 \n" /* out00 = w9 * inr14 */ + "vmla.f32 q13, q9, q1 \n" /* out01 = w9 * inr15 */ + "vld1.32 {d12-d15}, [%[inr2]]!\n" /* load inr2, 2-3 */ + "vmla.f32 q14, q10, q1 \n" /* out02 = w9 * inr16 */ + "vmla.f32 q15, q11, q1 \n" /* out03 = w9 * inr17 */ + // out row3 + "vmla.f32 q12, q4, q2 \n" /* out00 = w10 * inr20 */ + "vmla.f32 q13, q5, q2 \n" /* out01 = w10 * inr21 */ + "vld1.32 {d16-d19}, [%[inr2]]!\n" /* load inr2, 4-5 */ + "vmla.f32 q14, q6, q2 \n" /* out02 = w10 * inr22 */ + "vmla.f32 q15, q7, q2 \n" /* out03 = w10 * inr23 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w12-w13 */ + "vmla.f32 q12, q5, q3 \n" /* out00 = w11 * inr21 */ + "vmla.f32 q13, q6, q3 \n" /* out01 = w11 * inr22 */ + "vld1.32 {d20-d23}, [%[inr2]]!\n" /* load inr2, 6-7 */ + "vmla.f32 q14, q7, q3 \n" /* out02 = w11 * inr23 */ + "vmla.f32 q15, q8, q3 \n" /* out03 = w11 * inr24 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w14-w15 */ + "sub %[inr2], %[inr2], #64 \n" /* inr2 -= 64 */ + "vmla.f32 q12, q6, q0 \n" /* out00 = w12 * inr22 */ + "vmla.f32 q13, q7, q0 \n" /* out01 = w12 * inr23 */ + "vmla.f32 q14, q8, q0 \n" /* out02 = w12 * inr24 */ + "vmla.f32 q15, q9, q0 \n" /* out03 = w12 * inr25 */ + "vld1.32 {d8-d11}, [%[inr3]]!\n" /* load inr3, 0-1 */ + "vmla.f32 q12, q7, q1 \n" /* out00 = w13 * inr23 */ + "vmla.f32 q13, q8, q1 \n" /* out01 = w13 * inr24 */ + "vmla.f32 q14, q9, q1 \n" /* out02 = w13 * inr25 */ + "vmla.f32 q15, q10, q1 \n" /* out03 = w13 * inr26 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w16-w17 */ + "vmla.f32 q12, q8, q2 \n" /* out00 = w14 * inr24 */ + "vmla.f32 q13, q9, q2 \n" /* out01 = w14 * inr25 */ + "vld1.32 {d12-d15}, [%[inr3]]!\n" /* load inr3, 2-3 */ + "vmla.f32 q14, q10, q2 \n" /* out02 = w14 * inr26 */ + "vmla.f32 q15, q11, q2 \n" /* out03 = w14 * inr27 */ + // out row3 + "vmla.f32 q12, q4, q3 \n" /* out00 = w15 * inr30 */ + "vmla.f32 q13, q5, q3 \n" /* out01 = w15 * inr31 */ + "vld1.32 {d16-d19}, [%[inr3]]!\n" /* load inr3, 4-5 */ + "vmla.f32 q14, q6, q3 \n" /* out02 = w15 * inr32 */ + "vmla.f32 q15, q7, q3 \n" /* out03 = w15 * inr33 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w18-w19 */ + "vmla.f32 q12, q5, q0 \n" /* out00 = w16 * inr31 */ + "vmla.f32 q13, q6, q0 \n" /* out01 = w16 * inr32 */ + "vld1.32 {d20-d23}, [%[inr3]]!\n" /* load inr3, 6-7 */ + "vmla.f32 q14, q7, q0 \n" /* out02 = w16 * inr33 */ + "vmla.f32 q15, q8, q0 \n" /* out03 = w16 * inr34 */ + "sub %[inr3], %[inr3], #64 \n" /* inr3 -= 64 */ + "vmla.f32 q12, q6, q1 \n" /* out00 = w17 * inr32 */ + "vmla.f32 q13, q7, q1 \n" /* out01 = w17 * inr33 */ + "vmla.f32 q14, q8, q1 \n" /* out02 = w17 * inr34 */ + "vmla.f32 q15, q9, q1 \n" /* out03 = w17 * inr35 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w20-w21 */ + "vmla.f32 q12, q7, q2 \n" /* out00 = w18 * inr33 */ + "vmla.f32 q13, q8, q2 \n" /* out01 = w18 * inr34 */ + "vmla.f32 q14, q9, q2 \n" /* out02 = w18 * inr35 */ + "vmla.f32 q15, q10, q2 \n" /* out03 = w18 * inr36 */ + "vld1.32 {d8-d11}, [%[inr4]]!\n" /* load inr4, 0-1 */ + "vmla.f32 q12, q8, q3 \n" /* out00 = w19 * inr34 */ + "vmla.f32 q13, q9, q3 \n" /* out01 = w19 * inr35 */ + "vld1.32 {d12-d15}, [%[inr4]]!\n" /* load inr4, 2-3 */ + "vmla.f32 q14, q10, q3 \n" /* out02 = w19 * inr36 */ + "vmla.f32 q15, q11, q3 \n" /* out03 = w19 * inr37 */ + // out row4 + "vmla.f32 q12, q4, q0 \n" /* out00 = w20 * inr40 */ + "vmla.f32 q13, q5, q0 \n" /* out01 = w20 * inr41 */ + "vld1.32 {d16-d19}, [%[inr4]]!\n" /* load inr4, 4-5 */ + "vmla.f32 q14, q6, q0 \n" /* out02 = w20 * inr42 */ + "vmla.f32 q15, q7, q0 \n" /* out03 = w20 * inr43 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w22-w23 */ + "vmla.f32 q12, q5, q1 \n" /* out00 = w21 * inr41 */ + "vmla.f32 q13, q6, q1 \n" /* out01 = w21 * inr42 */ + "vmla.f32 q14, q7, q1 \n" /* out02 = w21 * inr43 */ + "vmla.f32 q15, q8, q1 \n" /* out03 = w21 * inr44 */ + "vld1.32 {d20-d23}, [%[inr4]]!\n" /* load inr4, 6-7 */ + "vmla.f32 q12, q6, q2 \n" /* out00 = w22 * inr42 */ + "vmla.f32 q13, q7, q2 \n" /* out01 = w22 * inr43 */ + "vmla.f32 q14, q8, q2 \n" /* out02 = w22 * inr44 */ + "vmla.f32 q15, q9, q2 \n" /* out03 = w22 * inr45 */ + "vld1.32 {d4-d5}, [%[wc]] \n" /* load w24 */ + "sub %[inr4], %[inr4], #64 \n" /* inr4 -= 64 */ + "vmla.f32 q12, q7, q3 \n" /* out00 = w23 * inr43 */ + "vmla.f32 q13, q8, q3 \n" /* out01 = w23 * inr44 */ + "vld1.32 {d8-d11}, [%[inr0]]!\n" /* load inr0, 0-1 */ + "sub %[wc], %[wc], #384 \n" /* wptr = wptr - 384 */ + "vmla.f32 q14, q9, q3 \n" /* out02 = w23 * inr45 */ + "vmla.f32 q15, q10, q3 \n" /* out03 = w23 * inr46 */ + "vld1.32 {d0-d3}, [%[wc]]! \n" /* load w0-w1 */ + "vmla.f32 q12, q8, q2 \n" /* out00 = w24 * inr44 */ + "vmla.f32 q13, q9, q2 \n" /* out01 = w24 * inr45 */ + "vld1.32 {d12-d15}, [%[inr0]]!\n" /* load inr0, 2-3 */ + "vmla.f32 q14, q10, q2 \n" /* out02 = w24 * inr46 */ + "vmla.f32 q15, q11, q2 \n" /* out03 = w24 * inr47 */ + "vst1.32 {d24-d27}, [%[out0]]!\n" /* store out00, out01 */ + "vld1.32 {d4-d7}, [%[wc]]! \n" /* load w2-w3 */ + "subs %[cnt], %[cnt], #1 \n" /* cnt = cnt - 1 */ + "vst1.32 {d28-d31}, [%[out0]]!\n" /* store out02, out03 */ + "vld1.32 {d24-d25}, [%[bias]] \n" /* load bias to out00 */ + "bne 1b\n" + : [cnt] "+r"(cnt), + [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc] "+r"(wptr), + [out0] "+r"(ptr_out0) + : [bias] "r"(bias_local) + : "cc","memory", + "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15" + ); + // clang-format on + block_inr0 = block_inr1; + block_inr1 = block_inr2; + block_inr2 = block_inr3; + block_inr3 = block_inr4; + block_inr4 = block_inr3 + in_len; + } + write_to_output_c4_fp32(pre_out, + dout_batch, + c, + c + hout_c_block, + h, + h + h_kernel, + 0, + wout_round, chout, hout, wout, - chin, - hin, - win, - weights, - bias, - pad, - flag_bias, flag_relu, - ctx); + ptr_write, + &act_param); + } } } } - +#endif // __aarch64__ } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc index 802082048c86beeeecfe64a0de09880b1b9b0137..ed3dad300804dc90fac874999ac5d0a420cff4a4 100644 --- a/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc +++ b/lite/backends/arm/math/conv5x5s1_depthwise_int8.cc @@ -709,7 +709,6 @@ void conv_depthwise_5x5s1_int8(Dtype* dout, "q15"); #endif // clang-format on - int32_t* ptr_tmp = ptr_out0 - w_loop * 32; block_inr0 = block_inr1; block_inr1 = block_inr2; block_inr2 = block_inr3; diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index 6286b887c0cd55b37b998077de8dc0f99dc12923..5524732029f07a0cd4d31f3c28a2435d45b50d67 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -198,24 +198,24 @@ namespace math { "fmin v20.4s, v20.4s, %[vsix].4s\n" \ "fmin v21.4s, v21.4s, %[vsix].4s\n" \ "fmin v22.4s, v22.4s, %[vsix].4s\n" -#define LEAKY_RELU /* LeakyRelu */ \ - "movi v0.4s, #0\n" /* for relu */ \ - "cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ - "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ - "cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ - "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ - "cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ - "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ - "cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ - "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ - "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ - "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ - "bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \ - "bif v19.16b, v8.16b, v7.16b \n" /* choose*/ -#define STORE /* save result */ \ - "str q19, [%[outc0]], #16\n" \ - "str q20, [%[outc1]], #16\n" \ - "str q21, [%[outc2]], #16\n" \ +#define LEAKY_RELU /* LeakyRelu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ + "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ + "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ + "bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \ + "bif v19.16b, v8.16b, v7.16b \n" /* choose*/ +#define STORE /* save result */ \ + "str q19, [%[outc0]], #16\n" \ + "str q20, [%[outc1]], #16\n" \ + "str q21, [%[outc2]], #16\n" \ "str q22, [%[outc3]], #16\n" #else diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 240679898f498d5bc1c8cf44aef0d43c2d025625..85404d6a6e2e6246677857be8231e15afa86210d 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -614,16 +614,16 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din, "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ #define NCHWC1_TRANS_FP32_LEAKY_RELU \ - "cmhs v4.4s, v0.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v5.4s, v1.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ - "fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \ - "fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \ + "fcmge v4.4s, v0.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v5.4s, v1.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v8.4s, v0.4s, %[scale].4s \n" /* mul */ \ + "fmul v9.4s, v1.4s, %[scale].4s \n" /* mul */ \ "fmul v10.4s, v2.4s, %[scale].4s \n" /* mul */ \ "fmul v11.4s, v3.4s, %[scale].4s \n" /* mul */ \ - "bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \ - "bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \ + "bif v0.16b, v8.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v9.16b, v5.16b \n" /* choose*/ \ "bif v2.16b, v10.16b, v6.16b \n" /* choose*/ \ "bif v3.16b, v11.16b, v7.16b \n" /* choose*/ @@ -674,15 +674,15 @@ inline void prepack_input_nxwc8_int8_dw(const int8_t* din, "vbif q3, q12, q8 @ choose \n" #define NCHWC1_TRANS_FP32_STORE \ - "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \ - "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \ + "vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result \n" \ + "vst1.32 {d2-d3}, [%[doutc0r0]]! @ store result, \n" \ "subs %[cnt], %[cnt], #1 @ loop count - 1\n" \ \ - "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \ - "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \ + "vld1.32 {d0-d3}, [%[ptr_din]]! @ load data \n" \ + "vst1.32 {d4-d5}, [%[doutc0r0]]! @ store result \n" \ "vst1.32 {d6-d7}, [%[doutc0r0]]! @ store result, \n" \ \ - "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \ + "vld1.32 {d4-d7}, [%[ptr_din]]! @ load data \n" \ \ "bne 1b @ jump to main loop\n" #endif @@ -934,12 +934,12 @@ inline bool write_to_output_c1_fp32(const float* din, "fmin v2.4s, v2.4s, %[six].4s \n" /* relu6 */ \ "fmin v3.4s, v3.4s, %[six].4s \n" /* relu6 */ -#define NCHWC2_TRANS_FP32_LEAKY_RELU \ - "cmhs v6.4s, v2.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v7.4s, v3.4s, v20.4s \n" /* vcgeq_u32 */ \ - "fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \ - "fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \ - "bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \ +#define NCHWC2_TRANS_FP32_LEAKY_RELU \ + "fcmge v6.4s, v2.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v7.4s, v3.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v4.4s, v2.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v3.4s, %[scale].4s \n" /* mul */ \ + "bif v2.16b, v4.16b, v6.16b \n" /* choose*/ \ "bif v3.16b, v5.16b, v7.16b \n" /* choose*/ #define NCHWC2_TRANS_FP32_STORE \ @@ -1275,19 +1275,19 @@ inline bool write_to_output_c2_fp32(const float* din, "fmin v18.4s, v18.4s, %[six].4s \n" /* relu6 */ \ "fmin v19.4s, v19.4s, %[six].4s \n" /* relu6 */ -#define NCHWC4_TRANS_FP32_LEAKY_RELU \ - "cmhs v8.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v9.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v10.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v11.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ - "fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \ - "fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \ - "fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \ - "fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \ - "bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \ - "bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \ - "bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \ - "bif v19.16b, v7.16b, v11.16b \n" /* choose*/ +#define NCHWC4_TRANS_FP32_LEAKY_RELU \ + "fcmge v8.4s, v16.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v9.4s, v17.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v10.4s, v18.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fcmge v11.4s, v19.4s, v20.4s \n" /* vcgeq_f32 */ \ + "fmul v4.4s, v16.4s, %[scale].4s \n" /* mul */ \ + "fmul v5.4s, v17.4s, %[scale].4s \n" /* mul */ \ + "fmul v6.4s, v18.4s, %[scale].4s \n" /* mul */ \ + "fmul v7.4s, v19.4s, %[scale].4s \n" /* mul */ \ + "bif v16.16b, v4.16b, v8.16b \n" /* choose*/ \ + "bif v17.16b, v5.16b, v9.16b \n" /* choose*/ \ + "bif v18.16b, v6.16b, v10.16b \n" /* choose*/ \ + "bif v19.16b, v7.16b, v11.16b \n" /* choose*/ #define NCHWC4_TRANS_FP32_STORE \ "str q16, [%[doutc0r0]], #16 \n" /* store c0r0*/ \ @@ -1754,15 +1754,15 @@ inline bool write_to_output_c4_fp32(const float* din, "fmin v13.4s, v13.4s, %[six].4s \n" /*relu6*/ #define NCHWC8_TRANS_FP32_LEAKY_RELU \ - "cmhs v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v10.4s, v16.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v11.4s, v17.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v14.4s, v18.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v15.4s, v19.4s, v20.4s \n" /* vcgeq_u32 */ \ \ - "cmhs v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \ - "cmhs v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v21.4s, v8.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v22.4s, v9.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v23.4s, v12.4s, v20.4s \n" /* vcgeq_u32 */ \ + "fcmge v24.4s, v13.4s, v20.4s \n" /* vcgeq_u32 */ \ \ "fmul v25.4s, v16.4s, %[scale].4s \n" /* mul */ \ "fmul v26.4s, v17.4s, %[scale].4s \n" /* mul */ \ @@ -1839,7 +1839,7 @@ inline bool write_to_output_c4_fp32(const float* din, "vmin.f32 q7, q7, %q[six] @ relu6\n" #define NCHWC8_TRANS_FP32_LEAKY_RELU \ - "vcge.f32 q9, q0, q15 @ q0 > 0 \n" \ + "vcge.f32 q9, q0, q15 @ q0 > 0 \n" \ "vcge.f32 q10, q1, q15 @ q0 > 0 \n" \ "vcge.f32 q11, q2, q15 @ q0 > 0 \n" \ "vcge.f32 q12, q3, q15 @ q0 > 0 \n" \ @@ -2168,19 +2168,19 @@ inline void act_switch_c8_fp32(const float* din_ptr, "fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ "fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ "fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */ -#define DO_LEAKY_RELU \ - "cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ - "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ - "bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \ - "bif v3.16b, v11.16b, v10.16b \n" /* choose*/ +#define DO_LEAKY_RELU \ + "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ + "bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v10.16b \n" /* choose*/ #define DO_STORE \ "subs %w[cnt], %w[cnt], #1 \n" \ "st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ @@ -2217,7 +2217,7 @@ inline void act_switch_c8_fp32(const float* din_ptr, "vbif q3, q8, q7 @ choose \n" \ "vbif q4, q10, q9 @ choose \n" \ "vbif q5, q12, q11 @ choose \n" \ - "vbif q6, q13, q13 @ choose \n" + "vbif q6, q14, q13 @ choose \n" #define DO_STORE \ "subs %[cnt], #1 \n" \ "vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \ diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index b5dd1b58c497582f78f1e3961d7c4b0a066219f1..186115890d79ec676c85f0bc13dfbe75fc1a551a 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -123,20 +123,21 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, int padh, ARMContext* ctx); -void conv_depthwise_5x5s1_fp32(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, +void conv_depthwise_5x5s1_fp32(float* dout, + const float* din, const float* weights, const float* bias, - int pad, bool flag_bias, bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int padw, + int padh, + const operators::ConvParam& param, ARMContext* ctx); void conv_depthwise_5x5s2_fp32(const float* din, diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index d4d24fdd903eabd7ca7b7a7264ea3d4ce8b4566b..f2fe954d5f53768c2a5497fa9c35764bad186476 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -188,7 +188,6 @@ void conv1x1s1_gemm(const float* i_data, if (n > 1) { weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; } - //! use gemv when the output channel size = 1 for (int b = 0; b < num; ++b) { // dC @@ -210,8 +209,11 @@ void conv1x1s1_gemm(const float* i_data, k, flag_bias, bias_group, - flag_relu, - ctx); + act_param.has_active, + act_param.active_type, + ctx, + act_param.Relu_clipped_coef, + act_param.Leaky_relu_alpha); } else { sgemm_prepack(false, m, @@ -410,8 +412,11 @@ void conv_im2col_gemm(const float* i_data, k, flag_bias, bias_group, - flag_relu, - ctx); + act_param.has_active, + act_param.active_type, + ctx, + act_param.Relu_clipped_coef, + act_param.Leaky_relu_alpha); } else { int ldb = n; sgemm_prepack(false, @@ -677,7 +682,8 @@ void conv_depthwise_5x5_fp32(const void* din, const float* scale) { auto paddings = *param.paddings; auto act_param = param.activation_param; - int pad = paddings[0]; + int pad_h = paddings[0]; + int pad_w = paddings[2]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; @@ -698,20 +704,21 @@ void conv_depthwise_5x5_fp32(const void* din, act_param, ctx); } else if (stride == 1) { - conv_depthwise_5x5s1_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, + conv_depthwise_5x5s1_fp32(reinterpret_cast(dout), + reinterpret_cast(din), reinterpret_cast(weights), bias, - pad, flag_bias, flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + pad_w, + pad_h, + param, ctx); } else { LOG(FATAL) << "unsupport this type 5x5 dw conv"; diff --git a/lite/backends/arm/math/fill_bias_relu.cc b/lite/backends/arm/math/fill_bias_relu.cc index c585548bf1ed5b0a49a60371f4617424fe0195d1..d816c2f549c2c074a35885931a585ff51ae97f6f 100644 --- a/lite/backends/arm/math/fill_bias_relu.cc +++ b/lite/backends/arm/math/fill_bias_relu.cc @@ -136,19 +136,19 @@ void fill_bias_relu(int* tensor, "fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ "fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \ "fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */ -#define FILL_LEAKY_RELU \ - "cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ - "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ - "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ - "bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \ - "bif v3.16b, v11.16b, v10.16b \n" /* choose*/ +#define FILL_LEAKY_RELU \ + "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "fcmge v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + "fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \ + "bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \ + "bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \ + "bif v3.16b, v11.16b, v10.16b \n" /* choose*/ #define FILL_STORE \ "subs %w[cnt], %w[cnt], #1 \n" \ "st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \ diff --git a/lite/backends/arm/math/sgemv.cc b/lite/backends/arm/math/sgemv.cc index 1830423136cc883d30d4eecad0eb9fcfc9ded6ba..98404fe60fdb1384d390458e10dac8c967fd2b21 100644 --- a/lite/backends/arm/math/sgemv.cc +++ b/lite/backends/arm/math/sgemv.cc @@ -22,35 +22,87 @@ namespace lite { namespace arm { namespace math { -void sgemv(const bool transA, - const int M, +void sgemv(const int M, const int N, const float *A, const float *x, - float *y); - -void sgemv_relu(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y); + float *y, + bool flag_bias, + const float *bias); -void sgemv_bias(const bool transA, - const int M, +void sgemv_relu(const int M, const int N, const float *A, const float *x, float *y, + bool flag_bias, const float *bias); -void sgemv_bias_relu(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y, - const float *bias); +void sgemv_relu6(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float six); + +void sgemv_leakey_relu(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float alpha); + +void sgemv_trans(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha); + +bool sgemv(const float *A, + const float *x, + float *y, + bool transA, + int M, + int N, + bool is_bias, + const float *bias, + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha) { + if (transA) { + sgemv_trans(M, N, A, x, y, is_bias, bias, flag_act, act, ctx, six, alpha); + } else { + if (flag_act) { + if (act == lite_api::ActivationType::kRelu) { + sgemv_relu(M, N, A, x, y, is_bias, bias); + } else if (act == lite_api::ActivationType::kRelu6) { + sgemv_relu6(M, N, A, x, y, is_bias, bias, six); + } else if (act == lite_api::ActivationType::kLeakyRelu) { + sgemv_leakey_relu(M, N, A, x, y, is_bias, bias, alpha); + } else { + LOG(FATAL) + << "sgemv no transA only support relu, relu6, leakey relu fusion"; + } + } else { + sgemv(M, N, A, x, y, is_bias, bias); + } + } + return true; +} + #ifdef __aarch64__ void sgemv_trans(const int M, const int N, @@ -59,8 +111,11 @@ void sgemv_trans(const int M, float *y, bool flag_bias, const float *bias, - bool flag_relu, - const ARMContext *ctx) { + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha) { int m_cnt16 = M >> 4; int m_cnt8 = (M & 15) >> 3; int m_cnt4 = (M & 15 & 7) >> 2; @@ -281,26 +336,70 @@ void sgemv_trans(const int M, valid_ths = rdc_ths; rdc_ths = rdc_ths >> 1; } - if (flag_relu) { + if (flag_act) { float *in_y = y_buf; float32x4_t vzero = vdupq_n_f32(0.f); - if (cnt4 > 0) { - int cnt = cnt4; - asm volatile( - "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ - "1:\n" - "fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */ - "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ - "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ - "st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */ - "bne 1b \n" /* branch to label 1*/ - "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ - : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) - : [vzero] "w"(vzero) - : "v0", "v1", "cc", "memory"); - } - for (int r = 0; r < remain; ++r) { - y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + if (act == lite_api::ActivationType::kRelu) { + if (cnt4 > 0) { + int cnt = cnt4; + asm volatile( + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "1:\n" + "fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu */ + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero) + : "v0", "v1", "cc", "memory"); + } + for (int r = 0; r < remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + } + } else if (act == lite_api::ActivationType::kRelu6) { + float32x4_t vsix = vdupq_n_f32(six); + if (cnt4 > 0) { + int cnt = cnt4; + asm volatile( + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "1:\n" + "fmax v1.4s, v0.4s, %[vzero].4s \n" /* v0 relu6 */ + "fmin v1.4s, v1.4s, %[vsix].4s \n" /* v1 relu6 */ + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v1.4s}, [%[out_y]], #16 \n" /* store v1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "v0", "v1", "cc", "memory"); + } + for (int r = 0; r < remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + y[r] = y[r] > six ? six : y[r]; + } + } else if (act == lite_api::ActivationType::kLeakyRelu) { + float32x4_t valpha = vdupq_n_f32(alpha); + if (cnt4 > 0) { + int cnt = cnt4; + asm volatile( + "1:\n" + "ld1 {v0.4s}, [%[in_y]], #16 \n" /* load y to v0 */ + "fcmge v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_f32 */ + "fmul v5.4s, v0.4s, %[valpha].4s \n" /* vmulq_f32 */ + "bif v0.16b, v5.16b, v4.16b \n" /* choose */ + "subs %w[cnt], %w[cnt], #1 \n" /* sub cnt */ + "st1 {v0.4s}, [%[out_y]], #16 \n" /* store v0 to y */ + "bne 1b \n" /* branch to label 1*/ + : [cnt] "+r"(cnt), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [valpha] "w"(valpha) + : "v0", "v4", "v5", "cc", "memory"); + } + for (int r = 0; r < remain; ++r) { + y[r] = in_y[r] < 0.f ? alpha * in_y[r] : in_y[r]; + } } } else { memcpy(y, y_buf, M * sizeof(float)); @@ -314,8 +413,11 @@ void sgemv_trans(const int M, float *y, bool flag_bias, const float *bias, - bool flag_relu, - const ARMContext *ctx) { + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six, + float alpha) { int m_cnt8 = M >> 3; int m_cnt4 = (M & 7) >> 2; int m_remain = M & 7 & 3; @@ -497,43 +599,73 @@ void sgemv_trans(const int M, valid_ths = rdc_ths; rdc_ths = rdc_ths >> 1; } - if (flag_relu) { + // do activation + if (flag_act) { float *in_y = y_buf; float32x4_t vzero = vdupq_n_f32(0.f); - if (m_cnt8 > 0) { - int cnt8 = m_cnt8; - asm volatile( - "vld1.32 {d0-d3}, [%[in_y]]! \n" /* load y to q0, q1 */ - "1:\n" - "vmax.f32 q2, q0, %q[vzero] \n" /* q0 relu */ - "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ - "vmax.f32 q3, q1, %q[vzero] \n" /* q1 relu */ - "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ - "vst1.32 {d4-d7}, [%[out_y]]! \n" /* store q0, q1 to y*/ - "vld1.32 {d2-d3}, [%[in_y]]! \n" /* load y to q0 */ - "bne 1b \n" /* branch to label 1*/ - "sub %[in_y], %[in_y], #32 \n" /* restore in_y */ - : [cnt] "+r"(cnt8), [in_y] "+r"(in_y), [out_y] "+r"(y) - : [vzero] "w"(vzero) - : "q0", "q1", "q2", "q3", "cc", "memory"); - } - if (m_cnt4 > 0) { - int cnt4 = m_cnt4; - asm volatile( - "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ - "1:\n" - "vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */ - "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ - "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ - "vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */ - "bne 1b \n" /* branch to label 1*/ - "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ - : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) - : [vzero] "w"(vzero) - : "q0", "q1", "cc", "memory"); - } - for (int r = 0; r < m_remain; ++r) { - y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + m_cnt4 = M >> 2; + m_remain = M & 3; + if (act == lite_api::ActivationType::kRelu) { + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "1:\n" + "vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu */ + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + for (int r = 0; r < m_remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + } + } else if (act == lite_api::ActivationType::kRelu6) { + float32x4_t vsix = vdupq_n_f32(six); + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "1:\n" + "vmax.f32 q1, q0, %q[vzero] \n" /* q0 relu6 */ + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "vmin.f32 q1, q1, %q[vsix] \n" /* q0 relu6 */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d2-d3}, [%[out_y]]! \n" /* store q1 to y */ + "bne 1b \n" /* branch to label 1*/ + "sub %[in_y], %[in_y], #16 \n" /* restore in_y */ + : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [vsix] "w"(vsix) + : "q0", "q1", "cc", "memory"); + } + for (int r = 0; r < m_remain; ++r) { + y[r] = in_y[r] > 0.f ? in_y[r] : 0.f; + y[r] = y[r] > six ? six : y[r]; + } + } else if (act == lite_api::ActivationType::kLeakyRelu) { + float32x4_t valpha = vdupq_n_f32(alpha); + if (m_cnt4 > 0) { + int cnt4 = m_cnt4; + asm volatile( + "1:\n" + "vld1.32 {d0-d1}, [%[in_y]]! \n" /* load y to q0 */ + "vcge.f32 q3, q0, %q[vzero] \n" /* vcgeq_f32 */ + "vmul.f32 q4, q0, %q[valpha] \n" /* vmulq_f32 */ + "vbif q0, q4, q3 \n" /* choose */ + "subs %[cnt], %[cnt], #1 \n" /* sub cnt */ + "vst1.32 {d0-d1}, [%[out_y]]! \n" /* store q0 to y */ + "bne 1b \n" /* branch to label 1*/ + : [cnt] "+r"(cnt4), [in_y] "+r"(in_y), [out_y] "+r"(y) + : [vzero] "w"(vzero), [valpha] "w"(valpha) + : "q0", "q3", "q4", "cc", "memory"); + } + for (int r = 0; r < m_remain; ++r) { + y[r] = in_y[r] < 0.f ? alpha * in_y[r] : in_y[r]; + } } } else { memcpy(y, y_buf, M * sizeof(float)); @@ -541,41 +673,6 @@ void sgemv_trans(const int M, } #endif // __aarch64__ -bool sgemv(const float *A, - const float *x, - float *y, - bool transA, - int M, - int N, - bool is_bias, - const float *bias, - bool is_relu, - const ARMContext *ctx) { - if (transA) { - sgemv_trans(M, N, A, x, y, is_bias, bias, is_relu, ctx); - } else { - if (is_bias) { - //! with bias - if (is_relu) { - //! with relu - sgemv_bias_relu(transA, M, N, A, x, y, bias); - } else { - //! without relu - sgemv_bias(transA, M, N, A, x, y, bias); - } - } else { - //! without bias - if (is_relu) { - //! with relu - sgemv_relu(transA, M, N, A, x, y); - } else { - //! without relu - sgemv(transA, M, N, A, x, y); - } - } - } - return true; -} // clang-format off //! define compute kernel #ifdef __aarch64__ @@ -715,19 +812,19 @@ bool sgemv(const float *A, #define SGEMV_KERNEL_1 \ /* check main loop */ \ "cmp %w[cnt], #1 \n" /* check whether has main loop */ \ - "blt 2f \n" /* jump to tail */ /* main loop */ \ - "1: \n" /* main loop */ \ - "ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \ - "ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \ - "fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \ - "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \ - "fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \ + "blt 2f \n" /* jump to tail */ \ + "1: \n" /* main loop */ \ + "ldp q8, q9, [%[in]], #32 \n" /* load input 8 float */ \ + "ldp q10, q11, [%[w0]], #32 \n" /* load w0 8 float */ \ + "fmla v0.4s, v8.4s, v10.4s \n" /* mul + add*/ \ + "subs %w[cnt], %w[cnt], #1 \n" /* sub main loop count */ \ + "fmla v1.4s, v9.4s, v11.4s \n" /* mul + add*/ \ "bne 1b \n" /* jump to main loop */ \ /* pair add to final result */ \ "2: \n" /* reduce to scale */ \ "fadd v9.4s, v0.4s, v1.4s \n" /* add 2 vector */ \ "faddp v10.4s, v9.4s, v9.4s\n" /* pair add to vector */ \ - "faddp s8, v10.2s \n" /* pair add to scale */ /* check tails */ \ + "faddp s8, v10.2s \n" /* pair add to scale */ \ "cmp %w[tail], #1 \n" /* check whether has tail */ \ "blt 4f \n" /* jump to end */ \ "3: \n" /* tail loop */ \ @@ -737,43 +834,100 @@ bool sgemv(const float *A, "subs %w[tail], %w[tail], #1\n" /* sub tail loop count */ \ "bne 3b \n" /* jump to tail loop */ -#define SGEMV_OUT_8 \ - /* end */ \ - "4: \n" /* end */ \ - "stp s8, s9, [%[out]] \n" /* save result */ \ - "stp s10, s11, [%[out], #8] \n" /* save result */ \ - "stp s12, s13, [%[out], #16]\n" /* save result */ \ - "stp s14, s15, [%[out], #24]\n" /* save result */ +#define SGEMV_OUT_8 \ + /* end */ \ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "stp q8, q9, [%[out]] \n" /* save result */ #define SGEMV_OUT_8_RELU \ /* end */ \ - "4: \n" /* end */ \ - "movi d0, #0 \n" /* zero data for relu */ \ - "fmax s8, s8, s0 \n" /* relu */ \ - "fmax s9, s9, s0 \n" /* relu */ \ - "fmax s10, s10, s0 \n" /* relu */ \ - "fmax s11, s11, s0 \n" /* relu */ \ - "fmax s12, s12, s0 \n" /* relu */ \ - "fmax s13, s13, s0 \n" /* relu */ \ - "fmax s14, s14, s0 \n" /* relu */ \ - "fmax s15, s15, s0 \n" /* relu */ \ - "stp s8, s9, [%[out]] \n" /* save result */ \ - "stp s10, s11, [%[out], #8] \n" /* save result */ \ - "stp s12, s13, [%[out], #16]\n" /* save result */ \ - "stp s14, s15, [%[out], #24]\n" /* save result */ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "movi v2.4s, #0 \n" /* zero data for relu */\ + "fmax v8.4s, v8.4s, v2.4s \n" /* relu */ \ + "fmax v9.4s, v9.4s, v2.4s \n" /* relu */ \ + "stp q8, q9, [%[out]] \n" /* save result */ -#define SGEMV_OUT_1 \ - /* end */ \ - "4: \n" /* end */ \ +#define SGEMV_OUT_8_RELU6 \ + /* end */ \ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "movi v2.4s, #0 \n" /* zero data for relu6 */\ + "fmax v8.4s, v8.4s, v2.4s \n" /* relu6 */ \ + "fmax v9.4s, v9.4s, v2.4s \n" /* relu6 */ \ + "fmin v8.4s, v8.4s, %[vsix].4s \n" /* relu */ \ + "fmin v9.4s, v9.4s, %[vsix].4s \n" /* relu */ \ + "stp q8, q9, [%[out]] \n" /* save result */ + +#define SGEMV_OUT_8_LEAKEY_RELU \ + /* end */ \ + "4: \n" /* end */ \ + "mov v8.s[1], v9.s[0] \n" /* ins s9 to v8[1]*/ \ + "mov v8.s[2], v10.s[0] \n" /* ins s10 to v8[2]*/ \ + "mov v8.s[3], v11.s[0] \n" /* ins s11 to v8[3]*/ \ + "mov v9.s[0], v12.s[0] \n" /* ins s12 to v9[0]*/ \ + "mov v9.s[1], v13.s[0] \n" /* ins s13 to v9[1]*/ \ + "mov v9.s[2], v14.s[0] \n" /* ins s14 to v9[2]*/ \ + "mov v9.s[3], v15.s[0] \n" /* ins s15 to v9[3]*/ \ + "movi v2.4s, #0 \n" /* zero data for leakey relu */ \ + "fcmge v4.4s, v8.4s, v2.4s \n" /* vcgeq_f32 */ \ + "fmul v5.4s, v8.4s, %[valpha].4s \n" /* vmulq_f32 */ \ + "fcmge v6.4s, v9.4s, v2.4s \n" /* vcgeq_f32 */ \ + "fmul v7.4s, v9.4s, %[valpha].4s \n" /* vmulq_f32 */ \ + "bif v8.16b, v5.16b, v4.16b \n" /* choose*/ \ + "bif v9.16b, v7.16b, v6.16b \n" /* choose*/ \ + "stp q8, q9, [%[out]] \n" /* save result */ + +#define SGEMV_OUT_1 \ + /* end */ \ + "4: \n" /* end */ \ "str s8, [%[out]] \n" /* save result */ #define SGEMV_OUT_1_RELU \ /* end */ \ "4: \n" /* end */ \ - "movi d0, #0 \n" /* zero data for relu */ \ - "fmax s8, s8, s0 \n" /* relu */ \ + "movi d1, #0 \n" /* zero data for relu */ \ + "fmax s8, s8, s1 \n" /* relu */ \ + "str s8, [%[out]] \n" /* save result */ + +#define SGEMV_OUT_1_RELU6 \ + /* end */ \ + "4: \n" /* end */ \ + "movi d1, #0 \n" /* zero data for relu6 */ \ + "fmov s2, %w[six] \n" /* mov six to s2 */ \ + "fmax s8, s8, s1 \n" /* relu6 */ \ + "fmin s8, s8, s2 \n" /* relu6 */ \ "str s8, [%[out]] \n" /* save result */ +#define SGEMV_OUT_1_LEAKEY_RELU \ + /* end */ \ + "4: \n" /* end */ \ + "fmov s1, %w[alpha] \n" /* mov alpha to s1 */ \ + "fcmp s8, #0 \n" /* cmp with zero*/ \ + "bge 5f \n" /* if ge zero */ \ + "fmul s8, s8, s1 \n" /* out * alpha */ \ + "5: \n" /* leakey relu label */ \ + "str s8, [%[out]] \n" /* save result */ + #else // __aarch64__ #define SGEMV_IN_4 \ @@ -841,14 +995,13 @@ bool sgemv(const float *A, "vmla.f32 q2, q5, q11 @ mul add\n" \ "vmla.f32 q3, q5, q13 @ mul add\n" \ "bne 1b @ jump to main loop\n" \ - /* pair add to final result */ \ "2: @ pair add \n" \ "vpadd.f32 d8, d0, d1 @ pair add, first step\n" \ "vpadd.f32 d9, d2, d3 @ pair add, first step\n" \ "vpadd.f32 d10, d4, d5 @ pair add, first step\n" \ "vpadd.f32 d11, d6, d7 @ pair add, first step\n" \ "vpadd.f32 d0, d8, d9 @ pair add, second step\n" \ - "vpadd.f32 d1, d10, d11 @ pair add, second step\n" /* check tails */ \ + "vpadd.f32 d1, d10, d11 @ pair add, second step\n" \ "cmp %[tail], #1 @ check whether has tail\n" \ "blt 4f @ jump to end\n" \ "3: @ tail loop\n" \ @@ -876,7 +1029,7 @@ bool sgemv(const float *A, "bne 1b @ jump to main loop\n" \ "2: @ end processing\n" \ "vpadd.f32 d2, d0, d1 @ pair add, first step\n" \ - "vpadd.f32 d0, d2, d2 @ pair add, final step\n"/*check tails*/ \ + "vpadd.f32 d0, d2, d2 @ pair add, final step\n" \ "cmp %[tail], #1 @ check whether has mid cols\n" \ "blt 4f @ jump to end\n" \ "3: @ tail loop\n" \ @@ -898,6 +1051,25 @@ bool sgemv(const float *A, "vmax.f32 q0, q0, q1 @ relu\n" \ "vst1.32 {d0-d1}, [%[out]] @ save result\n" +#define SGEMV_OUT_4_RELU6 \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 q1, #0 @ zero for relu6\n" \ + "vdup.f32 q2, %[six] @ six for relu6\n" \ + "vmax.f32 q0, q0, q1 @ relu6\n" \ + "vmin.f32 q0, q0, q2 @ relu6\n" \ + "vst1.32 {d0-d1}, [%[out]] @ save result\n" + +#define SGEMV_OUT_4_LEAKEY_RELU \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 q1, #0 @ zero for leakey relu\n" \ + "vdup.f32 q2, %[alpha] @ alpha for leakey relu\n" \ + "vcge.f32 q3, q0, q1 @ vcgeq_f32 \n" \ + "vmul.f32 q4, q0, q2 @ vmulq_f32 \n" \ + "vbif q0, q4, q3 @ choose \n" \ + "vst1.32 {d0-d1}, [%[out]] @ save result\n" + #define SGEMV_OUT_1 \ /* end */ \ "4: @ end\n" \ @@ -909,14 +1081,36 @@ bool sgemv(const float *A, "vmov.i32 d1, #0 @ zero for relu\n" \ "vmax.f32 d0, d0, d1 @ relu\n" \ "vst1.32 {d0[0]}, [%[out]] @ save result\n" + +#define SGEMV_OUT_1_RELU6 \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 d1, #0 @ zero for relu6\n" \ + "vdup.f32 d4, %[six] @ six for relu6\n" \ + "vmax.f32 d0, d0, d1 @ relu6\n" \ + "vmin.f32 d0, d0, d4 @ relu6\n" \ + "vst1.32 {d0[0]}, [%[out]] @ save result\n" + +#define SGEMV_OUT_1_LEAKEY_RELU \ + /* end */ \ + "4: @ end\n" \ + "vmov.i32 d2, #0 @ zero for leakey relu\n" \ + "vdup.f32 d3, %[alpha] @ alpha for leakey relu\n" \ + "vcge.f32 d6, d0, d2 @ vcgeq_f32 \n" \ + "vmul.f32 d8, d0, d3 @ vmulq_f32 \n" \ + "vbif d0, d8, d6 @ choose \n" \ + "vst1.32 {d0[0]}, [%[out]] @ save result\n" + #endif // clang-format on -void sgemv(const bool transA, - const int M, + +void sgemv(const int M, const int N, const float *A, const float *x, - float *y) { + float *y, + bool flag_bias, + const float *bias) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; @@ -926,7 +1120,6 @@ void sgemv(const bool transA, #ifdef __aarch64__ int out_cnt = M >> 3; - #pragma omp parallel for for (int j = 0; j < out_cnt; j++) { int out_idx = j * 8; @@ -940,9 +1133,22 @@ void sgemv(const bool transA, const float *ptr_w5 = ptr_w4 + N; const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; + const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8 + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -954,35 +1160,12 @@ void sgemv(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -992,24 +1175,17 @@ void sgemv(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float tmp[4]; - float tmp1[4]; - float tmp2[4]; - float tmp3[4]; - float tmp4[4]; - asm volatile( - SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1 - : [in] "+r"(ptr_in), - [w0] "+r"(ptr_w0), - [cnt] "+r"(cnt_loop), - [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), - [tmp] "r"(tmp), - [tmp1] "r"(tmp1), - [tmp2] "r"(tmp2), - [tmp3] "r"(tmp3), - [tmp4] "r"(tmp4) - : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0) + : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc"); } #else // __aarch64__ int out_cnt = M >> 2; @@ -1022,10 +1198,20 @@ void sgemv(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4 + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1033,23 +1219,16 @@ void sgemv(const bool transA, [w3] "+r"(ptr_w3), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + : [out] "r"(ptr_out), + [bias0] "r"(bias0), + [bias1] "r"(bias1), + [bias2] "r"(bias2), + [bias3] "r"(bias3) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1059,23 +1238,28 @@ void sgemv(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1 + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) + : [out] "r"(ptr_out), [bias0] "r"(bias0) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } -void sgemv_relu(const bool transA, - const int M, +void sgemv_relu(const int M, const int N, const float *A, const float *x, - float *y) { + float *y, + bool flag_bias, + const float *bias) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; @@ -1098,9 +1282,22 @@ void sgemv_relu(const bool transA, const float *ptr_w5 = ptr_w4 + N; const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; + const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8 SGEMV_KERNEL_8 SGEMV_OUT_8_RELU + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1112,35 +1309,12 @@ void sgemv_relu(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1150,13 +1324,17 @@ void sgemv_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } asm volatile( - SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) + : [out] "r"(ptr_out), [bias0] "r"(bias0) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } #else // __aarch64__ @@ -1170,10 +1348,20 @@ void sgemv_relu(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4 SGEMV_KERNEL_4 SGEMV_OUT_4_RELU + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1181,23 +1369,16 @@ void sgemv_relu(const bool transA, [w3] "+r"(ptr_w3), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + : [out] "r"(ptr_out), + [bias0] "r"(bias0), + [bias1] "r"(bias1), + [bias2] "r"(bias2), + [bias3] "r"(bias3) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1207,31 +1388,36 @@ void sgemv_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_1 SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out) + : [out] "r"(ptr_out), [bias0] "r"(bias0) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } -void sgemv_bias(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y, - const float *bias) { +void sgemv_relu6(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float six) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; int cnt = N >> 3; int tail = N & 7; - + float32x4_t vsix = vdupq_n_f32(six); #ifdef __aarch64__ int out_cnt = M >> 3; #pragma omp parallel for @@ -1248,9 +1434,21 @@ void sgemv_bias(const bool transA, const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8 + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1262,35 +1460,13 @@ void sgemv_bias(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local), + [vsix] "w" (vsix) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1300,14 +1476,17 @@ void sgemv_bias(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } asm volatile( - SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [six] "r"(six) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } #else // __aarch64__ @@ -1321,14 +1500,20 @@ void sgemv_bias(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - float bias0 = bias[out_idx]; - float bias1 = bias[out_idx + 1]; - float bias2 = bias[out_idx + 2]; - float bias3 = bias[out_idx + 3]; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4 + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1340,23 +1525,13 @@ void sgemv_bias(const bool transA, [bias0] "r"(bias0), [bias1] "r"(bias1), [bias2] "r"(bias2), - [bias3] "r"(bias3) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + [bias3] "r"(bias3), + [six] "r" (six) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1366,30 +1541,35 @@ void sgemv_bias(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; - asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1 + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU6 : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [six] "r"(six) : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } -void sgemv_bias_relu(const bool transA, - const int M, - const int N, - const float *A, - const float *x, - float *y, - const float *bias) { +void sgemv_leakey_relu(const int M, + const int N, + const float *A, + const float *x, + float *y, + bool flag_bias, + const float *bias, + const float alpha) { float *data_out = y; const float *data_in = x; const float *weights_ptr = A; int cnt = N >> 3; int tail = N & 7; + float32x4_t valpha = vdupq_n_f32(alpha); #ifdef __aarch64__ int out_cnt = M >> 3; #pragma omp parallel for @@ -1406,9 +1586,21 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w6 = ptr_w5 + N; const float *ptr_w7 = ptr_w6 + N; const float *bias_ptr = bias + out_idx; + float bias_local[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + if (flag_bias) { + bias_local[0] = bias_ptr[0]; + bias_local[1] = bias_ptr[1]; + bias_local[2] = bias_ptr[2]; + bias_local[3] = bias_ptr[3]; + bias_local[4] = bias_ptr[4]; + bias_local[5] = bias_ptr[5]; + bias_local[6] = bias_ptr[6]; + bias_local[7] = bias_ptr[7]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_RELU + // clang-format off + asm volatile(SGEMV_IN_8_BIAS SGEMV_KERNEL_8 SGEMV_OUT_8_LEAKEY_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1420,35 +1612,13 @@ void sgemv_bias_relu(const bool transA, [w7] "+r"(ptr_w7), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias_ptr] "r"(bias_ptr) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "cc", - "memory"); + : [out] "r"(ptr_out), [bias_ptr] "r"(bias_local), + [valpha] "w" (valpha) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1458,14 +1628,17 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } asm volatile( - SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_LEAKEY_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [cnt] "+r"(cnt_loop), [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [alpha] "r"(alpha) : "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17", "cc", "memory"); } #else // __aarch64__ @@ -1479,14 +1652,20 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w1 = ptr_w0 + N; const float *ptr_w2 = ptr_w1 + N; const float *ptr_w3 = ptr_w2 + N; - float bias0 = bias[out_idx]; - float bias1 = bias[out_idx + 1]; - float bias2 = bias[out_idx + 2]; - float bias3 = bias[out_idx + 3]; - + float bias0 = 0.f; + float bias1 = 0.f; + float bias2 = 0.f; + float bias3 = 0.f; + if (flag_bias) { + bias0 = bias[out_idx]; + bias1 = bias[out_idx + 1]; + bias2 = bias[out_idx + 2]; + bias3 = bias[out_idx + 3]; + } int cnt_loop = cnt; int tail_loop = tail; - asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_RELU + // clang-format off + asm volatile(SGEMV_IN_4_BIAS SGEMV_KERNEL_4 SGEMV_OUT_4_LEAKEY_RELU : [in] "+r"(ptr_in), [w0] "+r"(ptr_w0), [w1] "+r"(ptr_w1), @@ -1498,23 +1677,13 @@ void sgemv_bias_relu(const bool transA, [bias0] "r"(bias0), [bias1] "r"(bias1), [bias2] "r"(bias2), - [bias3] "r"(bias3) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "cc", + [bias3] "r"(bias3), + [alpha] "r" (alpha) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12", "q13", "cc", "memory"); + // clang-format on } //! deal with remains #pragma omp parallel for @@ -1524,14 +1693,18 @@ void sgemv_bias_relu(const bool transA, const float *ptr_w0 = weights_ptr + (N * j); int cnt_loop = cnt; int tail_loop = tail; - float bias0 = bias[j]; - asm volatile(SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_RELU - : [in] "+r"(ptr_in), - [w0] "+r"(ptr_w0), - [cnt] "+r"(cnt_loop), - [tail] "+r"(tail_loop) - : [out] "r"(ptr_out), [bias0] "r"(bias0) - : "q0", "q1", "q12", "q13", "q14", "q15", "cc", "memory"); + float bias0 = 0.f; + if (flag_bias) { + bias0 = bias[j]; + } + asm volatile( + SGEMV_IN_1_BIAS SGEMV_KERNEL_1 SGEMV_OUT_1_LEAKEY_RELU + : [in] "+r"(ptr_in), + [w0] "+r"(ptr_w0), + [cnt] "+r"(cnt_loop), + [tail] "+r"(tail_loop) + : [out] "r"(ptr_out), [bias0] "r"(bias0), [alpha] "r"(alpha) + : "q0", "q1", "q3", "q4", "q12", "q13", "q14", "q15", "cc", "memory"); } #endif // __aarch64__ } diff --git a/lite/backends/arm/math/sgemv.h b/lite/backends/arm/math/sgemv.h index aa17349c99e61f7135090318be829149ecd6bb57..53b2c2ab55a2cee51f8535683c5cf34340fd6dab 100644 --- a/lite/backends/arm/math/sgemv.h +++ b/lite/backends/arm/math/sgemv.h @@ -17,23 +17,26 @@ #include #include "lite/core/context.h" #include "lite/core/device_info.h" +#include "lite/operators/op_params.h" namespace paddle { namespace lite { namespace arm { namespace math { -// TODO(xxx): fixme now only support transA = false -bool sgemv(const float* A, - const float* x, - float* y, +bool sgemv(const float *A, + const float *x, + float *y, bool transA, int M, int N, bool is_bias, - const float* bias, - bool is_relu, - const ARMContext* ctx); + const float *bias, + bool flag_act, + lite_api::ActivationType act, + const ARMContext *ctx, + float six = 6.f, + float alpha = 1.f); } // namespace math } // namespace arm diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 4afb8f020e1c9001428e83709d95c167900bbfd1..b58244d97202725fa104d9ee57b996d06740d64b 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -42,8 +42,6 @@ void ConvCompute::PrepareForRun() { int stride = param.strides[0]; int threads = ctx.threads(); - bool pads_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); int chin = param.x->dims()[1]; int hin = param.x->dims()[2]; int win = param.x->dims()[3]; @@ -51,28 +49,28 @@ void ConvCompute::PrepareForRun() { int hout = param.output->dims()[2]; int wout = param.output->dims()[3]; + bool pads_equal = + ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); bool pads_all_equal = (pads_equal && paddings[0] == paddings[2]); - bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh); + bool ks_equal = (param.strides[0] == param.strides[1]) && (kw == kh); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); - bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2)); - bool flag_dw_5x5 = (paddings[0] == paddings[2]) && - ((kw == 5 && stride == 1) || (kw == 5 && stride == 2)); + + bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2); + bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; /// select conv impl - if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { - /// dw conv impl + if (param.groups == ic && ic == oc && ks_equal && no_dilation && flag_dw) { impl_ = new DepthwiseConv; // VLOG(3) << "invoking dw conv"; - } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && + } else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal && no_dilation && pads_all_equal) { - /// winograd conv impl + // TODO(MyPandaShaoxiang): winograd conv support any pad impl_ = new WinogradConv; // VLOG(3) << "invoking winograd conv"; } else if (param.groups == 1 && kw == 3 && stride == 2 && - chin * chout < 4 * hin * win && kps_equal && no_dilation) { - /// direct conv impl + chin * chout < 4 * hin * win && ks_equal && no_dilation) { impl_ = new DirectConv; // VLOG(3) << "invoking direct conv"; } else { @@ -109,7 +107,7 @@ void ConvCompute::PrepareForRun() { bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); - bool flag_dw_5x5 = pads_all_equal && (kw == 5 && sw == 1); + bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; if (param.groups == ic && ic == oc && kps_equal && pads_equal && @@ -154,7 +152,7 @@ void ConvCompute::PrepareForRun() { bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); - bool flag_dw_5x5 = pads_all_equal && (kw == 5 && sw == 1); + bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; if (param.groups == ic && ic == oc && kps_equal && pads_equal && diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index 10c190806fa7cd09da66afc1c242da054c460dfb..6f641d0f27ad3d0a1c19a667a0874a62f2d68116 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -52,7 +52,10 @@ void DepthwiseConv::PrepareForRun() { impl_ = lite::arm::math::conv_depthwise_3x3_fp32; } else if (kw == 5) { // VLOG(5) << "invoke 5x5 dw conv fp32"; - if (param.strides[0] == 2) { // conv5x5s2_dw + auto strides = param.strides; + if ((strides[0] == 1 && strides[1] == 1) || + (strides[0] == 2 && strides[1] == 2)) { + // trans weights constexpr int cblock = 4; auto oc = w_dims[0]; auto kh = w_dims[2]; @@ -63,10 +66,11 @@ void DepthwiseConv::PrepareForRun() { lite::arm::math::conv_trans_weights_numc( w_data_in, w_data, oc, 1, cblock, kh * kw); flag_trans_weights_ = true; + impl_ = lite::arm::math::conv_depthwise_5x5_fp32; } else { - flag_trans_weights_ = false; + LOG(FATAL) + << "5x5 depthwise conv only support stride == 1 or stride == 2"; } - impl_ = lite::arm::math::conv_depthwise_5x5_fp32; } else { LOG(FATAL) << "this type dw conv not impl"; } diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index cc119d3802ef1b3a92002767e96845e4ddfba500..1269a259072b6ae54759794f06040340cc42e15e 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -93,9 +93,11 @@ void FcCompute::Run() { if (flag_trans_bias_) { b_data = bias_.data(); } - bool flag_relu = false; + bool flag_act = false; + lite_api::ActivationType act; if (param.activation_type == "relu") { - flag_relu = true; + act = lite_api::ActivationType::kRelu; + flag_act = true; } if (flag_gemm_) { operators::ActivationParam act_param; @@ -119,7 +121,7 @@ void FcCompute::Run() { &ctx); if (param.bias) { CHECK_EQ(param.bias->numel(), n_); - lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_relu); + lite::arm::math::fill_bias_fc(o_data, b_data, m_, n_, flag_act); } } else { for (int i = 0; i < m_; ++i) { @@ -133,7 +135,8 @@ void FcCompute::Run() { k_, param.bias != nullptr, b_data, - flag_relu, + flag_act, + act, &ctx); } } diff --git a/lite/kernels/arm/matmul_compute.cc b/lite/kernels/arm/matmul_compute.cc index afcbe7267cb91f3d07e90c4f9d86253e4d270936..2841fa13f7a04026bc9040a8bd9fdc98dd7e149e 100644 --- a/lite/kernels/arm/matmul_compute.cc +++ b/lite/kernels/arm/matmul_compute.cc @@ -233,8 +233,17 @@ void MatMulCompute::Run() { int ldb = n_; int ldc = n_; if (n_ == 1) { - lite::arm::math::sgemv( - x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx); + lite::arm::math::sgemv(x_data, + y_data, + o_data, + false, + m_, + k_, + false, + nullptr, + false, + lite_api::ActivationType::kIndentity, + &ctx); if (fabsf(alpha - 1.f) > 1e-8f) { for (size_t i = 0; i < param.Out->dims().production(); ++i) { o_data[i] *= alpha; diff --git a/lite/kernels/arm/mul_compute.cc b/lite/kernels/arm/mul_compute.cc index a5de6c202c99502f9c4e6289ec411e4b8cf09e99..1321d001fd1d8a30b179d73979c4164cbe8916e1 100644 --- a/lite/kernels/arm/mul_compute.cc +++ b/lite/kernels/arm/mul_compute.cc @@ -50,8 +50,17 @@ void MulCompute::Run() { k_ = x_w; auto& ctx = this->ctx_->template As(); if (n_ == 1) { - lite::arm::math::sgemv( - x_data, y_data, o_data, false, m_, k_, false, nullptr, false, &ctx); + lite::arm::math::sgemv(x_data, + y_data, + o_data, + false, + m_, + k_, + false, + nullptr, + false, + lite_api::ActivationType::kIndentity, + &ctx); } else { constexpr bool is_tranposed_y = false; diff --git a/lite/tests/kernels/fc_compute_test.cc b/lite/tests/kernels/fc_compute_test.cc index de7bbc2158ea7e604a9ddbe24fadfe34e0492f9e..bd6d86b599e3e5340c23cc072ce46cc17ddf15df 100644 --- a/lite/tests/kernels/fc_compute_test.cc +++ b/lite/tests/kernels/fc_compute_test.cc @@ -131,7 +131,7 @@ class FcOPTest : public arena::TestCase { 1.f, 0.f, true, - flag_bias, + static_cast(flag_bias), false); } else { basic_gemm(false, diff --git a/lite/tests/math/conv_compute_test.cc b/lite/tests/math/conv_compute_test.cc index 367eb6c34761b8d0989da0d2e99aa00442d0c76b..53a9a00ccf2ad80e5ccd9d9b3a7244be769c9d7a 100644 --- a/lite/tests/math/conv_compute_test.cc +++ b/lite/tests/math/conv_compute_test.cc @@ -46,14 +46,19 @@ DEFINE_int32(out_channel, 32, "output channel"); DEFINE_int32(group, 1, "group"); DEFINE_int32(kernel_h, 3, "kernel height"); DEFINE_int32(kernel_w, 3, "kernel width"); -DEFINE_int32(pad_h, 1, "pad height"); -DEFINE_int32(pad_w, 1, "pad width"); +DEFINE_int32(pad_h0, 1, "pad top"); +DEFINE_int32(pad_h1, 1, "pad bottom"); +DEFINE_int32(pad_w0, 1, "pad left"); +DEFINE_int32(pad_w1, 1, "pad right"); DEFINE_int32(stride_h, 1, "stride height"); DEFINE_int32(stride_w, 1, "stride width"); DEFINE_int32(dila_h, 1, "dilation height"); DEFINE_int32(dila_w, 1, "dilation width"); -DEFINE_bool(flag_relu, true, "do relu"); +DEFINE_int32(flag_act, + 0, + "do activation"); // 0-no act, 1-relu, 2-relu6, 4-leakyrelu +DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha"); DEFINE_bool(flag_bias, true, "with bias"); typedef paddle::lite::DDim DDim; @@ -98,9 +103,10 @@ void test_conv_fp32(const std::vector& input_dims, const std::vector& pads, const std::vector& dilas, bool flag_bias, - bool flag_relu, + int flag_act, const std::vector& thread_num, - const std::vector& power_mode) { + const std::vector& power_mode, + const float leakey_relu_scale) { #ifdef LITE_WITH_ARM paddle::lite::DeviceInfo::Init(); #endif @@ -118,13 +124,20 @@ void test_conv_fp32(const std::vector& input_dims, param.strides = strides; param.paddings = std::make_shared>(pads); param.dilations = std::make_shared>(dilas); - param.fuse_relu = flag_relu; param.groups = group; - if (flag_relu) { + const float six = 6.f; + if (flag_act > 0) { ActivationParam act_param; act_param.has_active = true; - act_param.active_type = - (paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu + act_param.active_type = (paddle::lite_api::ActivationType) + flag_act; // 1-relu, 2-relu6, 4-leakyrelu + if (flag_act == 1) { + param.fuse_relu = true; + } else if (flag_act == 2) { + act_param.Relu_clipped_coef = six; + } else if (flag_act == 4) { + act_param.Leaky_relu_alpha = leakey_relu_scale; + } param.activation_param = act_param; } @@ -205,7 +218,9 @@ void test_conv_fp32(const std::vector& input_dims, pads[2], pads[0], flag_bias, - flag_relu); + flag_act, + six, + leakey_relu_scale); } /// warm up for (int i = 0; i < FLAGS_warmup; ++i) { @@ -254,22 +269,20 @@ void test_conv_fp32(const std::vector& input_dims, << ", dila_: " << dilas[0] << ", " << dilas[1] << ", group: " << group << ", bias: " << (flag_bias ? "true" : "false") - << ", relu: " << (flag_relu ? "true" : "false") - << ", threads: " << th << ", power_mode: " << cls - << " failed!!\n"; + << ", act: " << flag_act << ", threads: " << th + << ", power_mode: " << cls << " failed!!\n"; } } } LOG(INFO) << "test fp32 conv: input: " << dim_in << ", output: " << dim_out << ", weight dim: " << weight_dim - << ", pad: " << pads[0] << ", " << pads[1] - << ", stride: " << strides[0] << ", " << strides[1] - << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2] + << ", " << pads[3] << ", stride: " << strides[0] << ", " + << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] << ", group: " << group << ", bias: " << (flag_bias ? "true" : "false") - << ", relu: " << (flag_relu ? "true" : "false") - << ", threads: " << th << ", power_mode: " << cls - << " successed!!\n"; + << ", act: " << flag_act << ", threads: " << th + << ", power_mode: " << cls << " successed!!\n"; } } } @@ -287,12 +300,14 @@ void test_conv_fp32(const std::vector& input_dims, const std::vector& pads, const std::vector& dilas, bool flag_bias, - bool flag_relu, + int flag_act, const std::vector& thread_num, - const std::vector& power_mode) {} + const std::vector& power_mode, + const float leakey_relu_scale) {} #endif // LITE_WITH_ARM -#if 1 /// 3x3dw +// TODO(chenjiaoAngel): fix me, diff: 3x3 depthwise conv +#if 0 /// 3x3dw TEST(TestConv3x3DW, test_conv3x3_depthwise) { if (FLAGS_basic_test) { for (auto& stride : {1, 2}) { @@ -301,7 +316,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { for (auto& pad_top : {0, 1, 2}) { for (auto& pad_bottom : {0, 1, 2}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { + for (auto& flag_act : {0, 1, 2, 4}) { for (auto& c : {1, 3, 5, 8, 16, 32}) { std::vector dims; DDim weights_dim({c, 1, 3, 3}); @@ -310,6 +325,7 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { dims.push_back(DDim({batch, c, h, h})); } } + const float leakey_relu_scale = 8.88; test_conv_fp32(dims, weights_dim, c, @@ -317,9 +333,10 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { {pad_top, pad_bottom, pad_left, pad_right}, {1, 1}, flag_bias, - flag_relu, + flag_act, {1, 2, 4}, - {FLAGS_power_mode}); + {FLAGS_power_mode}, + leakey_relu_scale); } } } @@ -335,28 +352,41 @@ TEST(TestConv3x3DW, test_conv3x3_depthwise) { #if 1 /// 5x5dw TEST(TestConv5x5DW, test_conv5x5_depthwise) { if (FLAGS_basic_test) { +#ifdef __aarch64__ + // TODO(chenjiaoAngel): fix me, diff: arm64 5x5s2 depthwise conv + for (auto& stride : {1}) { +#else for (auto& stride : {1, 2}) { - for (auto& pad : {0, 1, 2}) { - for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { - for (auto& c : {1, 3, 5, 8, 16, 32}) { - std::vector dims; - DDim weights_dim({c, 1, 5, 5}); - for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { - dims.push_back(DDim({batch, c, h, h})); +#endif + for (auto& pad_left : {0, 1, 2}) { + for (auto& pad_right : {0, 1, 2}) { + for (auto& pad_top : {0, 1, 2}) { + for (auto& pad_bottom : {0, 1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_act : {0, 1, 2, 4}) { + for (auto& c : {1, 15, 32}) { + std::vector dims; + DDim weights_dim({c, 1, 5, 5}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 15, 56}) { + dims.push_back(DDim({batch, c, h, h})); + } + } + const float leakey_relu_scale = 8.88; + test_conv_fp32(dims, + weights_dim, + c, + {stride, stride}, + {pad_left, pad_right, pad_top, pad_bottom}, + {1, 1}, + flag_bias, + flag_act, + {4}, + {FLAGS_power_mode}, + leakey_relu_scale); + } } } - test_conv_fp32(dims, - weights_dim, - c, - {stride, stride}, - {pad, pad, pad, pad}, - {1, 1}, - flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); } } } @@ -373,7 +403,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) { for (auto& cout : {1, 5, 16, 37}) { for (auto& g : {1, 2}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { + for (auto& flag_act : {0, 1, 2, 4}) { std::vector dims; if (cin % g != 0 || cout % g != 0) { continue; @@ -384,6 +414,7 @@ TEST(TestConv1x1s1, test_conv1x1s1) { dims.push_back(DDim({batch, cin, h, h})); } } + const float leakey_relu_scale = 8.88; test_conv_fp32(dims, weights_dim, g, @@ -391,9 +422,10 @@ TEST(TestConv1x1s1, test_conv1x1s1) { {0, 0, 0, 0}, {1, 1}, flag_bias, - flag_relu, + flag_act, {1, 2, 4}, - {FLAGS_power_mode}); + {FLAGS_power_mode}, + leakey_relu_scale); } } } @@ -403,24 +435,29 @@ TEST(TestConv1x1s1, test_conv1x1s1) { } #endif /// conv1x1s1 -#if 1 /// conv3x3s1 +// TODO(MyPandaShaoxiang): fix me, diff: 3x3s1 winograd +#if 0 /// conv3x3s1 TEST(TestConv3x3s1, test_conv_3x3s1) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 32, 48}) { - for (auto& cout : {1, 5, 8, 32, 48}) { - for (auto& pad_left : {1, 2}) { - for (auto& pad_right : {1, 2}) { - for (auto& pad_top : {1, 2}) { - for (auto& pad_bottom : {1, 2}) { + for (auto& cin : {1, 3, 8, 8}) { + for (auto& cout : {1, 5, 32, 48}) { + for (auto& pad_left : {0, 1, 2}) { + for (auto& pad_right : {0, 1, 2}) { + for (auto& pad_top : {0, 1, 2}) { + for (auto& pad_bottom : {0, 1, 2}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { + for (auto& flag_act : {0, 1, 2, 4}) { std::vector dims; DDim weights_dim({cout, cin, 3, 3}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 56, 32}) { + for (auto& h : {1, 3, 17, 33}) { dims.push_back(DDim({batch, cin, h, h})); } } + if (cin == 1 && cout ==1) { + continue; + } + const float leakey_relu_scale = 8.88; test_conv_fp32(dims, weights_dim, 1, @@ -428,9 +465,10 @@ TEST(TestConv3x3s1, test_conv_3x3s1) { {pad_top, pad_bottom, pad_left, pad_right}, {1, 1}, flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); + flag_act, + {4}, + {FLAGS_power_mode}, + leakey_relu_scale); } } } @@ -446,21 +484,25 @@ TEST(TestConv3x3s1, test_conv_3x3s1) { #if 1 /// conv3x3s2 TEST(TestConv3x3s2, test_conv_3x3s2) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 32}) { - for (auto& cout : {1, 5, 8, 32}) { - for (auto& pad_left : {1, 2}) { - for (auto& pad_right : {1, 2}) { - for (auto& pad_top : {1, 2}) { - for (auto& pad_bottom : {1, 2}) { + for (auto& cin : {1, 3, 8}) { + for (auto& cout : {1, 3, 9, 32}) { + for (auto& pad_left : {0, 1, 2}) { + for (auto& pad_right : {0, 1, 2}) { + for (auto& pad_top : {0, 1, 2}) { + for (auto& pad_bottom : {0, 1, 2}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { + for (auto& flag_act : {0, 1, 2, 4}) { std::vector dims; DDim weights_dim({cout, cin, 3, 3}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { + for (auto& h : {3, 7, 15, 56, 32}) { dims.push_back(DDim({batch, cin, h, h})); } } + if (cin == 1 && cout == 1) { + continue; + } + const float leakey_relu_scale = 8.88; test_conv_fp32(dims, weights_dim, 1, @@ -468,9 +510,10 @@ TEST(TestConv3x3s2, test_conv_3x3s2) { {pad_top, pad_bottom, pad_left, pad_right}, {1, 1}, flag_bias, - flag_relu, + flag_act, {1, 2, 4}, - {FLAGS_power_mode}); + {FLAGS_power_mode}, + leakey_relu_scale); } } } @@ -486,29 +529,40 @@ TEST(TestConv3x3s2, test_conv_3x3s2) { #if 1 /// random param conv TEST(TestConvRand, test_conv_rand) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 16}) { - for (auto& cout : {1, 5, 8, 16}) { + for (auto& cin : {1, 3, 8}) { + for (auto& cout : {1, 5, 16}) { for (auto& g : {1, 2}) { for (auto& kw : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) { for (auto& stride : {1, 2}) { - for (auto& pad_left : {0, 1, 2}) { - for (auto& pad_right : {0, 1, 2}) { - for (auto& pad_top : {0, 1, 2}) { - for (auto& pad_bottom : {0, 1, 2}) { + for (auto& pad_left : {0, 2}) { + for (auto& pad_right : {0, 2}) { + for (auto& pad_top : {0, 2}) { + for (auto& pad_bottom : {0, 2}) { for (auto& dila : {1, 2}) { for (auto& flag_bias : {false, true}) { - for (auto& flag_relu : {false, true}) { + for (auto& flag_act : {0, 1, 2, 4}) { if (cin % g != 0 || cout % g != 0) { continue; } std::vector dims; DDim weights_dim({cout, cin / g, kh, kw}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 19, 32, 28}) { + for (auto& h : {1, 3, 19, 32}) { dims.push_back(DDim({batch, cin, h, h})); } } + // skip 3x3 depthwise conv + if (g == cin && cin == cout && kw == 3 && + kh == 3) { + break; + } + // skip 3x3s1 direct conv + if (g == 1 && (cin != 1 || cout != 1) && + kw == 3 && kh == 3 && stride == 1) { + break; + } + const float leakey_relu_scale = 8.88; test_conv_fp32( dims, weights_dim, @@ -517,9 +571,10 @@ TEST(TestConvRand, test_conv_rand) { {pad_top, pad_bottom, pad_left, pad_right}, {dila, dila}, flag_bias, - flag_relu, - {1, 2, 4}, - {FLAGS_power_mode}); + flag_act, + {4}, + {FLAGS_power_mode}, + leakey_relu_scale); } } } @@ -551,11 +606,12 @@ TEST(TestConvCustom, test_conv_fp32_custom_size) { FLAGS_kernel_w}), FLAGS_group, {FLAGS_stride_h, FLAGS_stride_w}, - {FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w}, + {FLAGS_pad_h0, FLAGS_pad_h1, FLAGS_pad_w0, FLAGS_pad_w1}, {FLAGS_dila_h, FLAGS_dila_w}, FLAGS_flag_bias, - FLAGS_flag_relu, + FLAGS_flag_act, {FLAGS_threads}, - {FLAGS_power_mode}); + {FLAGS_power_mode}, + FLAGS_leakey_relu_alpha); } #endif // custom diff --git a/lite/tests/math/conv_int8_compute_test.cc b/lite/tests/math/conv_int8_compute_test.cc index 27c186d7ceffcaab3019cedf7c281c524be73e44..8e0094bc3f6b01bde4a338e3d531235bd21f328d 100644 --- a/lite/tests/math/conv_int8_compute_test.cc +++ b/lite/tests/math/conv_int8_compute_test.cc @@ -291,7 +291,7 @@ void test_conv_int8(const std::vector& input_dims, pads[2], pads[0], flag_bias, - flag_relu); + static_cast(flag_relu)); paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32, dout_basic_int8, scale_out.data(), @@ -362,6 +362,7 @@ void test_conv_int8(const std::vector& input_dims, << pads[2] << ", " << pads[3] << ", stride: " << strides[0] << ", " << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", group: " << group << ", bias: " << (flag_bias ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false") << ", threads: " << th << ", power_mode: " << cls @@ -467,7 +468,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { std::vector dims; DDim weights_dim({c, 1, 3, 3}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 15, 19, 75, 32, 28}) { + for (auto& h : {1, 3, 15, 33}) { dims.push_back(DDim({batch, c, h, h})); } } @@ -479,7 +480,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { {1, 1}, flag_bias, flag_relu, - {1, 2, 4}, + {4}, {FLAGS_power_mode}); } } @@ -494,14 +495,14 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { if (FLAGS_basic_test) { for (auto& stride : {1}) { - for (auto& pad : {0, 1, 2}) { + for (auto& pad : {0, 1, 2, 3, 4}) { for (auto& flag_bias : {false, true}) { for (auto& flag_relu : {false, true}) { - for (auto& c : {1, 3, 5, 8, 16, 32}) { + for (auto& c : {1, 5, 15, 33}) { std::vector dims; DDim weights_dim({c, 1, 5, 5}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 15, 19, 28, 32, 75}) { + for (auto& h : {1, 3, 15, 33}) { dims.push_back(DDim({batch, c, h, h})); } } @@ -513,7 +514,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { {1, 1}, flag_bias, flag_relu, - {1, 2, 4}, + {4}, {FLAGS_power_mode}); } } @@ -527,8 +528,8 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { #if 1 /// conv1x1s1 TEST(TestConv1x1s1Int8, test_conv1x1s1) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 11, 32}) { - for (auto& cout : {1, 5, 16, 37}) { + for (auto& cin : {1, 3, 8, 32}) { + for (auto& cout : {1, 5, 17}) { for (auto& g : {1, 2}) { for (auto& flag_bias : {false, true}) { for (auto& flag_relu : {false, true}) { @@ -538,7 +539,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { } DDim weights_dim({cout, cin / g, 1, 1}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 28, 32, 56, 1}) { + for (auto& h : {1, 9, 16, 33}) { dims.push_back(DDim({batch, cin, h, h})); } } @@ -550,7 +551,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { {1, 1}, flag_bias, flag_relu, - {1, 2, 4}, + {4}, {FLAGS_power_mode}); } } @@ -564,8 +565,8 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { #if 1 /// conv3x3s1 TEST(TestConv3x3s1Int8, test_conv_3x3s1) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 32, 48}) { - for (auto& cout : {1, 5, 8, 32, 48}) { + for (auto& cin : {1, 3, 8, 33}) { + for (auto& cout : {1, 5, 33}) { for (auto& pad_top : {1, 2}) { for (auto& pad_bottom : {1, 2}) { for (auto& pad_left : {1, 2}) { @@ -575,7 +576,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { std::vector dims; DDim weights_dim({cout, cin, 3, 3}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 56, 32}) { + for (auto& h : {1, 7, 17, 33}) { dims.push_back(DDim({batch, cin, h, h})); } } @@ -587,7 +588,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { {1, 1}, flag_bias, flag_relu, - {1, 2, 4}, + {4}, {FLAGS_power_mode}); } } @@ -604,8 +605,8 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { #if 1 /// conv3x3s2 TEST(TestConv3x3s2Int8, test_conv_3x3s2) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 32}) { - for (auto& cout : {1, 5, 8, 32}) { + for (auto& cin : {1, 3, 31}) { + for (auto& cout : {1, 5, 33}) { for (auto& pad_top : {1, 2}) { for (auto& pad_bottom : {1, 2}) { for (auto& pad_left : {1, 2}) { @@ -615,7 +616,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { std::vector dims; DDim weights_dim({cout, cin, 3, 3}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 7, 19, 28, 75, 56, 32}) { + for (auto& h : {1, 7, 19, 33}) { dims.push_back(DDim({batch, cin, h, h})); } } @@ -627,7 +628,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { {1, 1}, flag_bias, flag_relu, - {1, 2, 4}, + {4}, {FLAGS_power_mode}); } } @@ -644,8 +645,8 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { #if 1 /// random param conv TEST(TestConvRandInt8, test_conv_rand) { if (FLAGS_basic_test) { - for (auto& cin : {1, 3, 8, 16}) { - for (auto& cout : {1, 5, 8, 16}) { + for (auto& cin : {1, 17}) { + for (auto& cout : {1, 8, 17}) { for (auto& g : {1, 2}) { for (auto& kw : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) { @@ -658,12 +659,12 @@ TEST(TestConvRandInt8, test_conv_rand) { for (auto& flag_bias : {false, true}) { for (auto& flag_relu : {false, true}) { if (cin % g != 0 || cout % g != 0) { - continue; + break; } std::vector dims; DDim weights_dim({cout, cin / g, kh, kw}); for (auto& batch : {1, 2}) { - for (auto& h : {1, 3, 19, 32, 28}) { + for (auto& h : {1, 3, 5, 19}) { dims.push_back(DDim({batch, cin, h, h})); } } @@ -676,7 +677,7 @@ TEST(TestConvRandInt8, test_conv_rand) { {dila, dila}, flag_bias, flag_relu, - {1, 2, 4}, + {4}, {FLAGS_power_mode}); } } diff --git a/lite/tests/math/gemv_int8_compute_test.cc b/lite/tests/math/gemv_int8_compute_test.cc index 623615c8da16326da3c233687915935aa5a88d64..25879a15184965b128bfa100a2b41a17aa842860 100644 --- a/lite/tests/math/gemv_int8_compute_test.cc +++ b/lite/tests/math/gemv_int8_compute_test.cc @@ -37,7 +37,7 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); -DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(basic_test, true, "do all tests"); DEFINE_bool(check_result, true, "check the result"); DEFINE_int32(M, 512, "gemv: M"); diff --git a/lite/tests/math/sgemm_c4_compute_test.cc b/lite/tests/math/sgemm_c4_compute_test.cc index 886dba6ac5a390c5eca4a9b499bfb57e2b077a32..3e5577e03075502bab30aa03a50241b817fa8742 100644 --- a/lite/tests/math/sgemm_c4_compute_test.cc +++ b/lite/tests/math/sgemm_c4_compute_test.cc @@ -37,7 +37,7 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); -DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(basic_test, true, "do all tests"); DEFINE_bool(check_result, true, "check the result"); DEFINE_int32(M, 512, "gemm_c4: M"); diff --git a/lite/tests/math/sgemv_compute_test.cc b/lite/tests/math/sgemv_compute_test.cc index 5dd2d322955d2c628366075a6dddb31bed2338ee..91a1fe1770dfa3eeb3f3b94fcd2361f1c1634b1e 100644 --- a/lite/tests/math/sgemv_compute_test.cc +++ b/lite/tests/math/sgemv_compute_test.cc @@ -38,11 +38,19 @@ DEFINE_int32(K, 512, "sgemv: K"); DEFINE_bool(traA, false, "gemv: A transpose"); -DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_int32(flag_act, 0, "do act"); DEFINE_bool(flag_bias, false, "with bias"); - -bool test_sgemv( - bool tra, int m, int k, bool has_bias, bool has_relu, int cls, int ths) { +DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha"); +DEFINE_double(clipped_coef, 6.0, "clipped relu coef"); +bool test_sgemv(bool tra, + int m, + int k, + bool has_bias, + int flag_act, + int cls, + int ths, + float six = 6.f, + float alpha = 1.f) { Tensor ta; Tensor tb; Tensor tc; @@ -68,8 +76,7 @@ bool test_sgemv( fill_tensor_rand(tbias, -1.f, 1.f); LOG(INFO) << "sgemv M: " << m << ", K: " << k - << ", transA: " << (tra ? "true" : "false") - << ", relu: " << (has_relu ? "true" : "false") + << ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act << ", bias: " << (has_bias ? "true" : "false"); #ifdef LITE_WITH_ARM @@ -78,10 +85,29 @@ bool test_sgemv( auto dc = tc.mutable_data(); auto dc_basic = tc_basic.mutable_data(); auto dbias = tbias.mutable_data(); - + paddle::lite_api::ActivationType act = + paddle::lite_api::ActivationType::kIndentity; + if (flag_act == 1) { + act = paddle::lite_api::ActivationType::kRelu; + } else if (flag_act == 2) { + act = paddle::lite_api::ActivationType::kRelu6; + } else if (flag_act == 4) { + act = paddle::lite_api::ActivationType::kLeakyRelu; + } if (FLAGS_check_result) { - basic_gemv( - m, k, da, db, dbias, dc_basic, 1.f, 0.f, tra, has_bias, has_relu); + basic_gemv(m, + k, + da, + db, + dbias, + dc_basic, + 1.f, + 0.f, + tra, + has_bias, + flag_act, + six, + alpha); } paddle::lite::profile::Timer t0; //! compute @@ -92,15 +118,37 @@ bool test_sgemv( ctx.SetRunMode(static_cast(cls), ths); /// warmup for (int j = 0; j < FLAGS_warmup; ++j) { - paddle::lite::arm::math::sgemv( - da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx); + paddle::lite::arm::math::sgemv(da, + db, + dc, + tra, + m, + k, + has_bias, + dbias, + flag_act > 0, + act, + &ctx, + six, + alpha); } t0.Reset(); for (int i = 0; i < FLAGS_repeats; ++i) { t0.Start(); - paddle::lite::arm::math::sgemv( - da, db, dc, tra, m, k, has_bias, dbias, has_relu, &ctx); + paddle::lite::arm::math::sgemv(da, + db, + dc, + tra, + m, + k, + has_bias, + dbias, + flag_act > 0, + act, + &ctx, + six, + alpha); t0.Stop(); } LOG(INFO) << "gemv output: M: " << m << ", K: " << k << ", cluster: " << cls @@ -125,7 +173,7 @@ bool test_sgemv( tensor_diff(tc_basic, tc, tdiff); LOG(INFO) << "basic result: "; print_tensor(tc_basic); - LOG(INFO) << "saber result: "; + LOG(INFO) << "lite result: "; print_tensor(tc); LOG(INFO) << "diff result: "; print_tensor(tdiff); @@ -144,22 +192,31 @@ TEST(TestLiteSgemv, Sgemv) { LOG(INFO) << "run basic sgemv test"; for (auto& m : {1, 3, 8, 21, 32, 397}) { for (auto& k : {1, 3, 8, 17, 59, 234}) { - for (auto& tra : {true, false}) { + for (auto& tra : {false, true}) { for (auto& has_bias : {false, true}) { - for (auto& has_relu : {false, true}) { + for (auto& flag_act : {0, 1, 2, 4}) { for (auto& th : {1, 2, 4}) { - auto flag = test_sgemv( - tra, m, k, has_bias, has_relu, FLAGS_cluster, th); + float six = 6.f; + float alpha = 8.88f; + auto flag = test_sgemv(tra, + m, + k, + has_bias, + flag_act, + FLAGS_cluster, + th, + six, + alpha); if (flag) { LOG(INFO) << "test m = " << m << ", k=" << k << ", bias: " << (has_bias ? "true" : "false") - << ", relu: " << (has_relu ? "true" : "false") + << ", flag act: " << flag_act << ", trans A: " << (tra ? "true" : "false") << ", threads: " << th << " passed\n"; } else { LOG(FATAL) << "test m = " << m << ", k=" << k << ", bias: " << (has_bias ? "true" : "false") - << ", relu: " << (has_relu ? "true" : "false") + << ", flag_act: " << flag_act << ", trans A: " << (tra ? "true" : "false") << ", threads: " << th << " failed\n"; } @@ -180,15 +237,17 @@ TEST(TestSgemvCustom, Sgemv_custom) { FLAGS_M, FLAGS_K, FLAGS_flag_bias, - FLAGS_flag_relu, + FLAGS_flag_act, FLAGS_cluster, - FLAGS_threads); + FLAGS_threads, + FLAGS_clipped_coef, + FLAGS_leakey_relu_alpha); if (!flag) { LOG(FATAL) << "test m = " << FLAGS_M << ", k=" << FLAGS_K << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias - << ", relu: " << FLAGS_flag_relu << " failed!!"; + << ", act: " << FLAGS_flag_act << " failed!!"; } LOG(INFO) << "test m = " << FLAGS_M << ", k=" << FLAGS_K << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias - << ", relu: " << FLAGS_flag_relu << " passed!!"; + << ", act: " << FLAGS_flag_act << " passed!!"; } diff --git a/lite/tests/utils/naive_math_impl.h b/lite/tests/utils/naive_math_impl.h index 91e398c5a9d9b20a4cd3ffb9b32090fc93af7781..e5ef77ca061d31a0b9b735d49cda9bbeda53c294 100644 --- a/lite/tests/utils/naive_math_impl.h +++ b/lite/tests/utils/naive_math_impl.h @@ -177,7 +177,9 @@ static void basic_gemv(int m, type2 beta, bool trans_a = false, bool flag_bias = false, - bool flag_relu = false) { + int flag_act = false, + float six = 6.f, + float leakey_relu_alpha = 1.f) { #pragma omp parallel for for (int i = 0; i < m; ++i) { auto bias_data = static_cast(0); @@ -195,8 +197,15 @@ static void basic_gemv(int m, sum += av * b[j]; } type2 tmp = alpha * sum + beta * c[i] + bias_data; - if (flag_relu) { - c[i] = tmp > (type2)0 ? tmp : (type2)0; + if (flag_act > 0) { + if (flag_act == 1) { // relu + c[i] = tmp > (type2)0 ? tmp : (type2)0; + } else if (flag_act == 2) { // relu 6 + c[i] = tmp > (type2)0 ? tmp : (type2)0; + c[i] = c[i] < six ? c[i] : six; + } else if (flag_act == 4) { // leakey relu + c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp; + } } else { c[i] = tmp; } @@ -230,7 +239,9 @@ static void conv_basic(const Dtype1* din, int pad_w, int pad_h, bool flag_bias, - bool flag_relu) { + int act_type, + float six = 6.f, + float scale = 1.f) { Dtype2 beta = 0; auto src_data = din; auto dst_data_ref = dout; @@ -280,10 +291,27 @@ static void conv_basic(const Dtype1* din, } } } - if (flag_relu) { - dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 - ? dst_data_ref[out_idx] - : (Dtype2)0; + if (act_type > 0) { + // 1-relu 2-relu6 4-leakyrelu + if (act_type == 1) { + dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 + ? dst_data_ref[out_idx] + : (Dtype2)0; + } else if (act_type == 2) { + dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 + ? dst_data_ref[out_idx] + : (Dtype2)0; + dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six + ? dst_data_ref[out_idx] + : (Dtype2)six; + } else if (act_type == 4) { + dst_data_ref[out_idx] = + dst_data_ref[out_idx] > (Dtype2)0 + ? dst_data_ref[out_idx] + : (Dtype2)(dst_data_ref[out_idx] * scale); + } else { + printf("this act type: %d does not support \n", act_type); + } } } }