From c35d8e1402961badfe496933adbecbf1f111636b Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Thu, 16 Jan 2020 12:11:36 +0800 Subject: [PATCH] [arm] add conv_5x5s2_dw to support any padding (#2770) 1. add conv_5x5s2_dw to support any padding 2. add 1x1s2pooling impl 3. fix conv dw 3x3 s1p01 bug --- .../arm/math/conv3x3s1p01_depthwise_fp32.cc | 107 +- .../arm/math/conv3x3s2p01_depthwise_fp32.cc | 45 +- .../arm/math/conv5x5s2_depthwise_fp32.cc | 4594 ++++------------- lite/backends/arm/math/conv_depthwise.h | 21 +- lite/backends/arm/math/conv_impl.cc | 16 +- lite/backends/arm/math/pooling.cc | 113 + lite/backends/arm/math/pooling.h | 10 + lite/core/mir/fusion/conv_activation_fuser.cc | 3 + lite/kernels/arm/conv_compute.cc | 7 +- lite/kernels/arm/conv_depthwise.cc | 25 +- lite/kernels/arm/pool_compute.cc | 17 +- lite/operators/conv_op.h | 4 + 12 files changed, 1187 insertions(+), 3775 deletions(-) diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc index 6ea9c4dcdb..510cb2334a 100644 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -2339,17 +2339,29 @@ void conv_depthwise_3x3s1p1_bias(float *dout, int size_out_channel = w_out * h_out; int w_stride = 9; - int tile_w = (w_in + 3) >> 2; - int cnt_col = tile_w - 2; + int tile_w = w_out >> 2; + int remain = w_out % 4; + int cnt_col = tile_w - 1; - unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); + unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in); + const unsigned int remian_idx[4] = {0, 1, 2, 3}; + + if (remain == 0 && size_pad_right == 5) { + size_pad_right = 1; + cnt_col -= 1; + remain = 4; + } else if (remain == 0 && size_pad_right == 6) { + size_pad_right = 2; + cnt_col -= 1; + remain = 4; + } uint32x4_t vmask_rp1 = vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); uint32x4_t vmask_rp2 = vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); uint32x4_t vmask_result = - vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); + vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx)); unsigned int vmask[8]; vst1q_u32(vmask, vmask_rp1); @@ -2398,7 +2410,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, const float *din_ptr5 = dr5; float *ptr_zero = const_cast(zero); #ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { + for (int i = 0; i < h_out; i += 4) { //! process top pad pad_h = 1 din_ptr0 = dr0; din_ptr1 = dr1; @@ -2484,7 +2496,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, dout_ptr = dout_ptr + 4 * w_out; } #else - for (int i = 0; i < h_in; i += 2) { + for (int i = 0; i < h_out; i += 2) { //! process top pad pad_h = 1 din_ptr0 = dr0; din_ptr1 = dr1; @@ -2883,39 +2895,57 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, wbias = vdupq_n_f32(0.f); } - int hs = -1; - int he = 3; - float out_buf1[4]; float out_buf2[4]; float trash_buf[4]; - int h_cnt = (h_out + 1) >> 1; float *doutr0 = dout_channel; float *doutr1 = dout_channel + w_out; - for (int j = 0; j < h_cnt; ++j) { - const float *dr0 = din_channel + hs * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; + const float *dr0 = din_channel; + const float *dr1 = dr0 + w_in; + const float *dr2 = dr1 + w_in; + const float *dr3 = dr2 + w_in; - if (hs == -1) { - dr0 = zero; + for (int j = 0; j < h_out; j += 2) { + const float *dr0_ptr = dr0; + const float *dr1_ptr = dr1; + const float *dr2_ptr = dr2; + const float *dr3_ptr = dr3; + if (j == 0) { + dr0_ptr = zero; + dr1_ptr = dr0; + dr2_ptr = dr1; + dr3_ptr = dr2; + dr0 = dr1; + dr1 = dr2; + } else { + dr0 = dr2; + dr1 = dr3; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + //! process bottom pad + if (j + 3 > h_in) { + switch (j + 3 - h_in) { + case 3: + dr1_ptr = zero; + case 2: + dr2_ptr = zero; + case 1: + dr3_ptr = zero; + default: + break; + } } - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; + //! process bottom remain + if (j + 2 > h_out) { + doutr1 = trash_buf; } - act_switch_3x3s1p1_s(dr0, - dr1, - dr2, - dr3, + act_switch_3x3s1p1_s(dr0_ptr, + dr1_ptr, + dr2_ptr, + dr3_ptr, out_buf1, out_buf2, wr0, @@ -2931,8 +2961,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, } doutr0 = doutr1; doutr1 += w_out; - hs += 2; - he += 2; } // end of processing heights } // end of processing channels } // end of processing batchs @@ -3458,6 +3486,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout, unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); const int remian_idx[4] = {0, 1, 2, 3}; + if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 + tile_w -= 1; + remain = 4; + size_pad_right = 2; + } + uint32x4_t vmask_rp1 = vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); uint32x4_t vmask_rp2 = @@ -4016,22 +4050,21 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, doutr0 = dout_channel + j * w_out; doutr1 = doutr0 + w_out; - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { + if (j + 4 > h_in) { + switch (j + 4 - h_in) { case 3: dr1 = zero_ptr; case 2: dr2 = zero_ptr; case 1: dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - doutr1 = trash_buf; default: break; } } + if (j + 2 > h_out) { + doutr1 = trash_buf; + } unsigned int *vmask_ptr = vmask; act_switch_3x3s1p0_s(dr0, dr1, diff --git a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc index a17d87b47d..dbfc0dc7b3 100644 --- a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc @@ -1202,15 +1202,17 @@ void conv_depthwise_3x3s2p1_bias(float* dout, int out_pad_idx[4] = {0, 1, 2, 3}; int size_pad_bottom = h_out * 2 - h_in; - int cnt_col = (w_out >> 2) - 2; - int size_right_remain = w_in - (7 + cnt_col * 8); - if (size_right_remain >= 9) { - cnt_col++; - size_right_remain -= 8; - } - int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; - int size_right_pad = w_out * 2 - w_in; + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + int cnt_col = tile_w - 1; uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), vld1q_s32(right_pad_idx)); // 0 2 4 6 @@ -1276,7 +1278,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, float* doutr1_ptr = nullptr; #ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { + for (int i = 0; i < h_out; i += 2) { din0_ptr = dr0; din1_ptr = dr1; din2_ptr = dr2; @@ -1303,8 +1305,8 @@ void conv_depthwise_3x3s2p1_bias(float* dout, dr4 = dr3 + w_in; //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { + if (i * 2 + 4 > h_in) { + switch (i * 2 + 4 - h_in) { case 4: din1_ptr = zero_ptr; case 3: @@ -1318,7 +1320,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, } } //! process output pad - if (i / 2 + 2 > h_out) { + if (i + 2 > h_out) { doutr1_ptr = write_ptr; } int cnt = cnt_col; @@ -1343,7 +1345,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, doutr0 = doutr0 + 2 * w_out; } #else - for (int i = 0; i < h_in; i += 2) { + for (int i = 0; i < h_out; i++) { din0_ptr = dr0; din1_ptr = dr1; din2_ptr = dr2; @@ -1641,7 +1643,8 @@ void act_switch_3x3s2p0(const float* din0_ptr, "ld1 {v20.4s}, [%[inptr3]] \n" "ld1 {v21.4s}, [%[inptr4]] \n" "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - MID_COMPUTE_S2 MID_RESULT_S2_RELU6 + "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2 + MID_RESULT_S2_RELU6 "cmp %w[remain], #1 \n" "blt 4f \n" RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6 @@ -1700,7 +1703,8 @@ void act_switch_3x3s2p0(const float* din0_ptr, "ld1 {v20.4s}, [%[inptr3]] \n" "ld1 {v21.4s}, [%[inptr4]] \n" "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU + "ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2 + MID_RESULT_S2_LEAKY_RELU "cmp %w[remain], #1 \n" "blt 4f \n" RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU @@ -1718,7 +1722,7 @@ void act_switch_3x3s2p0(const float* din0_ptr, [w1] "w"(wr1), [w2] "w"(wr2), [remain] "r"(cnt_remain), - [six_ptr] "r"(vscale), + [scale_ptr] "r"(vscale), [mask1] "w"(vmask_rp1), [mask2] "w"(vmask_rp2), [wmask] "w"(wmask), @@ -1834,7 +1838,14 @@ void conv_depthwise_3x3s2p0_bias(float* dout, int tile_w = w_out >> 2; int cnt_remain = w_out % 4; - unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); + unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), vld1q_s32(right_pad_idx)); // 0 2 4 6 diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index dced24db72..6a6a2dcc39 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -13,3732 +13,932 @@ // limitations under the License. #include +#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif namespace paddle { namespace lite { namespace arm { namespace math { +#ifdef __aarch64__ +#define COMPUTE \ + "ldp q0, q1, [%[inr0]], #32\n" /* load r0, 0-1 */ \ + "and v19.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q2, q3, [%[inr0]], #32\n" /* load r0, 2-3 */ \ + "and v20.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q4, q5, [%[inr0]], #32\n" /* load r0, 4-5 */ \ + "and v21.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q6, q7, [%[inr0]], #32\n" /* load r0, 6-7 */ \ + "and v22.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q8, q9, [%[inr0]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldr q10, [%[inr0]] \n" /* load r0, 10 */ \ + "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "sub %[inr0], %[inr0], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr1]], #32\n" /* load r1, 0-1 */ \ + "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , %[w3].4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w3].4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w3].4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w3].4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr1]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , %[w4].4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w4].4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w4].4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w4].4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr1]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr1]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr1]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr1], %[inr1], #32\n" /* inr1 -= 32 */ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr2]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr2]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr2]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr2]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr2], %[inr2], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr3]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr3]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr3]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr3]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr3]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr3]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr3], %[inr3], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr4]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr4]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr4]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr4]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr4]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr4]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr4], %[inr4], #32\n" /* inr0 -= 32 */ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "sub %[wc0], %[wc0], #320\n" /* weight -= 320 */ \ + "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ +#define RELU /* relu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" +#define RELU6 /* relu6 */ \ + "fmin v19.4s, v19.4s, %[vsix].4s\n" \ + "fmin v20.4s, v20.4s, %[vsix].4s\n" \ + "fmin v21.4s, v21.4s, %[vsix].4s\n" \ + "fmin v22.4s, v22.4s, %[vsix].4s\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ + "cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ + "cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ + "cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ + "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ + "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ + "bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ + "bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \ + "bif v19.16b, v8.16b, v7.16b \n" /* choose*/ +#define STORE /* save result */ \ + "str q19, [%[outc0]], #16\n" \ + "str q20, [%[outc1]], #16\n" \ + "str q21, [%[outc2]], #16\n" \ + "str q22, [%[outc3]], #16\n" -void conv_depthwise_5x5s2p2(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2p2_relu(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2p2_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2p2_relu_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx); - -void conv_depthwise_5x5s2_fp32(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, +#else +#define COMPUTE \ + /* fill with bias */ \ + "vld1.32 {d12-d13}, [%[bias]]\n" /* load bias */ /* load weights */ \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \ + "vand.i32 q12, q6, q6\n" \ + "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \ + "vand.i32 q13, q6, q6\n" \ + "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \ + "vand.i32 q14, q6, q6\n" \ + "vand.i32 q15, q6, q6\n" \ + "vld1.32 {d12-d13}, [%[r0]]!\n" /* load input r0, 6*/ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-q10 */ \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr6\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr4\n" \ + "vld1.32 {d0-d3}, [%[r0]]! \n" /* load r0, 7-8 */ \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vld1.32 {d4-d7}, [%[r0]] \n" /* load r0, 9-10 */ \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d4-d5}, [%[r1]]! @ load r1, 2\n" \ + "sub %[r0], %[r0], #16 @ r0 - 16 to nextline address\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 3, 4\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vld1.32 {d10-d13}, [%[r1]]! @ load r1, 5, 6\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr0\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr2\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 7, 8\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.32 {d4-d7}, [%[r1]] @ load r1, 9, 10\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "sub %[r1], %[r1], #16 @ r1 - 16 to nextline address\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.32 {d12-d13}, [%[r2]]! @ load r2, 6 \n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q4 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.32 {d4-d7}, [%[r2]] @ load r2, 9, 10\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[r2], %[r2], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 0, 1\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, 2, 3\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r3]]! @ load r3, 4, 5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.32 {d12-d13}, [%[r3]]! @ load r3, 6, \n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vld1.32 {d4-d7}, [%[r3]] @ load r3, 9, 10\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q9, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q9, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[r3], %[r3], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r4]]! @ load r4, 0, 1\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d4-d7}, [%[r4]]! @ load r4, 2, 3\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r4]]! @ load r3, 4, 5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.32 {d12-d13}, [%[r4]]! @ load r3, 6, \n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r4]]! @ load r3, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vld1.32 {d4-d7}, [%[r4]] @ load r3, 9, 10\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q9, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q9, q2 @ w3 * inr9\n" \ + "sub %[wc0], %[wc0], #400 @ wc0 - 400 to start address\n" \ + "sub %[r4], %[r4], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \ + "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \ + "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \ + "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ + +#define RELU /* relu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[six_ptr]]\n" \ + "vmax.f32 q12, q12, q0\n" \ + "vmax.f32 q13, q13, q0\n" \ + "vmax.f32 q14, q14, q0\n" \ + "vmax.f32 q15, q15, q0\n" +#define RELU6 /* relu6 */ \ + "vmin.f32 q12, q12, q1\n" \ + "vmin.f32 q13, q13, q1\n" \ + "vmin.f32 q14, q14, q1\n" \ + "vmin.f32 q15, q15, q1\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[scale_ptr]]\n" \ + "vcge.f32 q2, q12, q0 @ q0 > 0 \n" \ + "vcge.f32 q4, q13, q0 @ q0 > 0 \n" \ + "vcge.f32 q6, q14, q0 @ q0 > 0 \n" \ + "vcge.f32 q8, q15, q0 @ q0 > 0 \n" \ + "vmul.f32 q3, q12, q1 @ mul \n" \ + "vmul.f32 q5, q13, q1 @ mul \n" \ + "vmul.f32 q7, q14, q1 @ mul \n" \ + "vmul.f32 q9, q15, q1 @ mul \n" \ + "vbif q12, q3, q2 @ choose \n" \ + "vbif q13, q5, q4 @ choose \n" \ + "vbif q14, q7, q6 @ choose \n" \ + "vbif q15, q9, q8 @ choose \n" +#define STORE /* save result */ \ + "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \ + "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \ + "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \ + "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ + +#endif + +void act_switch_5x5s2(const float* inr0, + const float* inr1, + const float* inr2, + const float* inr3, + const float* inr4, + float* outc0, + float* outc1, + float* outc2, + float* outc3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t vbias, + const float* weight_c, + float* bias_local, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(tmp); + float32x4_t vscale = vdupq_n_f32(ss); +#else + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias), + [vsix] "w"(vsix) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias), + [vscale] "w"(vscale) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} +void conv_depthwise_5x5s2_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, int win, const float* weights, const float* bias, - int pad, - bool flag_bias, - bool flag_relu, + const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx) { - if (pad == 2) { - if (win >= 9) { - if (flag_relu) { - conv_depthwise_5x5s2p2_relu(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s2p2(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); - } - } else { - if (flag_relu) { - conv_depthwise_5x5s2p2_relu_s(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); - } else { - conv_depthwise_5x5s2p2_s(din, - dout, - num, - chout, - hout, - wout, - chin, - hin, - win, - weights, - bias, - flag_bias, - flag_relu, - ctx); + auto paddings = *param.paddings; + int threads = ctx->threads(); + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + const int out_c_block = 4; + const int out_h_kernel = 1; + const int out_w_kernel = 4; + const int win_ext = ow * 2 + 3; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh * 2 + 3; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = threads * prein_size + win_round + ow_round; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + auto ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 25; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; } - } - } -} - #ifdef __aarch64__ - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - const float* din5 = din4 + w_in; - const float* din6 = din5 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 2) { - //! (h * 2 - 2) + 6 > h_in - 1 - if (h * 2 + 5 > h_in) { - switch (h * 2 + 5 - h_in) { - case 6: - din1 = zero_ptr; - case 5: - din2 = zero_ptr; - case 4: - din3 = zero_ptr; - case 3: - din4 = zero_ptr; - case 2: - din5 = zero_ptr; - case 1: - din6 = zero_ptr; - default: - break; - } - } - if (h + 2 > h_out) { - switch (h + 2 - h_out) { - case 1: - dout1 = write_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - //! in r0, r1/r4, r2/r5, r3/r6: x 0 2 4 -- v8 v13 v18 v23 - //! in r0, r1/r4, r2/r5, r3/r6: x 1 3 5 -- v9 v14 v19 v24 - //! in r0, r1/r4, r2/r5, r3/r6: 0 2 4 6 -- v6 v11 v16 v21 - //! in r0, r1/r4, r2/r5, r3/r6: 1 3 5 7 -- v7 v12 v17 v22 - //! in r0, r1/r4, r2/r5, r3/r6: 2 4 6 8 -- v10 v15 v20 v25 - //! out r0, r1 -- v26, v27 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - // left - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" // r0 v6: 0 - // 2 4 6, - // v7: 1 3 - // 5 7 - "ext v8.16b, v31.16b, v6.16b, #12 \n" // r0 v8: x - // 0 2 4 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" // r1 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - "ext v9.16b, v31.16b, v7.16b, #12 \n" // r0 v9: x - // 1 3 5 - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load - // weights - // 0-7 - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ld1 {v10.s}[3], [%[din_ptr0]] \n" // r0 v10: - // 2 4 6 8 - "sub %[din_ptr0], %[din_ptr0], #8 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r1 v13: - // x 0 2 4 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" // r2 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r1 v14: - // x 1 3 5 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load - // weights - // 8-15 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" // r1 v15: - // 2 4 6 - "sub %[din_ptr1], %[din_ptr1], #8 \n" - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r2 v18: - // x 0 2 4 - "ld1 {v4.4s, v5.4s}, [%[weights]], #32 \n" // load - // weights - // 16-23 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r2 v19: - // x 1 3 5 - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" // r3 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" // r2 v20: - // 2 4 6 8 - "sub %[din_ptr2], %[din_ptr2], #8 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r3 v23: - // x 0 2 4 - "ld1 {v30.4s}, [%[weights]] \n" // load - // weights - // 24 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r3 v24: - // x 1 3 5 - "ld1 {v26.4s}, [%[vbias]] \n" // load - // bias to - // out_r0 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" // r2 v25: - // 2 4 6 8 - "sub %[din_ptr3], %[din_ptr3], #8 \n" - "mov v27.16b, v26.16b \n" // load - // bias to - // out_r1 - "mov v28.16b, v31.16b \n" // load - // zero to - // out_r0 - "mov v29.16b, v31.16b \n" // load - // zero to - // out_r1 - - "fmla v26.4s, v8.4s, v0.s[0] \n" // out r0: - // w0 - "fmla v28.4s, v9.4s, v0.s[1] \n" // out r0: - // w1 - "fmla v26.4s, v6.4s, v0.s[2] \n" // out r0: - // w2 - "fmla v28.4s, v7.4s, v0.s[3] \n" // out r0: - // w3 - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v8: 0 2 - // 4 6, v9: - // 1 3 5 7 - - "fmla v26.4s, v10.4s, v1.s[0] \n" // out r0: - // w4 - "fmla v28.4s, v13.4s, v1.s[1] \n" // out r0: - // w5 - "fmla v26.4s, v14.4s, v1.s[2] \n" // out r0: - // w6 - "fmla v28.4s, v11.4s, v1.s[3] \n" // out r0: - // w7 - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v6: 2 4 - // 6 8, v7: - // 3 5 7 9 - - "fmla v26.4s, v12.4s, v2.s[0] \n" // out r0: - // w8 - "fmla v28.4s, v15.4s, v2.s[1] \n" // out r0: - // w9 - "fmla v26.4s, v18.4s, v2.s[2] \n" // out r0: - // w10 - "fmla v28.4s, v19.4s, v2.s[3] \n" // out r0: - // w11 - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" // next r0 - // v10: 4 6 - // 8 10, - // v11: - // trash - // register - - "fmla v26.4s, v16.4s, v3.s[0] \n" // out r0: - // w12 - "fmla v28.4s, v17.4s, v3.s[1] \n" // out r0: - // w13 - "fmla v26.4s, v20.4s, v3.s[2] \n" // out r0: - // w14 - "fmla v28.4s, v23.4s, v3.s[3] \n" // out r0: - // w15 - "prfm pldl1keep, [%[din_ptr0]] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], #32 \n" // r4 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - - "fmla v26.4s, v24.4s, v4.s[0] \n" // out r0: - // w16 - "fmla v28.4s, v21.4s, v4.s[1] \n" // out r0: - // w17 - - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r4 v13: - // x 0 2 4 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r4 v14: - // x 1 3 5 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - - "fmla v26.4s, v22.4s, v4.s[2] \n" // out r0: - // w18 - "fmla v28.4s, v25.4s, v4.s[3] \n" // out r0: - // w19 - - "ld1 {v15.s}[3], [%[din_ptr4]] \n" // r4 v15: - // 2 4 6 - - "fmla v27.4s, v18.4s, v0.s[0] \n" // out r1: - // w0 - "fmla v29.4s, v19.4s, v0.s[1] \n" // out r1: - // w1 - - "sub %[din_ptr4], %[din_ptr4], #8 \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" // out r1: - // w2 - "fmla v29.4s, v17.4s, v0.s[3] \n" // out r1: - // w3 - "fmla v27.4s, v20.4s, v1.s[0] \n" // out r1: - // w4 - "fmla v29.4s, v23.4s, v1.s[1] \n" // out r1: - // w5 - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], #32 \n" // r5 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - - "fmla v27.4s, v24.4s, v1.s[2] \n" // out r1: - // w6 - "fmla v29.4s, v21.4s, v1.s[3] \n" // out r1: - // w7 - - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r5 v18: - // x 0 2 4 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r5 v19: - // x 1 3 5 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" // out r1: - // w8 - "fmla v29.4s, v25.4s, v2.s[1] \n" // out r1: - // w9 - - "ld1 {v20.s}[3], [%[din_ptr5]] \n" // r5 v20: - // 2 4 6 - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], #32 \n" // r6 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r6 v23: - // x 0 2 4 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r6 v24: - // x 1 3 5 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "sub %[din_ptr5], %[din_ptr5], #8 \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" // out r0: - // w22 - "fmla v28.4s, v12.4s, v5.s[3] \n" // out r0: - // w23 - - "ld1 {v25.s}[3], [%[din_ptr6]] \n" // r6 v25: - // 2 4 6 - - "fmla v26.4s, v13.4s, v5.s[0] \n" // out r0: - // w20 - "fmla v28.4s, v14.4s, v5.s[1] \n" // out r0: - // w21 - - "sub %[din_ptr6], %[din_ptr6], #8 \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" // out r0: - // w24 - "fmla v27.4s, v13.4s, v2.s[2] \n" // out r1: - // w10 - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" // out r1: - // w11 - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v13: 0 2 - // 4 6, - // v14: 1 3 - // 5 7 - "fmla v27.4s, v11.4s, v3.s[0] \n" // out r1: - // w12 - "fmla v29.4s, v12.4s, v3.s[1] \n" // out r1: - // w13 - - "st1 {v26.4s}, [%[dout_ptr0]], %[s_16] \n" // store - // output - // r0 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v11: 2 4 - // 6 8, - // v12: 3 5 - // 7 9 - - "fmla v27.4s, v15.4s, v3.s[2] \n" // out r1: - // w14 - "fmla v29.4s, v16.4s, v4.s[1] \n" // out r1: - // w17 - "fmla v27.4s, v18.4s, v3.s[3] \n" // out r1: - // w15 - "fmla v29.4s, v19.4s, v4.s[0] \n" // out r1: - // w16 - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" // next r1 - // v15: 4 6 - // 8 10, - // v16: - // trash - // register - - "fmla v27.4s, v17.4s, v4.s[2] \n" // out r1: - // w18 - "fmla v29.4s, v20.4s, v4.s[3] \n" // out r1: - // w19 - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v18: 0 2 - // 4 6, - // v19: 1 3 - // 5 7 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v16: 2 4 - // 6 8, - // v11: 3 5 - // 7 9 - - "fmla v27.4s, v23.4s, v5.s[0] \n" // out r1: - // w20 - "fmla v29.4s, v21.4s, v5.s[2] \n" // out r1: - // w22 - "fmla v27.4s, v24.4s, v5.s[1] \n" // out r1: - // w21 - "fmla v29.4s, v22.4s, v5.s[3] \n" // out r1: - // w23 - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" // next r2 - // v20: 4 6 - // 8 10, - // v21: - // trash - // register - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v23: 0 2 - // 4 6, - // v24: 1 3 - // 5 7 - - "fmla v27.4s, v25.4s, v30.s[0] \n" // out r1: - // w24 - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v21: 2 4 - // 6 8, - // v22: 3 5 - // 7 9 - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" // next r3 - // v25: 4 6 - // 8 10, - // v26: - // trash - // register - - "fadd v27.4s, v27.4s, v29.4s \n" - "cmp %w[mid_cnt], #1 \n" - - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - "blt 2f \n" - - // mid loop - "1: \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v27.16b, v26.16b \n" - "mov v28.16b, v31.16b \n" - "mov v29.16b, v31.16b \n" - - // out_r0 r0-r3 - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr0]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v27.4s, v16.4s, v0.s[2] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - - "fmla v28.4s, v17.4s, v3.s[1] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - - "fmla v26.4s, v24.4s, v4.s[0] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - "fadd v28.4s, v26.4s, v28.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]], %[s_16] \n" - "mov v26.16b, v31.16b \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v29.4s, v20.4s, v4.s[3] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v27.4s, v21.4s, v5.s[2] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" - - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - - "fmla v29.4s, v22.4s, v5.s[3] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" - - "fmla v27.4s, v25.4s, v30.s[0] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - - "prfm pldl1keep, [%[din_ptr3]] \n" - - "fadd v27.4s, v27.4s, v29.4s \n" - - "st1 {v26.4s}, [%[dout_ptr0]], #16 \n" - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" - "subs %w[mid_cnt], %w[mid_cnt], #1 \n" - "bne 1b \n" - - "2: \n" - "ld2 {v26.4s, v27.4s}, [%[mask]], %[s_8] \n" - "ld2 {v28.4s, v29.4s}, [%[mask]], %[s_8] \n" - "bif v8.16b, v31.16b, v26.16b \n" - "bif v9.16b, v31.16b, v27.16b \n" - "bif v6.16b, v31.16b, v28.16b \n" - "bif v7.16b, v31.16b, v29.16b \n" - - "bif v13.16b, v31.16b, v26.16b \n" - "bif v14.16b, v31.16b, v27.16b \n" - "bif v11.16b, v31.16b, v28.16b \n" - "bif v12.16b, v31.16b, v29.16b \n" - - "bif v18.16b, v31.16b, v26.16b \n" - "bif v19.16b, v31.16b, v27.16b \n" - "bif v16.16b, v31.16b, v28.16b \n" - "bif v17.16b, v31.16b, v29.16b \n" - - "bif v23.16b, v31.16b, v26.16b \n" - "bif v24.16b, v31.16b, v27.16b \n" - "bif v21.16b, v31.16b, v28.16b \n" - "bif v22.16b, v31.16b, v29.16b \n" - - "ld2 {v28.4s, v29.4s}, [%[mask]] \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v29.16b, v31.16b \n" - - "bif v10.16b, v31.16b, v28.16b \n" - "bif v15.16b, v31.16b, v28.16b \n" - - "mov v27.16b, v26.16b \n" - - "bif v20.16b, v31.16b, v28.16b \n" - "bif v25.16b, v31.16b, v28.16b \n" - "mov v28.16b, v31.16b \n" - - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "sub %[mask], %[mask], #16 \n" - "ld2 {v6.4s, v7.4s}, [%[mask]], %[s_8] \n" - "ld2 {v8.4s, v9.4s}, [%[mask]], %[s_8] \n" - "ld2 {v10.4s, v11.4s}, [%[mask]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v28.4s, v17.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "bif v13.16b, v31.16b, v6.16b \n" - "bif v14.16b, v31.16b, v7.16b \n" - "bif v11.16b, v31.16b, v8.16b \n" - "bif v12.16b, v31.16b, v9.16b \n" - "bif v15.16b, v31.16b, v10.16b \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]] \n" - - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - "fmla v26.4s, v24.4s, v4.s[0] \n" - - "bif v18.16b, v31.16b, v6.16b \n" - "bif v19.16b, v31.16b, v7.16b \n" - "bif v16.16b, v31.16b, v8.16b \n" - "bif v17.16b, v31.16b, v9.16b \n" - "bif v20.16b, v31.16b, v10.16b \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - "fadd v28.4s, v28.4s, v26.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]] \n" - "mov v26.16b, v31.16b \n" - - "bif v23.16b, v31.16b, v6.16b \n" - "bif v24.16b, v31.16b, v7.16b \n" - "bif v21.16b, v31.16b, v8.16b \n" - "bif v22.16b, v31.16b, v9.16b \n" - "bif v25.16b, v31.16b, v10.16b \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v26.4s, v15.4s, v30.s[0] \n" - - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - - "st1 {v26.4s}, [%[out_buf0]] \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - "fmla v29.4s, v20.4s, v4.s[3] \n" - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "fmla v27.4s, v21.4s, v5.s[2] \n" - "fmla v29.4s, v22.4s, v5.s[3] \n" - "fmla v27.4s, v25.4s, v30.s[0] \n" - "fadd v27.4s, v27.4s, v29.4s \n" - - "st1 {v27.4s}, [%[out_buf1]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [dout_ptr1] "+r"(dout_ptr1), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [din_ptr6] "+r"(din_ptr6), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - dout_ptr1[i] = out_buf1[i]; - } - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din6 + w_in; - din4 = din3 + w_in; - din5 = din4 + w_in; - din6 = din5 + w_in; - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t vbias = vdupq_n_f32(0.f); + if (flag_bias) { + vbias = vld1q_f32(&bias[c]); // v28 } - } - } -} - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2_relu(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; - -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - const float* din5 = din4 + w_in; - const float* din6 = din5 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 2) { - //! (h * 2 - 2) + 6 > h_in - 1 - if (h * 2 + 5 > h_in) { - switch (h * 2 + 5 - h_in) { - case 6: - din1 = zero_ptr; - case 5: - din2 = zero_ptr; - case 4: - din3 = zero_ptr; + weight_c += 20; +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc0 = dout_c00 + h * ow; + float* outc1 = outc0 + size_out_channel; + float* outc2 = outc1 + size_out_channel; + float* outc3 = outc2 + size_out_channel; + const float* inr0 = pre_din + h * 2 * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + const float* inr4 = inr3 + row_len; + + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { case 3: - din4 = zero_ptr; + outc1 = ptr_write; case 2: - din5 = zero_ptr; + outc2 = ptr_write; case 1: - din6 = zero_ptr; + outc3 = ptr_write; default: break; } } - if (h + 2 > h_out) { - switch (h + 2 - h_out) { - case 1: - dout1 = write_ptr; - default: - break; + auto c0 = outc0; + auto c1 = outc1; + auto c2 = outc2; + auto c3 = outc3; + float pre_out[16]; + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + if (flag_mask) { + c0 = outc0; + c1 = outc1; + c2 = outc2; + c3 = outc3; + outc0 = pre_out; + outc1 = pre_out + 4; + outc2 = pre_out + 8; + outc3 = pre_out + 12; } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - const float* din_ptr5 = din5; - const float* din_ptr6 = din6; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - float* dout_ptr1 = dout1; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - //! in r0, r1/r4, r2/r5, r3/r6: x 0 2 4 -- v8 v13 v18 v23 - //! in r0, r1/r4, r2/r5, r3/r6: x 1 3 5 -- v9 v14 v19 v24 - //! in r0, r1/r4, r2/r5, r3/r6: 0 2 4 6 -- v6 v11 v16 v21 - //! in r0, r1/r4, r2/r5, r3/r6: 1 3 5 7 -- v7 v12 v17 v22 - //! in r0, r1/r4, r2/r5, r3/r6: 2 4 6 8 -- v10 v15 v20 v25 - //! out r0, r1 -- v26, v27 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - // left - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" // r0 v6: 0 - // 2 4 6, - // v7: 1 3 - // 5 7 - "ext v8.16b, v31.16b, v6.16b, #12 \n" // r0 v8: x - // 0 2 4 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" // r1 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - "ext v9.16b, v31.16b, v7.16b, #12 \n" // r0 v9: x - // 1 3 5 - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load - // weights - // 0-7 - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ld1 {v10.s}[3], [%[din_ptr0]] \n" // r0 v10: - // 2 4 6 8 - "sub %[din_ptr0], %[din_ptr0], #8 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r1 v13: - // x 0 2 4 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" // r2 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r1 v14: - // x 1 3 5 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load - // weights - // 8-15 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" // r1 v15: - // 2 4 6 - "sub %[din_ptr1], %[din_ptr1], #8 \n" - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r2 v18: - // x 0 2 4 - "ld1 {v4.4s, v5.4s}, [%[weights]], #32 \n" // load - // weights - // 16-23 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r2 v19: - // x 1 3 5 - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" // r3 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" // r2 v20: - // 2 4 6 8 - "sub %[din_ptr2], %[din_ptr2], #8 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r3 v23: - // x 0 2 4 - "ld1 {v30.4s}, [%[weights]] \n" // load - // weights - // 24 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r3 v24: - // x 1 3 5 - "ld1 {v26.4s}, [%[vbias]] \n" // load - // bias to - // out_r0 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" // r2 v25: - // 2 4 6 8 - "sub %[din_ptr3], %[din_ptr3], #8 \n" - "mov v27.16b, v26.16b \n" // load - // bias to - // out_r1 - "mov v28.16b, v31.16b \n" // load - // zero to - // out_r0 - "mov v29.16b, v31.16b \n" // load - // zero to - // out_r1 - - "fmla v26.4s, v8.4s, v0.s[0] \n" // out r0: - // w0 - "fmla v28.4s, v9.4s, v0.s[1] \n" // out r0: - // w1 - "fmla v26.4s, v6.4s, v0.s[2] \n" // out r0: - // w2 - "fmla v28.4s, v7.4s, v0.s[3] \n" // out r0: - // w3 - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v8: 0 2 - // 4 6, v9: - // 1 3 5 7 - - "fmla v26.4s, v10.4s, v1.s[0] \n" // out r0: - // w4 - "fmla v28.4s, v13.4s, v1.s[1] \n" // out r0: - // w5 - "fmla v26.4s, v14.4s, v1.s[2] \n" // out r0: - // w6 - "fmla v28.4s, v11.4s, v1.s[3] \n" // out r0: - // w7 - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" // next r0 - // v6: 2 4 - // 6 8, v7: - // 3 5 7 9 - - "fmla v26.4s, v12.4s, v2.s[0] \n" // out r0: - // w8 - "fmla v28.4s, v15.4s, v2.s[1] \n" // out r0: - // w9 - "fmla v26.4s, v18.4s, v2.s[2] \n" // out r0: - // w10 - "fmla v28.4s, v19.4s, v2.s[3] \n" // out r0: - // w11 - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" // next r0 - // v10: 4 6 - // 8 10, - // v11: - // trash - // register - - "fmla v26.4s, v16.4s, v3.s[0] \n" // out r0: - // w12 - "fmla v28.4s, v17.4s, v3.s[1] \n" // out r0: - // w13 - "fmla v26.4s, v20.4s, v3.s[2] \n" // out r0: - // w14 - "fmla v28.4s, v23.4s, v3.s[3] \n" // out r0: - // w15 - "prfm pldl1keep, [%[din_ptr0]] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], #32 \n" // r4 v11: - // 0 2 4 6, - // v12: 1 3 - // 5 7 - - "fmla v26.4s, v24.4s, v4.s[0] \n" // out r0: - // w16 - "fmla v28.4s, v21.4s, v4.s[1] \n" // out r0: - // w17 - - "ext v13.16b, v31.16b, v11.16b, #12 \n" // r4 v13: - // x 0 2 4 - "ext v14.16b, v31.16b, v12.16b, #12 \n" // r4 v14: - // x 1 3 5 - "ext v15.16b, v11.16b, v31.16b, #4 \n" - - "fmla v26.4s, v22.4s, v4.s[2] \n" // out r0: - // w18 - "fmla v28.4s, v25.4s, v4.s[3] \n" // out r0: - // w19 - - "ld1 {v15.s}[3], [%[din_ptr4]] \n" // r4 v15: - // 2 4 6 - - "fmla v27.4s, v18.4s, v0.s[0] \n" // out r1: - // w0 - "fmla v29.4s, v19.4s, v0.s[1] \n" // out r1: - // w1 - - "sub %[din_ptr4], %[din_ptr4], #8 \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" // out r1: - // w2 - "fmla v29.4s, v17.4s, v0.s[3] \n" // out r1: - // w3 - "fmla v27.4s, v20.4s, v1.s[0] \n" // out r1: - // w4 - "fmla v29.4s, v23.4s, v1.s[1] \n" // out r1: - // w5 - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], #32 \n" // r5 v16: - // 0 2 4 6, - // v17: 1 3 - // 5 7 - - "fmla v27.4s, v24.4s, v1.s[2] \n" // out r1: - // w6 - "fmla v29.4s, v21.4s, v1.s[3] \n" // out r1: - // w7 - - "ext v18.16b, v31.16b, v16.16b, #12 \n" // r5 v18: - // x 0 2 4 - "ext v19.16b, v31.16b, v17.16b, #12 \n" // r5 v19: - // x 1 3 5 - "ext v20.16b, v16.16b, v31.16b, #4 \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" // out r1: - // w8 - "fmla v29.4s, v25.4s, v2.s[1] \n" // out r1: - // w9 - - "ld1 {v20.s}[3], [%[din_ptr5]] \n" // r5 v20: - // 2 4 6 - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], #32 \n" // r6 v21: - // 0 2 4 6, - // v22: 1 3 - // 5 7 - - "ext v23.16b, v31.16b, v21.16b, #12 \n" // r6 v23: - // x 0 2 4 - "ext v24.16b, v31.16b, v22.16b, #12 \n" // r6 v24: - // x 1 3 5 - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "sub %[din_ptr5], %[din_ptr5], #8 \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" // out r0: - // w22 - "fmla v28.4s, v12.4s, v5.s[3] \n" // out r0: - // w23 - - "ld1 {v25.s}[3], [%[din_ptr6]] \n" // r6 v25: - // 2 4 6 - - "fmla v26.4s, v13.4s, v5.s[0] \n" // out r0: - // w20 - "fmla v28.4s, v14.4s, v5.s[1] \n" // out r0: - // w21 - - "sub %[din_ptr6], %[din_ptr6], #8 \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" // out r0: - // w24 - "fmla v27.4s, v13.4s, v2.s[2] \n" // out r1: - // w10 - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" // out r1: - // w11 - "fmax v26.4s, v26.4s, v31.4s \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v13: 0 2 - // 4 6, - // v14: 1 3 - // 5 7 - "fmla v27.4s, v11.4s, v3.s[0] \n" // out r1: - // w12 - "fmla v29.4s, v12.4s, v3.s[1] \n" // out r1: - // w13 - - "st1 {v26.4s}, [%[dout_ptr0]], %[s_16] \n" // store - // output - // r0 - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" // next r1 - // v11: 2 4 - // 6 8, - // v12: 3 5 - // 7 9 - - "fmla v27.4s, v15.4s, v3.s[2] \n" // out r1: - // w14 - "fmla v29.4s, v16.4s, v4.s[1] \n" // out r1: - // w17 - "fmla v27.4s, v18.4s, v3.s[3] \n" // out r1: - // w15 - "fmla v29.4s, v19.4s, v4.s[0] \n" // out r1: - // w16 - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" // next r1 - // v15: 4 6 - // 8 10, - // v16: - // trash - // register - - "fmla v27.4s, v17.4s, v4.s[2] \n" // out r1: - // w18 - "fmla v29.4s, v20.4s, v4.s[3] \n" // out r1: - // w19 - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v18: 0 2 - // 4 6, - // v19: 1 3 - // 5 7 - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" // next r2 - // v16: 2 4 - // 6 8, - // v11: 3 5 - // 7 9 - - "fmla v27.4s, v23.4s, v5.s[0] \n" // out r1: - // w20 - "fmla v29.4s, v21.4s, v5.s[2] \n" // out r1: - // w22 - "fmla v27.4s, v24.4s, v5.s[1] \n" // out r1: - // w21 - "fmla v29.4s, v22.4s, v5.s[3] \n" // out r1: - // w23 - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" // next r2 - // v20: 4 6 - // 8 10, - // v21: - // trash - // register - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v23: 0 2 - // 4 6, - // v24: 1 3 - // 5 7 - - "fmla v27.4s, v25.4s, v30.s[0] \n" // out r1: - // w24 - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" // next r3 - // v21: 2 4 - // 6 8, - // v22: 3 5 - // 7 9 - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" // next r3 - // v25: 4 6 - // 8 10, - // v26: - // trash - // register - - "fadd v27.4s, v27.4s, v29.4s \n" - "fmax v27.4s, v27.4s, v31.4s \n" - "cmp %w[mid_cnt], #1 \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - "blt 2f \n" - - // mid loop - "1: \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v27.16b, v26.16b \n" - "mov v28.16b, v31.16b \n" - "mov v29.16b, v31.16b \n" - - // out_r0 r0-r3 - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "ld2 {v8.4s, v9.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], %[s_8] \n" - - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - - "ld2 {v10.4s, v11.4s}, [%[din_ptr0]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr0]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v27.4s, v16.4s, v0.s[2] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - - "fmla v28.4s, v17.4s, v3.s[1] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr5]] \n" - - "fmla v26.4s, v24.4s, v4.s[0] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - "fadd v28.4s, v26.4s, v28.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]], %[s_16] \n" - "mov v26.16b, v31.16b \n" - "prfm pldl1keep, [%[din_ptr6]] \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], %[s_8] \n" - - "fmla v26.4s, v15.4s, v30.s[0] \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr1]], %[s_16] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v29.4s, v20.4s, v4.s[3] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], %[s_8] \n" - - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v27.4s, v21.4s, v5.s[2] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr2]], %[s_16] \n" - - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr3]], %[s_8] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - - "fmla v29.4s, v22.4s, v5.s[3] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], %[s_8] \n" - - "fmla v27.4s, v25.4s, v30.s[0] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - "fadd v27.4s, v27.4s, v29.4s \n" - "fmax v26.4s, v26.4s, v31.4s \n" - "fmax v27.4s, v27.4s, v31.4s \n" - - "prfm pldl1keep, [%[din_ptr3]] \n" - "st1 {v26.4s}, [%[dout_ptr0]], #16 \n" - "st1 {v27.4s}, [%[dout_ptr1]], #16 \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr3]], %[s_16] \n" - "subs %w[mid_cnt], %w[mid_cnt], #1 \n" - "bne 1b \n" - - "2: \n" - "ld2 {v26.4s, v27.4s}, [%[mask]], %[s_8] \n" - "ld2 {v28.4s, v29.4s}, [%[mask]], %[s_8] \n" - "bif v8.16b, v31.16b, v26.16b \n" - "bif v9.16b, v31.16b, v27.16b \n" - "bif v6.16b, v31.16b, v28.16b \n" - "bif v7.16b, v31.16b, v29.16b \n" - - "bif v13.16b, v31.16b, v26.16b \n" - "bif v14.16b, v31.16b, v27.16b \n" - "bif v11.16b, v31.16b, v28.16b \n" - "bif v12.16b, v31.16b, v29.16b \n" - - "bif v18.16b, v31.16b, v26.16b \n" - "bif v19.16b, v31.16b, v27.16b \n" - "bif v16.16b, v31.16b, v28.16b \n" - "bif v17.16b, v31.16b, v29.16b \n" - - "bif v23.16b, v31.16b, v26.16b \n" - "bif v24.16b, v31.16b, v27.16b \n" - "bif v21.16b, v31.16b, v28.16b \n" - "bif v22.16b, v31.16b, v29.16b \n" - - "ld2 {v28.4s, v29.4s}, [%[mask]] \n" - "ld1 {v26.4s}, [%[vbias]] \n" - "mov v29.16b, v31.16b \n" - - "bif v10.16b, v31.16b, v28.16b \n" - "bif v15.16b, v31.16b, v28.16b \n" - - "mov v27.16b, v26.16b \n" - - "bif v20.16b, v31.16b, v28.16b \n" - "bif v25.16b, v31.16b, v28.16b \n" - "mov v28.16b, v31.16b \n" - - "fmla v26.4s, v8.4s, v0.s[0] \n" - "fmla v28.4s, v9.4s, v0.s[1] \n" - "fmla v26.4s, v6.4s, v0.s[2] \n" - "fmla v28.4s, v7.4s, v0.s[3] \n" - - "fmla v26.4s, v10.4s, v1.s[0] \n" - "fmla v28.4s, v13.4s, v1.s[1] \n" - "fmla v26.4s, v14.4s, v1.s[2] \n" - "fmla v28.4s, v11.4s, v1.s[3] \n" - - "sub %[mask], %[mask], #16 \n" - "ld2 {v6.4s, v7.4s}, [%[mask]], %[s_8] \n" - "ld2 {v8.4s, v9.4s}, [%[mask]], %[s_8] \n" - "ld2 {v10.4s, v11.4s}, [%[mask]] \n" - - "fmla v26.4s, v12.4s, v2.s[0] \n" - "fmla v28.4s, v15.4s, v2.s[1] \n" - - "ld2 {v13.4s, v14.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v26.4s, v16.4s, v3.s[0] \n" - "fmla v28.4s, v17.4s, v3.s[1] \n" - - "ld2 {v11.4s, v12.4s}, [%[din_ptr4]], %[s_8] \n" - - "fmla v27.4s, v16.4s, v0.s[2] \n" - "fmla v29.4s, v17.4s, v0.s[3] \n" - - "ld2 {v15.4s, v16.4s}, [%[din_ptr4]] \n" - - "fmla v26.4s, v18.4s, v2.s[2] \n" - "fmla v28.4s, v19.4s, v2.s[3] \n" - "fmla v27.4s, v18.4s, v0.s[0] \n" - "fmla v29.4s, v19.4s, v0.s[1] \n" - - "bif v13.16b, v31.16b, v6.16b \n" - "bif v14.16b, v31.16b, v7.16b \n" - "bif v11.16b, v31.16b, v8.16b \n" - "bif v12.16b, v31.16b, v9.16b \n" - "bif v15.16b, v31.16b, v10.16b \n" - - "ld2 {v18.4s, v19.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v26.4s, v20.4s, v3.s[2] \n" - "fmla v27.4s, v20.4s, v1.s[0] \n" - - "ld2 {v16.4s, v17.4s}, [%[din_ptr5]], %[s_8] \n" - - "fmla v29.4s, v21.4s, v1.s[3] \n" - "fmla v28.4s, v21.4s, v4.s[1] \n" - - "ld2 {v20.4s, v21.4s}, [%[din_ptr5]] \n" - - "fmla v28.4s, v23.4s, v3.s[3] \n" - "fmla v29.4s, v23.4s, v1.s[1] \n" - "fmla v27.4s, v24.4s, v1.s[2] \n" - "fmla v26.4s, v24.4s, v4.s[0] \n" - - "bif v18.16b, v31.16b, v6.16b \n" - "bif v19.16b, v31.16b, v7.16b \n" - "bif v16.16b, v31.16b, v8.16b \n" - "bif v17.16b, v31.16b, v9.16b \n" - "bif v20.16b, v31.16b, v10.16b \n" - - "ld2 {v23.4s, v24.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v27.4s, v22.4s, v2.s[0] \n" - "fmla v26.4s, v22.4s, v4.s[2] \n" - - "ld2 {v21.4s, v22.4s}, [%[din_ptr6]], %[s_8] \n" - - "fmla v28.4s, v25.4s, v4.s[3] \n" - "fmla v29.4s, v25.4s, v2.s[1] \n" - "fadd v28.4s, v28.4s, v26.4s \n" - - "ld2 {v25.4s, v26.4s}, [%[din_ptr6]] \n" - "mov v26.16b, v31.16b \n" - - "bif v23.16b, v31.16b, v6.16b \n" - "bif v24.16b, v31.16b, v7.16b \n" - "bif v21.16b, v31.16b, v8.16b \n" - "bif v22.16b, v31.16b, v9.16b \n" - "bif v25.16b, v31.16b, v10.16b \n" - - "fmla v26.4s, v13.4s, v5.s[0] \n" - "fmla v28.4s, v14.4s, v5.s[1] \n" - "fmla v26.4s, v11.4s, v5.s[2] \n" - "fmla v28.4s, v12.4s, v5.s[3] \n" - "fmla v26.4s, v15.4s, v30.s[0] \n" - - "fmla v27.4s, v13.4s, v2.s[2] \n" - "fmla v29.4s, v14.4s, v2.s[3] \n" - "fmla v27.4s, v11.4s, v3.s[0] \n" - "fmla v29.4s, v12.4s, v3.s[1] \n" - - "fadd v26.4s, v26.4s, v28.4s \n" - "fmla v27.4s, v15.4s, v3.s[2] \n" - "fmla v29.4s, v18.4s, v3.s[3] \n" - "fmla v27.4s, v19.4s, v4.s[0] \n" - "fmla v29.4s, v16.4s, v4.s[1] \n" - - "fmax v26.4s, v26.4s, v31.4s \n" - "fmla v27.4s, v17.4s, v4.s[2] \n" - "fmla v29.4s, v20.4s, v4.s[3] \n" - "fmla v27.4s, v23.4s, v5.s[0] \n" - "fmla v29.4s, v24.4s, v5.s[1] \n" - - "st1 {v26.4s}, [%[out_buf0]] \n" - "fmla v27.4s, v21.4s, v5.s[2] \n" - "fmla v29.4s, v22.4s, v5.s[3] \n" - "fmla v27.4s, v25.4s, v30.s[0] \n" - "fadd v27.4s, v27.4s, v29.4s \n" - - "fmax v27.4s, v27.4s, v31.4s \n" - "st1 {v27.4s}, [%[out_buf1]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [dout_ptr1] "+r"(dout_ptr1), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [din_ptr6] "+r"(din_ptr6), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - dout_ptr1[i] = out_buf1[i]; - } - din0 = din4; - din1 = din5; - din2 = din6; - din3 = din6 + w_in; - din4 = din3 + w_in; - din5 = din4 + w_in; - din6 = din5 + w_in; - dout0 = dout1 + w_out; - dout1 = dout0 + w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - //! in r0/r4, r1, r2, r3: x 0 2 4 -- v8 v13 v18 v23 v28 - //! in r0/r4, r1, r2, r3: x 1 3 5 -- v9 v14 v19 v24 v29 - //! in r0/r4, r1, r2, r3: 0 2 4 6 -- v6 v11 v16 v21 v26 - //! in r0/r4, r1, r2, r3: 1 3 5 7 -- v7 v12 v17 v22 v27 - //! in r0/r4, r1, r2, r3: 2 4 6 8 -- v10 v15 v20 v25 v30 - //! out r0 -- v4 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - - //! load mask - "ld2 {v0.4s, v1.4s}, [%[mask]], %[s_8] \n" - "ld2 {v2.4s, v3.4s}, [%[mask]], %[s_8] \n" - "ld2 {v4.4s, v5.4s}, [%[mask]] \n" - - //! load and extract input - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" - "ld2 {v26.4s, v27.4s}, [%[din_ptr4]], #32 \n" - - "ext v8.16b, v31.16b, v6.16b, #12 \n" - "ext v9.16b, v31.16b, v7.16b, #12 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" - "ext v14.16b, v31.16b, v12.16b, #12 \n" - - "ext v18.16b, v31.16b, v16.16b, #12 \n" - "ext v19.16b, v31.16b, v17.16b, #12 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" - "ext v24.16b, v31.16b, v22.16b, #12 \n" - "ext v28.16b, v31.16b, v26.16b, #12 \n" - "ext v29.16b, v31.16b, v27.16b, #12 \n" - - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ext v30.16b, v26.16b, v31.16b, #4 \n" - - "bif v8.16b, v31.16b, v0.16b \n" - "bif v9.16b, v31.16b, v1.16b \n" - "bif v6.16b, v31.16b, v2.16b \n" - "bif v7.16b, v31.16b, v3.16b \n" - - "bif v13.16b, v31.16b, v0.16b \n" - "bif v14.16b, v31.16b, v1.16b \n" - "bif v11.16b, v31.16b, v2.16b \n" - "bif v12.16b, v31.16b, v3.16b \n" - - "bif v18.16b, v31.16b, v0.16b \n" - "bif v19.16b, v31.16b, v1.16b \n" - "bif v16.16b, v31.16b, v2.16b \n" - "bif v17.16b, v31.16b, v3.16b \n" - - "ld1 {v10.s}[3], [%[din_ptr0]] \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" - "ld1 {v30.s}[3], [%[din_ptr4]] \n" - - "bif v23.16b, v31.16b, v0.16b \n" - "bif v24.16b, v31.16b, v1.16b \n" - "bif v21.16b, v31.16b, v2.16b \n" - "bif v22.16b, v31.16b, v3.16b \n" - - "bif v28.16b, v31.16b, v0.16b \n" - "bif v29.16b, v31.16b, v1.16b \n" - "bif v26.16b, v31.16b, v2.16b \n" - "bif v27.16b, v31.16b, v3.16b \n" - - "bif v10.16b, v31.16b, v4.16b \n" - "bif v15.16b, v31.16b, v4.16b \n" - "bif v20.16b, v31.16b, v4.16b \n" - "bif v25.16b, v31.16b, v4.16b \n" - "bif v30.16b, v31.16b, v4.16b \n" - - "ld1 {v4.4s}, [%[vbias]] \n" - "mov v5.16b, v31.16b \n" - - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load weights 0-7 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load weights 8-15 - - //! compute - "fmla v4.4s, v8.4s, v0.s[0] \n" // out r0: w0 - "fmla v5.4s, v9.4s, v0.s[1] \n" // out r0: w1 - "fmla v4.4s, v6.4s, v0.s[2] \n" // out r0: w2 - "fmla v5.4s, v7.4s, v0.s[3] \n" // out r0: w3 - - "fmla v4.4s, v10.4s, v1.s[0] \n" // out r0: w4 - "fmla v5.4s, v13.4s, v1.s[1] \n" // out r0: w5 - "fmla v4.4s, v14.4s, v1.s[2] \n" // out r0: w6 - "fmla v5.4s, v11.4s, v1.s[3] \n" // out r0: w7 - - "ld1 {v6.4s, v7.4s}, [%[weights]], #32 \n" // load weights 16-23 - "ld1 {v8.s}[0], [%[weights]] \n" // load weights 24 - - "fmla v4.4s, v12.4s, v2.s[0] \n" // out r0: w8 - "fmla v5.4s, v15.4s, v2.s[1] \n" // out r0: w9 - "fmla v4.4s, v18.4s, v2.s[2] \n" // out r0: w10 - "fmla v5.4s, v19.4s, v2.s[3] \n" // out r0: w11 - - "fmla v4.4s, v16.4s, v3.s[0] \n" // out r0: w12 - "fmla v5.4s, v17.4s, v3.s[1] \n" // out r0: w13 - "fmla v4.4s, v20.4s, v3.s[2] \n" // out r0: w14 - "fmla v5.4s, v23.4s, v3.s[3] \n" // out r0: w15 - - "fmla v4.4s, v24.4s, v6.s[0] \n" // out r0: w16 - "fmla v5.4s, v21.4s, v6.s[1] \n" // out r0: w17 - "fmla v4.4s, v22.4s, v6.s[2] \n" // out r0: w18 - "fmla v5.4s, v25.4s, v6.s[3] \n" // out r0: w19 - - "fmla v4.4s, v28.4s, v7.s[0] \n" // out r0: w20 - "fmla v5.4s, v29.4s, v7.s[1] \n" // out r0: w21 - "fmla v4.4s, v26.4s, v7.s[2] \n" // out r0: w22 - "fmla v5.4s, v27.4s, v7.s[3] \n" // out r0: w23 - "fmla v4.4s, v30.4s, v8.s[0] \n" // out r0: w24 - - "fadd v4.4s, v4.4s, v5.4s \n" // add out to v4 - "st1 {v4.4s}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_relu_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - //! in r0/r4, r1, r2, r3: x 0 2 4 -- v8 v13 v18 v23 v28 - //! in r0/r4, r1, r2, r3: x 1 3 5 -- v9 v14 v19 v24 v29 - //! in r0/r4, r1, r2, r3: 0 2 4 6 -- v6 v11 v16 v21 v26 - //! in r0/r4, r1, r2, r3: 1 3 5 7 -- v7 v12 v17 v22 v27 - //! in r0/r4, r1, r2, r3: 2 4 6 8 -- v10 v15 v20 v25 v30 - //! out r0 -- v4 - asm volatile( - "movi v31.4s, #0x0\n" - "prfm pldl1keep, [%[din_ptr0]] \n" - "prfm pldl1keep, [%[din_ptr1]] \n" - "prfm pldl1keep, [%[din_ptr2]] \n" - "prfm pldl1keep, [%[din_ptr3]] \n" - "prfm pldl1keep, [%[din_ptr4]] \n" - "prfm pldl1keep, [%[weights]] \n" - "prfm pldl1keep, [%[mask]] \n" - - //! load mask - "ld2 {v0.4s, v1.4s}, [%[mask]], %[s_8] \n" - "ld2 {v2.4s, v3.4s}, [%[mask]], %[s_8] \n" - "ld2 {v4.4s, v5.4s}, [%[mask]] \n" - - //! load and extract input - "ld2 {v6.4s, v7.4s}, [%[din_ptr0]], #32 \n" - "ld2 {v11.4s, v12.4s}, [%[din_ptr1]], #32 \n" - "ld2 {v16.4s, v17.4s}, [%[din_ptr2]], #32 \n" - "ld2 {v21.4s, v22.4s}, [%[din_ptr3]], #32 \n" - "ld2 {v26.4s, v27.4s}, [%[din_ptr4]], #32 \n" - - "ext v8.16b, v31.16b, v6.16b, #12 \n" - "ext v9.16b, v31.16b, v7.16b, #12 \n" - "ext v13.16b, v31.16b, v11.16b, #12 \n" - "ext v14.16b, v31.16b, v12.16b, #12 \n" - - "ext v18.16b, v31.16b, v16.16b, #12 \n" - "ext v19.16b, v31.16b, v17.16b, #12 \n" - "ext v23.16b, v31.16b, v21.16b, #12 \n" - "ext v24.16b, v31.16b, v22.16b, #12 \n" - "ext v28.16b, v31.16b, v26.16b, #12 \n" - "ext v29.16b, v31.16b, v27.16b, #12 \n" - - "ext v10.16b, v6.16b, v31.16b, #4 \n" - "ext v15.16b, v11.16b, v31.16b, #4 \n" - "ext v20.16b, v16.16b, v31.16b, #4 \n" - "ext v25.16b, v21.16b, v31.16b, #4 \n" - "ext v30.16b, v26.16b, v31.16b, #4 \n" - - "bif v8.16b, v31.16b, v0.16b \n" - "bif v9.16b, v31.16b, v1.16b \n" - "bif v6.16b, v31.16b, v2.16b \n" - "bif v7.16b, v31.16b, v3.16b \n" - - "bif v13.16b, v31.16b, v0.16b \n" - "bif v14.16b, v31.16b, v1.16b \n" - "bif v11.16b, v31.16b, v2.16b \n" - "bif v12.16b, v31.16b, v3.16b \n" - - "bif v18.16b, v31.16b, v0.16b \n" - "bif v19.16b, v31.16b, v1.16b \n" - "bif v16.16b, v31.16b, v2.16b \n" - "bif v17.16b, v31.16b, v3.16b \n" - - "ld1 {v10.s}[3], [%[din_ptr0]] \n" - "ld1 {v15.s}[3], [%[din_ptr1]] \n" - "ld1 {v20.s}[3], [%[din_ptr2]] \n" - "ld1 {v25.s}[3], [%[din_ptr3]] \n" - "ld1 {v30.s}[3], [%[din_ptr4]] \n" - - "bif v23.16b, v31.16b, v0.16b \n" - "bif v24.16b, v31.16b, v1.16b \n" - "bif v21.16b, v31.16b, v2.16b \n" - "bif v22.16b, v31.16b, v3.16b \n" - - "bif v28.16b, v31.16b, v0.16b \n" - "bif v29.16b, v31.16b, v1.16b \n" - "bif v26.16b, v31.16b, v2.16b \n" - "bif v27.16b, v31.16b, v3.16b \n" - - "bif v10.16b, v31.16b, v4.16b \n" - "bif v15.16b, v31.16b, v4.16b \n" - "bif v20.16b, v31.16b, v4.16b \n" - "bif v25.16b, v31.16b, v4.16b \n" - "bif v30.16b, v31.16b, v4.16b \n" - - "ld1 {v4.4s}, [%[vbias]] \n" - "mov v5.16b, v31.16b \n" - - "ld1 {v0.4s, v1.4s}, [%[weights]], #32 \n" // load weights 0-7 - "ld1 {v2.4s, v3.4s}, [%[weights]], #32 \n" // load weights 8-15 - - //! compute - "fmla v4.4s, v8.4s, v0.s[0] \n" // out r0: w0 - "fmla v5.4s, v9.4s, v0.s[1] \n" // out r0: w1 - "fmla v4.4s, v6.4s, v0.s[2] \n" // out r0: w2 - "fmla v5.4s, v7.4s, v0.s[3] \n" // out r0: w3 - - "fmla v4.4s, v10.4s, v1.s[0] \n" // out r0: w4 - "fmla v5.4s, v13.4s, v1.s[1] \n" // out r0: w5 - "fmla v4.4s, v14.4s, v1.s[2] \n" // out r0: w6 - "fmla v5.4s, v11.4s, v1.s[3] \n" // out r0: w7 - - "ld1 {v6.4s, v7.4s}, [%[weights]], #32 \n" // load weights 16-23 - "ld1 {v8.s}[0], [%[weights]] \n" // load weights 24 - - "fmla v4.4s, v12.4s, v2.s[0] \n" // out r0: w8 - "fmla v5.4s, v15.4s, v2.s[1] \n" // out r0: w9 - "fmla v4.4s, v18.4s, v2.s[2] \n" // out r0: w10 - "fmla v5.4s, v19.4s, v2.s[3] \n" // out r0: w11 - - "fmla v4.4s, v16.4s, v3.s[0] \n" // out r0: w12 - "fmla v5.4s, v17.4s, v3.s[1] \n" // out r0: w13 - "fmla v4.4s, v20.4s, v3.s[2] \n" // out r0: w14 - "fmla v5.4s, v23.4s, v3.s[3] \n" // out r0: w15 - - "fmla v4.4s, v24.4s, v6.s[0] \n" // out r0: w16 - "fmla v5.4s, v21.4s, v6.s[1] \n" // out r0: w17 - "fmla v4.4s, v22.4s, v6.s[2] \n" // out r0: w18 - "fmla v5.4s, v25.4s, v6.s[3] \n" // out r0: w19 - - "fmla v4.4s, v28.4s, v7.s[0] \n" // out r0: w20 - "fmla v5.4s, v29.4s, v7.s[1] \n" // out r0: w21 - "fmla v4.4s, v26.4s, v7.s[2] \n" // out r0: w22 - "fmla v5.4s, v27.4s, v7.s[3] \n" // out r0: w23 - "fmla v4.4s, v30.4s, v8.s[0] \n" // out r0: w24 - - "fadd v4.4s, v4.4s, v5.4s \n" // add out to v4 - "fmax v4.4s, v4.4s, v31.4s \n" - "st1 {v4.4s}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [out_buf1] "r"(out_buf1), - [s_8] "r"(s_8) - : "memory", - "cc", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25", - "v26", - "v27", - "v28", - "v29", - "v30", - "v31"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - +#ifdef __aarch64__ + act_switch_5x5s2(inr0, + inr1, + inr2, + inr3, + inr4, + outc0, + outc1, + outc2, + outc3, + w0, + w1, + w2, + w3, + w4, + vbias, + weight_c, + bias_local, + act_param); #else - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - // printf("invoke 5x5s2p2 armv7\n"); - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float* dout0 = dout_ch; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "pld [%[mask]] \n" - - // left - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - "vld1.32 {d26-d29}, [%[vbias]] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[din_ptr0], #8 \n" - - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r1 - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[din_ptr1], #8 \n" - - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r2 - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[din_ptr2], #8 \n" - - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r3 - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[din_ptr3], #8 \n" - - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "sub %[din_ptr4], #8 \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmov.32 q12, %q[w0] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - "vmla.f32 q13, q10, %e[w0][0] \n" - "vadd.f32 q13, q13, q14 \n" - "vmov.32 %q[w0], q12 \n" - "cmp %[mid_cnt], #1 \n" - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "pld [%[din_ptr0]] \n" - "blt 2f \n" - - // mid - "1: \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr1]], %[s_16] \n" - - // r1 - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "pld [%[din_ptr1]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr2]], %[s_16] \n" - - // r2 - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "pld [%[din_ptr2]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr3]], %[s_16] \n" - - // r3 - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "pld [%[din_ptr3]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr4]], %[s_16] \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - "pld [%[din_ptr4]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w0][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - - "vmov.32 %q[w0], q12 \n" - "vadd.f32 q13, q13, q14 \n" - "subs %[mid_cnt], #1 \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "bne 1b \n" - - "2: \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d20-d23}, [%[din_ptr1]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d20-d23}, [%[din_ptr2]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d20-d23}, [%[din_ptr3]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d20-d23}, [%[din_ptr4]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! larger depthwise, win >= 9; -void conv_depthwise_5x5s2p2_relu(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - // printf("invoke 5x5s2p2 armv7\n"); - CHECK_GE(w_in, 9) << "only support win >= 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int cnt = (w_out_round - 4) / 4; - int mid_cnt = cnt - 1; - int right_start = cnt * 2 * 4 - 2; - int mask_cnt = 12 - (w_in - right_start); - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float* dout0 = dout_ch; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - int loop = mid_cnt; - const int s_8 = 8; - const int s_16 = 16; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "pld [%[mask]] \n" - - // left - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - "vld1.32 {d26-d29}, [%[vbias]] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[din_ptr0], #8 \n" - - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r1 - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[din_ptr1], #8 \n" - - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r2 - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[din_ptr2], #8 \n" - - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r3 - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[din_ptr3], #8 \n" - - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "sub %[din_ptr4], #8 \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmov.32 q12, %q[w0] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - "vmla.f32 q13, q10, %e[w0][0] \n" - "vadd.f32 q13, q13, q14 \n" - "vmov.f32 %q[w0], q12 \n" - "vmax.f32 q13, q13, q15 \n" - "cmp %[mid_cnt], #1 \n" - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "pld [%[din_ptr0]] \n" - "blt 2f \n" - - // mid - "1: \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w1][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr1]], %[s_16] \n" - - // r1 - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - "pld [%[din_ptr1]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w2][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr2]], %[s_16] \n" - - // r2 - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - "pld [%[din_ptr2]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w3][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr3]], %[s_16] \n" - - // r3 - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - "pld [%[din_ptr3]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - - "vmla.f32 q13, q10, %f[w4][1] \n" - - "vld2.32 {d20-d23}, [%[din_ptr4]], %[s_16] \n" - - // r4 - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - "pld [%[din_ptr4]] \n" - - "vld2.32 {d12-d15}, [%[din_ptr0]], %[s_8] \n" - "vld1.32 {%e[w0][0]}, [%[weights]] \n" - - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d16-d19}, [%[din_ptr0]], %[s_8] \n" - - "vmla.f32 q13, q10, %e[w0][0] \n" - - "vld2.32 {d20-d23}, [%[din_ptr0]], %[s_16] \n" - - "vmov.32 %q[w0], q12 \n" - "vadd.f32 q13, q13, q14 \n" - "vmax.f32 q13, q13, q15 \n" - "subs %[mid_cnt], #1 \n" - "vst1.32 {d26-d27}, [%[dout_ptr0]]! \n" - "bne 1b \n" - - "2: \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - - // r0 - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr1]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d20-d23}, [%[din_ptr1]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w1][1] \n" - "vmla.f32 q14, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w1][1] \n" - "vmla.f32 q14, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr2]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d20-d23}, [%[din_ptr2]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d12-d15}, [%[din_ptr3]], %[s_8] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr3]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d20-d23}, [%[din_ptr3]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w3][1] \n" - "vmla.f32 q14, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld2.32 {d12-d15}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w4][1] \n" - "vmla.f32 q14, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "sub %[mask], #16 \n" - "vld2.32 {d16-d19}, [%[din_ptr4]], %[s_8] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d20-d23}, [%[din_ptr4]] \n" - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vmax.f32 q13, q13, q15 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [mid_cnt] "+r"(loop), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5), - [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [s_16] "r"(s_16) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - int remain_cnt = w_out - (mid_cnt + 1) * 4; - for (int i = 0; i < remain_cnt; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; - } - } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - - // r0 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %e[w1][1] \n" - "vmla.f32 q13, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %f[w1][1] \n" - "vmla.f32 q13, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %f[w3][1] \n" - "vmla.f32 q13, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %e[w4][1] \n" - "vmla.f32 q13, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; - } - } - } -} - -//! small depthwise, win < 9; -void conv_depthwise_5x5s2p2_relu_s(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - CHECK_LT(w_in, 9) << "only support win < 9\n"; - int w_out_round = (w_out + 3) / 4 * 4; - int mask_cnt = 12 - w_in - 2; - int mask[12]; - memset(mask, 0xff, 12 * sizeof(int)); - for (int i = 0; i < mask_cnt; ++i) { - mask[11 - i] = 0; - } - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - int in_spatial_size = w_in * h_in; - int out_spatial_size = w_out * h_out; - int weights_saptial_size = 25; - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * in_spatial_size * ch_in; - float* dout_batch = dout + n * out_spatial_size * ch_out; -#pragma omp parallel for - for (int c = 0; c < ch_in; ++c) { - const float* din_ch = din_batch + c * in_spatial_size; - float* dout_ch = dout_batch + c * out_spatial_size; - const float* din0 = zero_ptr; - const float* din1 = zero_ptr; - const float* din2 = din_ch; - const float* din3 = din2 + w_in; - const float* din4 = din3 + w_in; - - float out_buf0[4]; - float out_buf1[4]; - float* dout0 = dout_ch; - float* dout1 = dout0 + w_out; - - const float* weights_c = weights + c * weights_saptial_size; - float32x4_t w0 = vld1q_f32(weights_c); - float32x4_t w1 = vld1q_f32(weights_c + 4); - float32x4_t w2 = vld1q_f32(weights_c + 8); - float32x4_t w3 = vld1q_f32(weights_c + 12); - float32x4_t w4 = vld1q_f32(weights_c + 16); - float32x4_t w5 = vld1q_f32(weights_c + 20); - for (int h = 0; h < h_out; h += 1) { - //! (h * 2 - 2) + 4 > h_in - 1 - if (h * 2 + 3 > h_in) { - switch (h * 2 + 3 - h_in) { - case 4: - din1 = zero_ptr; - case 3: - din2 = zero_ptr; - case 2: - din3 = zero_ptr; - case 1: - din4 = zero_ptr; - default: - break; + act_switch_5x5s2(inr0, + inr1, + inr2, + inr3, + inr4, + outc0, + outc1, + outc2, + outc3, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + weight_c, + bias_local, + act_param); +#endif + if (flag_mask) { + for (int i = 0; i < remain; ++i) { + c0[i] = pre_out[i]; + c1[i] = pre_out[i + 4]; + c2[i] = pre_out[i + 8]; + c3[i] = pre_out[i + 12]; + } } + inr0 += 32; + inr1 += 32; + inr2 += 32; + inr3 += 32; + inr4 += 32; + outc0 += 4; + outc1 += 4; + outc2 += 4; + outc3 += 4; } - const float* din_ptr0 = din0; - const float* din_ptr1 = din1; - const float* din_ptr2 = din2; - const float* din_ptr3 = din3; - const float* din_ptr4 = din4; - - const float* weights_ptr = weights_c + 24; - float* dout_ptr0 = dout0; - - float bias_c = 0.f; - if (flag_bias) { - bias_c = bias[c]; - } - float vbias[4] = {bias_c, bias_c, bias_c, bias_c}; - int* mask_ptr = mask; - const int s_8 = 8; - - asm volatile( - "vmov.i32 q15, #0x0 \n" - "pld [%[din_ptr0]] \n" - "pld [%[din_ptr1]] \n" - "pld [%[din_ptr2]] \n" - "pld [%[din_ptr3]] \n" - "pld [%[din_ptr4]] \n" - "vld1.32 {d26-d27}, [%[vbias]] \n" - "vmov.32 q14, q15 \n" - "vld2.32 {d16-d19}, [%[din_ptr0]]! \n" - - // r0 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr0]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w0][0] \n" - "vmla.f32 q14, q7, %e[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w0][0] \n" - "vmla.f32 q14, q9, %f[w0][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr1]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %e[w1][0] \n" - - // r1 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr1]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %e[w1][1] \n" - "vmla.f32 q13, q7, %f[w1][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %f[w1][1] \n" - "vmla.f32 q13, q9, %e[w2][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr2]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %e[w2][1] \n" - - // r2 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr2]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %f[w2][0] \n" - "vmla.f32 q14, q7, %f[w2][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %e[w3][0] \n" - "vmla.f32 q14, q9, %e[w3][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr3]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, %f[w3][0] \n" - - // r3 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr3]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q14, q6, %f[w3][1] \n" - "vmla.f32 q13, q7, %e[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q14, q8, %e[w4][1] \n" - "vmla.f32 q13, q9, %f[w4][0] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vld2.32 {d16-d19}, [%[din_ptr4]]! \n" - "sub %[mask], #16 \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q14, q10, %f[w4][1] \n" - - // r4 - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vext.32 q6, q15, q8, #3 \n" - "vext.32 q7, q15, q9, #3 \n" - "vext.32 q10, q8, q15, #1 \n" - "vld1.32 {d21[1]}, [%[din_ptr4]] \n" - - "vbif.32 q6, q15, q11 \n" - "vbif.32 q7, q15, q12 \n" - "vmla.f32 q13, q6, %e[w5][0] \n" - "vmla.f32 q14, q7, %e[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]], %[s_8] \n" - "vld1.32 {d12[0]}, [%[weights]] \n" - "vbif.32 q8, q15, q11 \n" - "vbif.32 q9, q15, q12 \n" - "vmla.f32 q13, q8, %f[w5][0] \n" - "vmla.f32 q14, q9, %f[w5][1] \n" - - "vld2.32 {d22-d25}, [%[mask]] \n" - "vbif.32 q10, q15, q11 \n" - "vmla.f32 q13, q10, d12[0] \n" - - "vadd.f32 q13, q13, q14 \n" - "vmax.f32 q13, q13, q15 \n" - "vst1.32 {d26-d27}, [%[out_buf0]] \n" - - : [dout_ptr0] "+r"(dout_ptr0), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [mask] "+r"(mask_ptr), - [weights] "+r"(weights_ptr) - : [vbias] "r"(vbias), - [out_buf0] "r"(out_buf0), - [s_8] "r"(s_8), - [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [w5] "w"(w5) - : "memory", - "cc", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - for (int i = 0; i < w_out; ++i) { - dout_ptr0[i] = out_buf0[i]; - } - din0 = din2; - din1 = din3; - din2 = din4; - din3 = din2 + w_in; - din4 = din3 + w_in; - dout0 += w_out; } } } } -#endif // __aarch64__ } // namespace math } // namespace arm diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index bb85e74774..b5dd1b58c4 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -150,11 +150,26 @@ void conv_depthwise_5x5s2_fp32(const float* din, int win, const float* weights, const float* bias, - int pad, - bool flag_bias, - bool flag_relu, + const operators::ConvParam& param, + const operators::ActivationParam act_param, ARMContext* ctx); +void conv_depthwise_5x5s2p2_fp32(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + bool flag_relu, + ARMContext* ctx); + template void conv_depthwise_5x5s1_int8(Dtype* dout, const int8_t* din, diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index c327c715c5..d4d24fdd90 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -589,10 +589,9 @@ void conv_depthwise_3x3_fp32(const void* din, int stride = param.strides[1]; int pad = pad_w; bool flag_bias = param.bias != nullptr; - bool pads_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); + bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); if (stride == 1) { - if (pads_equal && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] + if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s1_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -624,9 +623,8 @@ void conv_depthwise_3x3_fp32(const void* din, act_param, ctx); } - } else if (stride == 2) { - if (pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] + if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] conv_depthwise_3x3s2_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -678,12 +676,13 @@ void conv_depthwise_5x5_fp32(const void* din, ARMContext* ctx, const float* scale) { auto paddings = *param.paddings; + auto act_param = param.activation_param; int pad = paddings[0]; int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; ctx->ExtendWorkspace((w_in + w_out) * sizeof(float)); - if (pad == 2 && stride == 2) { + if (stride == 2) { conv_depthwise_5x5s2_fp32(reinterpret_cast(din), reinterpret_cast(dout), num, @@ -695,9 +694,8 @@ void conv_depthwise_5x5_fp32(const void* din, w_in, reinterpret_cast(weights), bias, - pad, - flag_bias, - flag_relu, + param, + act_param, ctx); } else if (stride == 1) { conv_depthwise_5x5s1_fp32(reinterpret_cast(din), diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index 9d42fd98df..06620da798 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -898,6 +898,119 @@ void pooling_global_avg(const float* din, } } +void pooling1x1s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + int win_ext = w_unroll_size * 8; + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + memset(zero_ptr, 0, win * sizeof(float)); + auto write_ptr = + static_cast(TargetMalloc(TARGET(kARM), wout * sizeof(float))); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + for (int h = 0; h < hout; h += 4) { + const float* din0_ptr = data_in_channel + h * 2 * win; + const float* din1_ptr = din0_ptr + 2 * win; + const float* din2_ptr = din1_ptr + 2 * win; + const float* din3_ptr = din2_ptr + 2 * win; + + float* doutr0 = data_out_channel + h * wout; + float* doutr1 = doutr0 + wout; + float* doutr2 = doutr1 + wout; + float* doutr3 = doutr2 + wout; + if (h + 4 > hout) { + switch (h + 4 - hout) { + case 3: + doutr1 = write_ptr; + case 2: + doutr2 = write_ptr; + case 1: + doutr3 = write_ptr; + default: + break; + } + } + if (h * 2 + 4 >= hin) { + switch (h * 2 + 4 - hin) { + case 4: + din0_ptr = zero_ptr; + case 3: + case 2: + din1_ptr = zero_ptr; + case 1: + case 0: + din2_ptr = zero_ptr; + din3_ptr = zero_ptr; + default: + break; + } + } + for (int i = 0; i < w_unroll_size; i++) { + float32x4x2_t din0 = vld2q_f32(din0_ptr); + float32x4x2_t din1 = vld2q_f32(din1_ptr); + float32x4x2_t din2 = vld2q_f32(din2_ptr); + float32x4x2_t din3 = vld2q_f32(din3_ptr); + din0_ptr += 8; + din1_ptr += 8; + din2_ptr += 8; + din3_ptr += 8; + + vst1q_f32(doutr0, din0.val[0]); + vst1q_f32(doutr1, din1.val[0]); + vst1q_f32(doutr2, din2.val[0]); + vst1q_f32(doutr3, din3.val[0]); + + doutr0 += 4; + doutr1 += 4; + doutr2 += 4; + doutr3 += 4; + } + int j = win_ext; + for (int i = 0; i < w_unroll_remian; i++) { + if (j >= win) { + *doutr0++ = 0.f; + *doutr1++ = 0.f; + *doutr2++ = 0.f; + *doutr3++ = 0.f; + } else { + *doutr0++ = *din0_ptr; + *doutr1++ = *din1_ptr; + *doutr2++ = *din2_ptr; + *doutr3++ = *din3_ptr; + din0_ptr += 2; + din1_ptr += 2; + din2_ptr += 2; + din3_ptr += 2; + } + j += 2; + } + } + } + } + TargetFree(TARGET(kARM), zero_ptr); + TargetFree(TARGET(kARM), write_ptr); +} + void pooling2x2s2_max(const float* din, float* dout, int num, diff --git a/lite/backends/arm/math/pooling.h b/lite/backends/arm/math/pooling.h index 9288f27bbc..701732cb45 100644 --- a/lite/backends/arm/math/pooling.h +++ b/lite/backends/arm/math/pooling.h @@ -64,6 +64,16 @@ void pooling_global_avg(const float* din, int hin, int win); +void pooling1x1s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + void pooling2x2s2_max(const float* din, float* dout, int num, diff --git a/lite/core/mir/fusion/conv_activation_fuser.cc b/lite/core/mir/fusion/conv_activation_fuser.cc index 6ba11a6a4e..993fe4e944 100644 --- a/lite/core/mir/fusion/conv_activation_fuser.cc +++ b/lite/core/mir/fusion/conv_activation_fuser.cc @@ -79,6 +79,9 @@ cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetAttr("act_type", act_type_); if (act_type_ == "relu") { op_desc.SetAttr("fuse_relu", true); + } else if (act_type_ == "relu6") { + float alpha = act_op_desc.GetAttr("threshold"); + op_desc.SetAttr("fuse_brelu_threshold", alpha); } else if (act_type_ == "leaky_relu") { float alpha = act_op_desc.GetAttr("alpha"); op_desc.SetAttr("leaky_relu_alpha", alpha); diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 52849a026e..4afb8f020e 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -56,13 +56,12 @@ void ConvCompute::PrepareForRun() { bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2)); - bool flag_dw_5x5 = pads_all_equal && ((kw == 5 && stride == 1) || - (kw == 5 && stride == 2 && pad == 2)); + bool flag_dw_5x5 = (paddings[0] == paddings[2]) && + ((kw == 5 && stride == 1) || (kw == 5 && stride == 2)); bool flag_dw = flag_dw_3x3 || flag_dw_5x5; /// select conv impl - if (param.groups == ic && ic == oc && kps_equal && pads_equal && - no_dilation && flag_dw) { + if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { /// dw conv impl impl_ = new DepthwiseConv; // VLOG(3) << "invoking dw conv"; diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index adaae92472..10c190806f 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -28,16 +28,13 @@ void DepthwiseConv::PrepareForRun() { auto& ctx = this->ctx_->template As(); auto w_dims = param.filter->dims(); auto kw = w_dims[3]; + auto paddings = *param.paddings; // select dw conv kernel if (kw == 3) { // VLOG(5) << "invoke 3x3 dw conv fp32"; - auto paddings = *param.paddings; - bool pads_equal = - ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); - - if (pads_equal && paddings[0] == paddings[2] && + bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); + if (pads_less && paddings[0] == paddings[2] && (paddings[0] == 0 || paddings[0] == 1)) { - impl_ = lite::arm::math::conv_depthwise_3x3_fp32; flag_trans_weights_ = false; } else { // trans weights @@ -50,11 +47,25 @@ void DepthwiseConv::PrepareForRun() { auto w_data_in = param.filter->data(); lite::arm::math::conv_trans_weights_numc( w_data_in, w_data, oc, 1, cblock, kh * kw); - impl_ = lite::arm::math::conv_depthwise_3x3_fp32; flag_trans_weights_ = true; } + impl_ = lite::arm::math::conv_depthwise_3x3_fp32; } else if (kw == 5) { // VLOG(5) << "invoke 5x5 dw conv fp32"; + if (param.strides[0] == 2) { // conv5x5s2_dw + constexpr int cblock = 4; + auto oc = w_dims[0]; + auto kh = w_dims[2]; + auto cround = ROUNDUP(oc, cblock); + weights_.Resize({cround, 1, kh, kw}); + auto w_data = weights_.mutable_data(); + auto w_data_in = param.filter->data(); + lite::arm::math::conv_trans_weights_numc( + w_data_in, w_data, oc, 1, cblock, kh * kw); + flag_trans_weights_ = true; + } else { + flag_trans_weights_ = false; + } impl_ = lite::arm::math::conv_depthwise_5x5_fp32; } else { LOG(FATAL) << "this type dw conv not impl"; diff --git a/lite/kernels/arm/pool_compute.cc b/lite/kernels/arm/pool_compute.cc index f97d58f964..7ff4222563 100644 --- a/lite/kernels/arm/pool_compute.cc +++ b/lite/kernels/arm/pool_compute.cc @@ -85,7 +85,22 @@ void PoolCompute::Run() { return; } } else { - if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) { + if (ksize[0] == 1 && strides[0] == 2 && paddings[0] == 0 && kps_equal) { + auto& ctx = this->ctx_->template As(); + if (pooling_type == "max") { + lite::arm::math::pooling1x1s2p0_max(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3]); + return; + } + } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && + kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling2x2s2_max(din, dout, diff --git a/lite/operators/conv_op.h b/lite/operators/conv_op.h index 6e1c0bb3d4..63107022f1 100644 --- a/lite/operators/conv_op.h +++ b/lite/operators/conv_op.h @@ -85,6 +85,10 @@ class ConvOpLite : public OpLite { if (act_type == "relu") { param_.activation_param.active_type = lite_api::ActivationType::kRelu; param_.fuse_relu = true; + } else if (act_type == "relu6") { + param_.activation_param.active_type = lite_api::ActivationType::kRelu6; + param_.activation_param.Relu_clipped_coef = + op_desc.GetAttr("fuse_brelu_threshold"); // 6.f } else if (act_type == "leaky_relu") { param_.activation_param.active_type = lite_api::ActivationType::kLeakyRelu; -- GitLab