diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index 1c0e8e5bf9ad350e8948e06808c9510e476139bd..831c28bf528d34984aecf269fd340ba3d6f6fd6e 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -82,6 +82,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) conv5x5s1_depthwise_fp32.cc conv5x5s2_depthwise_int8.cc conv5x5s2_depthwise_fp32.cc + conv5x5s2_depthwise_fp32_c4.cc conv3x3_winograd_fp32_c4.cc conv3x3_winograd_int8.cc conv_winograd_3x3.cc diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc index a72b7553e0c8fddcb9028b0e6125281a07e65387..a2cbf8b59105087a4f806855042f0e3156955ae0 100644 --- a/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32.cc @@ -13,733 +13,1763 @@ // limitations under the License. #include -#include "lite/backends/arm/math/conv_block_utils.h" #include "lite/backends/arm/math/conv_depthwise.h" -#include "lite/core/context.h" -#include "lite/operators/op_params.h" -#ifdef ARM_WITH_OMP -#include -#endif namespace paddle { namespace lite { namespace arm { namespace math { +void conv_depthwise_5x5s2_bias(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + ARMContext* ctx); +void conv_depthwise_5x5s2_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + ARMContext* ctx); +void conv_depthwise_5x5s2_bias_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + ARMContext* ctx); +void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + ARMContext* ctx); +void conv_depthwise_5x5s2_fp32(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + const operators::ActivationParam& act_param, + ARMContext* ctx) { + bool has_active = act_param.has_active; + auto act_type = act_param.active_type; + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + if (has_active) { + switch (act_type) { + case lite_api::ActivationType::kRelu: + conv_depthwise_5x5s2_bias_relu(dout, + din, + weights, + bias, + flag_bias, + num, + chin, + hin, + win, + hout, + wout, + pad_top, + pad_bottom, + pad_left, + pad_right, + ctx); + break; + case lite_api::ActivationType::kRelu6: + conv_depthwise_5x5s2_bias_relu6(dout, + din, + weights, + bias, + vsix, + flag_bias, + num, + chin, + hin, + win, + hout, + wout, + pad_top, + pad_bottom, + pad_left, + pad_right, + ctx); + break; + case lite_api::ActivationType::kLeakyRelu: + conv_depthwise_5x5s2_bias_leakyRelu(dout, + din, + weights, + bias, + vscale, + flag_bias, + num, + chin, + hin, + win, + hout, + wout, + pad_top, + pad_bottom, + pad_left, + pad_right, + ctx); + break; + default: + LOG(FATAL) << "this act_type: " << static_cast(act_type) + << " fuse not support"; + } + } else { + conv_depthwise_5x5s2_bias(dout, + din, + weights, + bias, + flag_bias, + num, + chin, + hin, + win, + hout, + wout, + pad_top, + pad_bottom, + pad_left, + pad_right, + ctx); + } +} +// clang-format off #ifdef __aarch64__ -#define COMPUTE \ - "ldp q0, q1, [%[inr0]], #32\n" /* load r0, 0-1 */ \ - "and v19.16b, %[vbias].16b, %[vbias].16b\n" \ - "ldp q2, q3, [%[inr0]], #32\n" /* load r0, 2-3 */ \ - "and v20.16b, %[vbias].16b, %[vbias].16b\n" \ - "ldp q4, q5, [%[inr0]], #32\n" /* load r0, 4-5 */ \ - "and v21.16b, %[vbias].16b, %[vbias].16b\n" \ - "ldp q6, q7, [%[inr0]], #32\n" /* load r0, 6-7 */ \ - "and v22.16b, %[vbias].16b, %[vbias].16b\n" \ - "ldp q8, q9, [%[inr0]], #32\n" /* load r0, 8-9 */ \ - "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ - "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ - "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ - "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ - "ldr q10, [%[inr0]] \n" /* load r0, 10 */ \ - "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ - "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ - "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ - "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ - "sub %[inr0], %[inr0], #32\n" /* inr0 -= 32 */ \ - "ldp q0, q1, [%[inr1]], #32\n" /* load r1, 0-1 */ \ - "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ - "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ - "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ - "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ - "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ - "fmla v19.4s , %[w3].4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , %[w3].4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , %[w3].4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , %[w3].4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ - "ldp q2, q3, [%[inr1]], #32\n" /* load r1, 2-3 */ \ - "fmla v19.4s , %[w4].4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , %[w4].4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , %[w4].4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , %[w4].4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q4, q5, [%[inr1]], #32\n" /* load r1, 4-5 */ \ - "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ - "ldp q6, q7, [%[inr1]], #32\n" /* load r0, 6-7 */ \ - "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ - "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ - "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ - "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ - "ldp q8, q9, [%[inr1]], #32\n" /* load r0, 8-9 */ \ - "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ - "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ - "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ - "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ - "ldr q10, [%[inr1]] \n" /* load r0, 10 */ \ - "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ - "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ - "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ - "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ - "sub %[inr1], %[inr1], #32\n" /* inr1 -= 32 */ \ - "ldp q0, q1, [%[inr2]], #32\n" /* load r1, 0-1 */ \ - "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ - "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ - "ldp q2, q3, [%[inr2]], #32\n" /* load r1, 2-3 */ \ - "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q4, q5, [%[inr2]], #32\n" /* load r1, 4-5 */ \ - "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ - "ldp q6, q7, [%[inr2]], #32\n" /* load r0, 6-7 */ \ - "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ - "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ - "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ - "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ - "ldp q8, q9, [%[inr2]], #32\n" /* load r0, 8-9 */ \ - "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ - "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ - "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ - "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ - "ldr q10, [%[inr2]] \n" /* load r0, 10 */ \ - "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ - "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ - "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ - "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ - "sub %[inr2], %[inr2], #32\n" /* inr0 -= 32 */ \ - "ldp q0, q1, [%[inr3]], #32\n" /* load r1, 0-1 */ \ - "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ - "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ - "ldp q2, q3, [%[inr3]], #32\n" /* load r1, 2-3 */ \ - "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q4, q5, [%[inr3]], #32\n" /* load r1, 4-5 */ \ - "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ - "ldp q6, q7, [%[inr3]], #32\n" /* load r0, 6-7 */ \ - "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ - "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ - "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ - "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ - "ldp q8, q9, [%[inr3]], #32\n" /* load r0, 8-9 */ \ - "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ - "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ - "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ - "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ - "ldr q10, [%[inr3]] \n" /* load r0, 10 */ \ - "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ - "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ - "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ - "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ - "sub %[inr3], %[inr3], #32\n" /* inr0 -= 32 */ \ - "ldp q0, q1, [%[inr4]], #32\n" /* load r1, 0-1 */ \ - "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ - "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ - "ldp q2, q3, [%[inr4]], #32\n" /* load r1, 2-3 */ \ - "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ - "ldp q4, q5, [%[inr4]], #32\n" /* load r1, 4-5 */ \ - "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ - "ldp q6, q7, [%[inr4]], #32\n" /* load r0, 6-7 */ \ - "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ - "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ - "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ - "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ - "ldp q8, q9, [%[inr4]], #32\n" /* load r0, 8-9 */ \ - "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ - "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ - "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ - "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ - "ldr q10, [%[inr4]] \n" /* load r0, 10 */ \ - "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ - "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ - "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ - "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ - "sub %[inr4], %[inr4], #32\n" /* inr0 -= 32 */ \ - "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ - "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ - "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ - "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ - "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ - "sub %[wc0], %[wc0], #320\n" /* weight -= 320 */ \ - "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \ - "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \ - "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \ - "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \ - "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ - "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ - "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ - "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ -#define RELU /* relu */ \ - "movi v0.4s, #0\n" /* for relu */ \ - "fmax v19.4s, v19.4s, v0.4s\n" \ - "fmax v20.4s, v20.4s, v0.4s\n" \ - "fmax v21.4s, v21.4s, v0.4s\n" \ - "fmax v22.4s, v22.4s, v0.4s\n" -#define RELU6 /* relu6 */ \ - "fmin v19.4s, v19.4s, %[vsix].4s\n" \ - "fmin v20.4s, v20.4s, %[vsix].4s\n" \ - "fmin v21.4s, v21.4s, %[vsix].4s\n" \ - "fmin v22.4s, v22.4s, %[vsix].4s\n" -#define LEAKY_RELU /* LeakyRelu */ \ - "movi v0.4s, #0\n" /* for relu */ \ - "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \ - "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ - "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \ - "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ - "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \ - "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ - "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \ - "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ - "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ - "bif v20.16b, v4.16b, v3.16b \n" /* choose*/ \ - "bif v21.16b, v6.16b, v5.16b \n" /* choose*/ \ - "bif v22.16b, v8.16b, v7.16b \n" /* choose*/ -#define STORE /* save result */ \ - "str q19, [%[outc0]], #16\n" \ - "str q20, [%[outc1]], #16\n" \ - "str q21, [%[outc2]], #16\n" \ - "str q22, [%[outc3]], #16\n" - +#define COMPUTE_ONE_LINE_S2_PRE \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_TWO_LINE_S2_PRE \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_THREE_LINE_S2_PRE \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr2]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr2]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ + "fmla v16.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_FOUR_LINE_S2_PRE \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr2]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr2]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ + "fmla v16.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr3]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr3]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr3].s[0]\n" /*0246*wr3[0]*/ \ + "fmla v15.4s, v10.4s, %[wr3].s[1]\n" /*1357*wr3[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr3].s[2]\n" /*2468*wr3[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr3].s[3]\n" /*3579*wr3[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_FIVE_LINE_S2 \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr2]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr2]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ + "fmla v16.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr3]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr3]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr3].s[0]\n" /*0246*wr3[0]*/ \ + "fmla v15.4s, v10.4s, %[wr3].s[1]\n" /*1357*wr3[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr4]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr4]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr3].s[2]\n" /*2468*wr3[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr3].s[3]\n" /*3579*wr3[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr4].s[0]\n" /*0246*wr4[0]*/ \ + "fmla v16.4s, v10.4s, %[wr4].s[1]\n" /*1357*wr4[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr4].s[2]\n" /*2468*wr4[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr4].s[3]\n" /*3579*wr4[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_FIVE_LINE_S2_OUT2 \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ld1 {v17.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v17.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "fmul v18.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr2]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr2]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "fmla v17.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "fmla v18.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ + "fmla v17.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ + "fmla v18.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v16.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ + "fmla v17.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr3]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr3]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ + "fmla v18.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ + "fmla v17.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ + "fmla v18.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr3].s[0]\n" /*0246*wr3[0]*/ \ + "fmla v17.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ + "fmla v15.4s, v10.4s, %[wr3].s[1]\n" /*1357*wr3[1]*/ \ + "fmla v18.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr4]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr4]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr3].s[2]\n" /*2468*wr3[2]*/ \ + "fmla v17.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr3].s[3]\n" /*3579*wr3[3]*/ \ + "fmla v18.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ + "fmla v17.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr4].s[0]\n" /*0246*wr4[0]*/ \ + "fmla v18.4s, v9.4s, %[wr3].s[0]\n" /*0246*wr3[0]*/ \ + "fmla v16.4s, v10.4s, %[wr4].s[1]\n" /*1357*wr4[1]*/ \ + "fmla v17.4s, v10.4s, %[wr3].s[1]\n" /*1357*wr3[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr5]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr5]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr4].s[2]\n" /*2468*wr4[2]*/ \ + "fmla v18.4s, v12.4s, %[wr3].s[2]\n" /*2468*wr3[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr4].s[3]\n" /*3579*wr4[3]*/ \ + "fmla v17.4s, v13.4s, %[wr3].s[3]\n" /*3579*wr3[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ + "fmla v18.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v17.4s, v9.4s, %[wr4].s[0]\n" /*0246*wr4[0]*/ \ + "fmla v18.4s, v10.4s, %[wr4].s[1]\n" /*1357*wr4[1]*/\ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v17.4s, v12.4s, %[wr4].s[2]\n" /*2468*wr4[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v18.4s, v13.4s, %[wr4].s[3]\n" /*3579*wr4[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v17.4s, v14.4s, %[wr6].s[0]\n" /*46810*wr6[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" \ + "fadd v18.4s, v18.4s, v17.4s\n" +#define COMPUTE_ONE_LINE_S2_POST \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_TWO_LINE_S2_POST \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_THREE_LINE_S2_POST \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr2]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr2]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ + "fmla v16.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define COMPUTE_FOUR_LINE_S2_POST \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "1: \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmla v15.4s, v9.4s, %[wr0].s[0]\n" /*0246*wr0[0]*/ \ + "fmul v16.4s, v10.4s, %[wr0].s[1]\n" /*1357*wr0[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr1]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v12.4s, %[wr0].s[2]\n" /*2468*wr0[2]*/ \ + "ldr d22, [%[din_ptr1]]\n" /*891011*/ \ + "fmla v16.4s, v13.4s, %[wr0].s[3]\n" /*3579*wr0[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[0]\n" /*46810*wr5[0]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr1].s[0]\n" /*0246*wr1[0]*/ \ + "fmla v15.4s, v10.4s, %[wr1].s[1]\n" /*1357*wr1[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr2]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr2]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr1].s[2]\n" /*2468*wr1[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr1].s[3]\n" /*3579*wr1[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[1]\n" /*46810*wr5[1]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v15.4s, v9.4s, %[wr2].s[0]\n" /*0246*wr2[0]*/ \ + "fmla v16.4s, v10.4s, %[wr2].s[1]\n" /*1357*wr2[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr3]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr3]]\n" /*891011*/ \ + "fmla v15.4s, v12.4s, %[wr2].s[2]\n" /*2468*wr2[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v16.4s, v13.4s, %[wr2].s[3]\n" /*3579*wr2[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v15.4s, v14.4s, %[wr5].s[2]\n" /*46810*wr5[2]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fmla v16.4s, v9.4s, %[wr3].s[0]\n" /*0246*wr3[0]*/ \ + "fmla v15.4s, v10.4s, %[wr3].s[1]\n" /*1357*wr3[1]*/ \ + "ld2 {v9.4s, v10.4s}, [%[din_ptr0]], #32\n" \ + "mov v13.s[3], v11.s[1]\n" /*3579*/ \ + "ldr d22, [%[din_ptr0]]\n" /*891011*/ \ + "fmla v16.4s, v12.4s, %[wr3].s[2]\n" /*2468*wr3[2]*/ \ + "mov v14.s[3], v11.s[2]\n" /*46810*/ \ + "fmla v15.4s, v13.4s, %[wr3].s[3]\n" /*3579*wr3[3]*/ \ + "ext v12.16b, v9.16b, v11.16b, #4\n" /*2468*/ \ + "fmla v16.4s, v14.4s, %[wr5].s[3]\n" /*46810*wr5[3]*/\ + "ext v13.16b, v10.16b, v11.16b, #4\n"/*3578*/ \ + "ext v14.16b, v9.16b, v11.16b, #8\n" /*4689*/ \ + "fadd v16.4s, v16.4s, v15.4s\n" +#define RESULT_S2 \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "st1 {v16.4s}, [%[dout_ptr]], #16\n" \ + "bne 1b" +#define RESULT_S2_RELU \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "fmax v16.4s, v16.4s, %[vzero].4s\n" \ + "st1 {v16.4s}, [%[dout_ptr]], #16\n" \ + "bne 1b" +#define RESULT_S2_RELU6 \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "fmax v16.4s, v16.4s, %[vzero].4s\n" \ + "fmin v16.4s, v16.4s, %[vsix].4s\n" \ + "st1 {v16.4s}, [%[dout_ptr]], #16\n" \ + "bne 1b" +#define RESULT_S2_LEAKY_RELU \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "fcmge v17.4s, v16.4s, %[vzero].4s\n" \ + "fmul v18.4s, v16.4s, %[vscale].4s\n" \ + "bif v16.16b, v18.16b, v17.16b\n" \ + "st1 {v16.4s}, [%[dout_ptr]], #16\n" \ + "bne 1b" +#define RESULT_S2_OUT2 \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "st1 {v16.4s}, [%[dout_ptr0]], #16\n" \ + "ld1 {v17.4s}, [%[bias]]\n" \ + "st1 {v18.4s}, [%[dout_ptr1]], #16\n" \ + "bne 1b" +#define RESULT_S2_RELU_OUT2 \ + "fmax v16.4s, v16.4s, %[vzero].4s\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "fmax v18.4s, v18.4s, %[vzero].4s\n" \ + "ld1 {v17.4s}, [%[bias]]\n" \ + "st1 {v16.4s}, [%[dout_ptr0]], #16\n" \ + "st1 {v18.4s}, [%[dout_ptr1]], #16\n" \ + "bne 1b" +#define RESULT_S2_RELU6_OUT2 \ + "fmax v16.4s, v16.4s, %[vzero].4s\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "fmax v18.4s, v18.4s, %[vzero].4s\n" \ + "ld1 {v17.4s}, [%[bias]]\n" \ + "fmin v16.4s, v16.4s, %[vsix].4s\n" \ + "fmin v18.4s, v18.4s, %[vsix].4s\n" \ + "st1 {v16.4s}, [%[dout_ptr0]], #16\n" \ + "st1 {v18.4s}, [%[dout_ptr1]], #16\n" \ + "bne 1b" +#define RESULT_S2_LEAKY_RELU_OUT2 \ + "fcmge v19.4s, v16.4s, %[vzero].4s\n" \ + "fmul v20.4s, v16.4s, %[vscale].4s\n" \ + "ld1 {v15.4s}, [%[bias]]\n" \ + "fcmge v21.4s, v18.4s, %[vzero].4s\n" \ + "fmul v22.4s, v18.4s, %[vscale].4s\n" \ + "ld1 {v17.4s}, [%[bias]]\n" \ + "bif v16.16b, v20.16b, v19.16b\n" \ + "bif v18.16b, v22.16b, v21.16b\n" \ + "st1 {v16.4s}, [%[dout_ptr0]], #16\n" \ + "st1 {v18.4s}, [%[dout_ptr1]], #16\n" \ + "bne 1b" #else -#define COMPUTE \ - /* fill with bias */ \ - "vld1.32 {d12-d13}, [%[bias]]\n" /* load bias */ /* load weights */ \ - "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ - "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \ - "vand.i32 q12, q6, q6\n" \ - "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \ - "vand.i32 q13, q6, q6\n" \ - "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \ - "vand.i32 q14, q6, q6\n" \ - "vand.i32 q15, q6, q6\n" \ - "vld1.32 {d12-d13}, [%[r0]]!\n" /* load input r0, 6*/ \ - "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ - "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-q10 */ \ - "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ - "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ - "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ - "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ - "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ - "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ - "vmla.f32 q13, q9, q4 @ w2 * inr6\n" \ - "vmla.f32 q14, q9, q6 @ w2 * inr4\n" \ - "vld1.32 {d0-d3}, [%[r0]]! \n" /* load r0, 7-8 */ \ - "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ - "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ - "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ - "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ - "vld1.32 {d4-d7}, [%[r0]] \n" /* load r0, 9-10 */ \ - "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ - "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ - "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ - "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ - "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \ - "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ - "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ - "vld1.32 {d4-d5}, [%[r1]]! @ load r1, 2\n" \ - "sub %[r0], %[r0], #16 @ r0 - 16 to nextline address\n" \ - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ - "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ - "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ - "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ - "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 3, 4\n" \ - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ - "vld1.32 {d10-d13}, [%[r1]]! @ load r1, 5, 6\n" \ - "vmla.f32 q14, q7, q4 @ w0 * inr0\n" \ - "vmla.f32 q15, q7, q6 @ w0 * inr2\n" \ - "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ - "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ - "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 7, 8\n" \ - "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ - "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ - "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ - "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ - "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ - "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ - "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ - "vld1.32 {d4-d7}, [%[r1]] @ load r1, 9, 10\n" \ - "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ - "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ - "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ - "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ - "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ - "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ - "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ - "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \ - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ - "sub %[r1], %[r1], #16 @ r1 - 16 to nextline address\n" \ - "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \ - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ - "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ - "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ - "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \ - "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ - "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ - "vld1.32 {d12-d13}, [%[r2]]! @ load r2, 6 \n" \ - "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ - "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ - "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 7, 8\n" \ - "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ - "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ - "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ - "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ - "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ - "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ - "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ - "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ - "vld1.32 {d4-d7}, [%[r2]] @ load r2, 9, 10\n" \ - "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ - "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ - "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ - "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ - "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ - "sub %[r2], %[r2], #16 @ r1 - 16 to nextline address\n" \ - "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ - "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 0, 1\n" \ - "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ - "vld1.32 {d4-d7}, [%[r3]]! @ load r3, 2, 3\n" \ - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ - "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ - "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ - "vld1.32 {d8-d11}, [%[r3]]! @ load r3, 4, 5\n" \ - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ - "vld1.32 {d12-d13}, [%[r3]]! @ load r3, 6, \n" \ - "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ - "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ - "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ - "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ - "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 7, 8\n" \ - "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ - "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ - "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ - "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ - "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ - "vld1.32 {d4-d7}, [%[r3]] @ load r3, 9, 10\n" \ - "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ - "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ - "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ - "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ - "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ - "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ - "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ - "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ - "sub %[r3], %[r3], #16 @ r1 - 16 to nextline address\n" \ - "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ - "vld1.32 {d0-d3}, [%[r4]]! @ load r4, 0, 1\n" \ - "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ - "vld1.32 {d4-d7}, [%[r4]]! @ load r4, 2, 3\n" \ - "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ - "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ - "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ - "vld1.32 {d8-d11}, [%[r4]]! @ load r3, 4, 5\n" \ - "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ - "vld1.32 {d12-d13}, [%[r4]]! @ load r3, 6, \n" \ - "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ - "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ - "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ - "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ - "vld1.32 {d0-d3}, [%[r4]]! @ load r3, 7, 8\n" \ - "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ - "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ - "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ - "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ - "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ - "vld1.32 {d4-d7}, [%[r4]] @ load r3, 9, 10\n" \ - "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ - "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ - "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ - "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ - "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ - "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ - "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ - "sub %[wc0], %[wc0], #400 @ wc0 - 400 to start address\n" \ - "sub %[r4], %[r4], #16 @ r1 - 16 to nextline address\n" \ - "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ - "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ - "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \ - "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \ - "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \ - "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ - -#define RELU /* relu */ \ - "vmov.u32 q0, #0\n" \ - "vld1.32 {d2-d3}, [%[six_ptr]]\n" \ - "vmax.f32 q12, q12, q0\n" \ - "vmax.f32 q13, q13, q0\n" \ - "vmax.f32 q14, q14, q0\n" \ - "vmax.f32 q15, q15, q0\n" -#define RELU6 /* relu6 */ \ - "vmin.f32 q12, q12, q1\n" \ - "vmin.f32 q13, q13, q1\n" \ - "vmin.f32 q14, q14, q1\n" \ - "vmin.f32 q15, q15, q1\n" -#define LEAKY_RELU /* LeakyRelu */ \ - "vmov.u32 q0, #0\n" \ - "vld1.32 {d2-d3}, [%[scale_ptr]]\n" \ - "vcge.f32 q2, q12, q0 @ q0 > 0 \n" \ - "vcge.f32 q4, q13, q0 @ q0 > 0 \n" \ - "vcge.f32 q6, q14, q0 @ q0 > 0 \n" \ - "vcge.f32 q8, q15, q0 @ q0 > 0 \n" \ - "vmul.f32 q3, q12, q1 @ mul \n" \ - "vmul.f32 q5, q13, q1 @ mul \n" \ - "vmul.f32 q7, q14, q1 @ mul \n" \ - "vmul.f32 q9, q15, q1 @ mul \n" \ - "vbif q12, q3, q2 @ choose \n" \ - "vbif q13, q5, q4 @ choose \n" \ - "vbif q14, q7, q6 @ choose \n" \ - "vbif q15, q9, q8 @ choose \n" -#define STORE /* save result */ \ - "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \ - "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \ - "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \ - "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ +#define COMPUTE_ONE_LINE_S2_PRE \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_TWO_LINE_S2_PRE \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_THREE_LINE_S2_PRE \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr2]]!\n" \ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr2]]\n" /*810911*/\ + "vmla.f32 q15, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ + "vmla.f32 q14, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q15, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_FOUR_LINE_S2_PRE \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr2]]!\n" \ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr2]]\n" /*810911*/\ + "vmla.f32 q15, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ + "vmla.f32 q14, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr3]]!\n" \ + "vmla.f32 q15, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr3]]\n" /*810911*/\ + "vmla.f32 q14, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr3][0]\n" /*0246*wr3[0]*/ \ + "vmla.f32 q15, q9, %e[wr3][1]\n" /*1357*wr3[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q14, q11, %f[wr3][0]\n" /*2468*wr3[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr3][1]\n" /*3579*wr3[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_FIVE_LINE_S2 \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr2]]!\n" \ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr2]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ + "vmla.f32 q14, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr3]]!\n" \ + "vmla.f32 q15, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr3]]\n" /*810911*/\ + "vmla.f32 q14, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr3][0]\n" /*0246*wr3[0]*/ \ + "vmla.f32 q15, q9, %e[wr3][1]\n" /*1357*wr3[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr4]]!\n" \ + "vmla.f32 q14, q11, %f[wr3][0]\n" /*2468*wr3[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr4]]\n" /*810911*/\ + "vmla.f32 q15, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr3][1]\n" /*3579*wr3[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr4][0]\n" /*0246*wr4[0]*/ \ + "vmla.f32 q14, q9, %e[wr4][1]\n" /*1357*wr4[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q15, q11, %f[wr4][0]\n" /*2468*wr4[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr4][1]\n" /*3579*wr4[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_FIVE_LINE_S2_OUT2 \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vadd.f32 q15, q15, q14\n" \ + "vld1.f32 {d28-d29}, [%[bias]]\n" \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q14, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vmla.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr2]]!\n" \ + "vmla.f32 q15, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vmla.f32 q14, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr2]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ + "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vmla.f32 q15, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vmla.f32 q14, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vmla.f32 q15, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ + "vmla.f32 q14, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr3]]!\n" \ + "vmla.f32 q15, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr3]]\n" /*810911*/\ + "vmla.f32 q15, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ + "vmla.f32 q14, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vmla.f32 q15, q8, %e[wr3][0]\n" /*0246*wr3[0]*/ \ + "vmla.f32 q14, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q9, %e[wr3][1]\n" /*1357*wr3[1]*/ \ + "vmla.f32 q14, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr4]]!\n" \ + "vmla.f32 q15, q11, %f[wr3][0]\n" /*2468*wr3[2]*/\ + "vmla.f32 q14, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr4]]\n" /*810911*/\ + "vmla.f32 q15, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ + "vmla.f32 q14, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr3][1]\n" /*3579*wr3[3]*/\ + "vmla.f32 q14, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr4][0]\n" /*0246*wr4[0]*/ \ + "vmla.f32 q14, q8, %e[wr3][0]\n" /*0246*wr3[0]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vmla.f32 q15, q9, %e[wr4][1]\n" /*1357*wr4[1]*/ \ + "vmla.f32 q14, q9, %e[wr3][1]\n" /*1357*wr3[1]*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr5]]!\n" \ + "vmla.f32 q15, q11, %f[wr4][0]\n" /*2468*wr4[2]*/\ + "vmla.f32 q14, q11, %f[wr3][0]\n" /*2468*wr3[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr5]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ + "vmla.f32 q14, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr4][1]\n" /*3579*wr4[3]*/\ + "vmla.f32 q14, q12, %f[wr3][1]\n" /*3579*wr4[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr4][0]\n" /*0246*wr4[0]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vmla.f32 q14, q9, %e[wr4][1]\n" /*1357*wr4[1]*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q14, q11, %f[wr4][0]\n" /*2468*wr4[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr6][0]\n"/*46810*wr6[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vmla.f32 q14, q12, %f[wr4][1]\n" /*3579*wr4[3]*/\ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ +#define COMPUTE_ONE_LINE_S2_POST \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_TWO_LINE_S2_POST \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_THREE_LINE_S2_POST \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr2]]!\n" \ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr2]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ + "vmla.f32 q14, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q15, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q14, q13, %f[wr5][0]\n"/*46810*wr5[3]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define COMPUTE_FOUR_LINE_S2_POST \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "1: \n" \ + "subs %[cnt], #1\n" \ + "vmla.f32 q15, q8, %e[wr0][0]\n" /*0246*wr0[0]*/ \ + "vmul.f32 q14, q9, %e[wr0][1]\n" /*1357*wr0[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr1]]!\n" \ + "vmla.f32 q15, q11, %f[wr0][0]\n" /*2468*wr0[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr1]]\n" /*810911*/\ + "vmla.f32 q14, q13, %e[wr5][0]\n"/*46810*wr5[0]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr0][1]\n" /*3579*wr0[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr1][0]\n" /*0246*wr1[0]*/ \ + "vmla.f32 q15, q9, %e[wr1][1]\n" /*1357*wr1[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr2]]!\n" \ + "vmla.f32 q14, q11, %f[wr1][0]\n" /*2468*wr1[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr2]]\n" /*810911*/\ + "vmla.f32 q15, q13, %e[wr5][1]\n"/*46810*wr5[1]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr1][1]\n" /*3579*wr1[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q15, q8, %e[wr2][0]\n" /*0246*wr2[0]*/ \ + "vmla.f32 q14, q9, %e[wr2][1]\n" /*1357*wr2[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr3]]!\n" \ + "vmla.f32 q15, q11, %f[wr2][0]\n" /*2468*wr2[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr3]]\n" /*810911*/\ + "vmla.f32 q14, q13, %f[wr5][0]\n"/*46810*wr5[2]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q15, q12, %f[wr2][1]\n" /*3579*wr2[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vmla.f32 q14, q8, %e[wr3][0]\n" /*0246*wr3[0]*/ \ + "vmla.f32 q15, q9, %e[wr3][1]\n" /*1357*wr3[1]*/ \ + "vext.f32 d24, d18, d19, #1\n" /*13-35*/ \ + "vld2.f32 {d16-d19}, [%[din_ptr0]]!\n" \ + "vmla.f32 q14, q11, %f[wr3][0]\n" /*2468*wr3[2]*/\ + "vld2.f32 {d20-d21}, [%[din_ptr0]]\n" /*810911*/\ + "vmla.f32 q15, q13, %f[wr5][1]\n"/*46810*wr5[3]*/\ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vmla.f32 q14, q12, %f[wr3][1]\n" /*3579*wr3[3]*/\ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vadd.f32 q14, q14, q15\n" +#define RESULT_S2 \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \ + "bne 1b" +#define RESULT_S2_RELU \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \ + "bne 1b" +#define RESULT_S2_RELU6 \ + "vld1.f32 {d26-d27}, [%[six_ptr]]\n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vmin.f32 q14, q14, q13\n" \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \ + "bne 1b" +#define RESULT_S2_LEAKY_RELU \ + "vld1.f32 {d26-d27}, [%[scale_ptr]]\n" \ + "vcge.f32 q11, q14, %q[vzero]\n" \ + "vmul.f32 q12, q14, q13\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vbif q14, q12, q11\n" \ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vst1.f32 {d28-d29}, [%[dout_ptr]]!\n" \ + "bne 1b" +#define RESULT_S2_OUT2 \ + "vst1.f32 {d30-d31}, [%[dout_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \ + "bne 1b" +#define RESULT_S2_RELU_OUT2 \ + "vmax.f32 q15, q15, %q[vzero]\n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vst1.f32 {d30-d31}, [%[dout_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \ + "bne 1b" +#define RESULT_S2_RELU6_OUT2 \ + "vld1.f32 {d26-d27}, [%[six_ptr]]\n" \ + "vmax.f32 q15, q15, %q[vzero]\n" \ + "vmax.f32 q14, q14, %q[vzero]\n" \ + "vmin.f32 q15, q15, q13\n" \ + "vmin.f32 q14, q14, q13\n" \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vst1.f32 {d30-d31}, [%[dout_ptr0]]!\n" \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \ + "bne 1b" +#define RESULT_S2_LEAKY_RELU_OUT2 \ + "vld1.f32 {d26-d27}, [%[scale_ptr]]\n" \ + "vcge.f32 q11, q15, %q[vzero]\n" \ + "vmul.f32 q12, q15, q13\n" \ + "vbif q15, q12, q11\n" \ + "vcge.f32 q11, q14, %q[vzero]\n" \ + "vmul.f32 q12, q14, q13\n" \ + "vext.32 q13, q8, q10, #2\n" /*46810*/ \ + "vst1.f32 {d30-d31}, [%[dout_ptr0]]!\n" \ + "vbif q14, q12, q11\n" \ + "vext.32 q11, q8, q10, #1\n" /*2468*/ \ + "vext.32 d25, d19, d21, #1\n" /*57-79*/ \ + "vld1.f32 {d30-d31}, [%[bias]]\n" \ + "vst1.f32 {d28-d29}, [%[dout_ptr1]]!\n" \ + "bne 1b" #endif +// clang-format on +inline float compute_one_data_pre( + const float* data, float32x4_t wr, float bias_val, float wei_val, int num) { + float sum = bias_val; + int index = 4 - num; + for (int i = 0; i < num; i++) { + sum += data[i] * wr[index + i]; + } + sum += data[num] * wei_val; + return sum; +} -void act_switch_5x5s2(const float* inr0, - const float* inr1, - const float* inr2, - const float* inr3, - const float* inr4, - float* outc0, - float* outc1, - float* outc2, - float* outc3, - float32x4_t w0, - float32x4_t w1, - float32x4_t w2, - float32x4_t w3, - float32x4_t w4, - float32x4_t vbias, - const float* weight_c, - float* bias_local, - const operators::ActivationParam act_param) { - bool has_active = act_param.has_active; - if (has_active) { - float tmp = act_param.Relu_clipped_coef; - float ss = act_param.Leaky_relu_alpha; +inline float compute_one_data_post( + const float* data, float32x4_t wr, float bias_val, float wei_val, int num) { + float sum = bias_val; + for (int i = 0; i < num; i++) { + sum += data[i] * wr[i]; + } + sum += data[num] * wei_val; + return sum; +} + +inline void compute_all_padding_pre(float* dout, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + bool odds, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + int tmp_index = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[5][3 - k], + 4 - i); + } + *dout++ = sum; + } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp_index - k]++; + } + } + // mid + // clang-format off + if (cnt > 0) { + switch (num) { + case 0: #ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(tmp); - float32x4_t vscale = vdupq_n_f32(ss); + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); #else - float vsix[4] = {tmp, tmp, tmp, tmp}; - float vscale[4] = {ss, ss, ss, ss}; + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif - switch (act_param.active_type) { - case lite_api::ActivationType::kRelu: + break; + case 1: #ifdef __aarch64__ - asm volatile(COMPUTE RELU STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [inr4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [vbias] "w"(vbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - asm volatile(COMPUTE RELU STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [r4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [bias] "r"(bias_local), [six_ptr] "r"(vsix) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; - case lite_api::ActivationType::kRelu6: + case 2: #ifdef __aarch64__ - asm volatile(COMPUTE RELU RELU6 STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [inr4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [vbias] "w"(vbias), - [vsix] "w"(vsix) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - asm volatile(COMPUTE RELU RELU6 STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [r4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [bias] "r"(bias_local), [six_ptr] "r"(vsix) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; - case lite_api::ActivationType::kLeakyRelu: + case 3: #ifdef __aarch64__ - asm volatile(COMPUTE LEAKY_RELU STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [inr4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [vbias] "w"(vbias), - [vscale] "w"(vscale) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - asm volatile(COMPUTE LEAKY_RELU STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [r4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); #endif break; default: - LOG(FATAL) << "this act_type: " - << static_cast(act_param.active_type) - << " fuse not support"; + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; } - } else { + din_ptr_arr[0] -= 8; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); + din_ptr_arr[num] += 2; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - i], + weights[3 - i], + 0.f, + weights[5][3 - i], + 4); + din_ptr_arr[tmp_index - i] += 2; + } + *dout++ = sum; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num] += 2; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[3 - k][3 - i], + 3 - i); + din_ptr_arr[tmp_index - k] += 2; + } + *dout++ = sum; + } +} +inline void compute_all_padding_mid(float* dout, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + bool odds, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + // left + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum; + } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp_index - k]++; + } + } + // clang-format off + // mid + if (cnt > 0) { #ifdef __aarch64__ - asm volatile(COMPUTE STORE - : [inr0] "+r"(inr0), - [inr1] "+r"(inr1), - [inr2] "+r"(inr2), - [inr3] "+r"(inr3), - [inr4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [w0] "w"(w0), - [w1] "w"(w1), - [w2] "w"(w2), - [w3] "w"(w3), - [w4] "w"(w4), - [vbias] "w"(vbias) + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) : "cc", "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", "v9", "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 8; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[num] += 2; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp - i] += 2; + } + *dout++ = sum; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num] += 2; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[tmp - k] += 2; + } + *dout++ = sum; + } +} +inline void compute_all_padding_mid_out2(float* dout0, + float* dout1, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + bool odds, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + int tmp1 = num + 1; + int tmp = num - 1; + // left + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + float sum1 = compute_one_data_pre( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + sum1 += compute_one_data_pre(din_ptr_arr[num - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout0++ = sum; + *dout1++ = sum1; + } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[tmp1]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[num - k]++; + } + din_ptr_arr[0]++; + } + // clang-format off + // mid + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [din_ptr6] "+r"(din_ptr_arr[6]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", "v14", "v15", "v16", "v17", - "v18", - "v19", - "v20", - "v21", - "v22"); -#else - asm volatile(COMPUTE STORE - : [r0] "+r"(inr0), - [r1] "+r"(inr1), - [r2] "+r"(inr2), - [r3] "+r"(inr3), - [r4] "+r"(inr4), - [wc0] "+r"(weight_c), - [outc0] "+r"(outc0), - [outc1] "+r"(outc1), - [outc2] "+r"(outc2), - [outc3] "+r"(outc3) - : [bias] "r"(bias_local) + "v18"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [din_ptr6] "+r"(din_ptr_arr[6]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [bias] "r"(bias) : "cc", "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", "q8", "q9", "q10", @@ -749,197 +1779,3404 @@ void act_switch_5x5s2(const float* inr0, "q14", "q15"); #endif + din_ptr_arr[0] -= 8; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[tmp1] += 2; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + sum1 += compute_one_data_post( + din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[num - i] += 2; + } + din_ptr_arr[0] += 2; + *dout0++ = sum; + *dout1++ = sum1; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[tmp1] += 2; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + sum1 += compute_one_data_post(din_ptr_arr[num - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[num - k] += 2; + } + din_ptr_arr[0] += 2; + *dout0++ = sum; + *dout1++ = sum1; } } -void conv_depthwise_5x5s2_fp32(const float* i_data, - float* o_data, - int bs, - int oc, - int oh, - int ow, - int ic, - int ih, - int win, + +inline void compute_all_padding_post(float* dout, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + bool odds, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + // left + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum; + } + if (odds) { // origin pad_left is odds, such as ori_pad_left=1 + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + din_ptr_arr[tmp_index - k]++; + } + } + // clang-format off + // mid + if (cnt > 0) { + switch (num) { + case 0: +#ifdef __aarch64__ + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[3] -= 4; + break; + case 1: +#ifdef __aarch64__ + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[2] -= 4; + break; + case 2: +#ifdef __aarch64__ + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[1] -= 4; + break; + case 3: +#ifdef __aarch64__ + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 8; + break; + default: + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + } + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[3] += 2; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[2 - i] += 2; + } + *dout++ = sum; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[3] += 2; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[2 - k] += 2; + } + *dout++ = sum; + } +} + +void conv_depthwise_5x5s2_bias(float* dout, + const float* din, const float* weights, const float* bias, - const operators::ConvParam& param, - const operators::ActivationParam act_param, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, ARMContext* ctx) { - auto paddings = *param.paddings; - int threads = ctx->threads(); - const int pad_h = paddings[0]; - const int pad_w = paddings[2]; - const int out_c_block = 4; - const int out_h_kernel = 1; - const int out_w_kernel = 4; - const int win_ext = ow * 2 + 3; - const int ow_round = ROUNDUP(ow, 4); - const int win_round = ROUNDUP(win_ext, 4); - const int hin_round = oh * 2 + 3; - const int prein_size = win_round * hin_round * out_c_block; - auto workspace_size = threads * prein_size + win_round + ow_round; - ctx->ExtendWorkspace(sizeof(float) * workspace_size); + int in_size = win * hin; + int out_size = wout * hout; + int in_channel_size = chin * in_size; + int out_channel_size = chin * out_size; + int pad_left_new = (pad_left + 1) / 2; + int pad_right_new = pad_right / 2; + int pad_top_new = (pad_top + 1) / 2; + int pad_bottom_new = pad_bottom / 2; + int weights_size = 25; + int num_out = wout << 1; + int loop_w = wout - pad_left_new - pad_left_new; + int loop_h = hout - pad_top_new - pad_bottom_new; + bool odds_w = pad_left % 2; + bool odds_h = pad_top % 2; + int cnt = loop_w >> 2; + int remain = loop_w & 3; + for (int n = 0; n < num; n++) { + const float* din_batch = din + n * in_channel_size; + float* dout_batch = dout + n * out_channel_size; +#pragma omp parallel for + for (int c = 0; c < chin; c++) { + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + const float* din_ptr5 = din_ptr4 + win; + const float* din_ptr6 = din_ptr5 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr0 = dout_ch; + float* dout_ptr1 = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = { + din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5, din_ptr6}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; + // top_h + for (int h = pad_top_new; h > 0; h--) { + compute_all_padding_pre(dout_ptr0, + din_ptr_arr, + vbias, + weights_vec, + odds_w, + pad_left, + pad_right, + cnt, + remain, + 4 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + if (odds_h) { + din_ptr_arr[0] = din_ptr1; + din_ptr_arr[1] = din_ptr2; + din_ptr_arr[2] = din_ptr3; + din_ptr_arr[3] = din_ptr4; + din_ptr_arr[4] = din_ptr5; + din_ptr_arr[5] = din_ptr6; + din_ptr_arr[6] = din_ptr6 + win; + } + dout_ptr1 = dout_ptr0 + wout; + // mid_h + for (int h = 0; h < loop_h - 1; h += 2) { + compute_all_padding_mid_out2(dout_ptr0, + dout_ptr1, + din_ptr_arr, + vbias, + weights_vec, + odds_w, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 += num_out; + dout_ptr1 += num_out; + din_ptr0 = din_ptr4; + din_ptr1 = din_ptr5; + din_ptr2 = din_ptr6; + din_ptr3 = din_ptr6 + win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr4 = din_ptr3 + win; + din_ptr_arr[2] = din_ptr2; + din_ptr5 = din_ptr4 + win; + din_ptr_arr[3] = din_ptr3; + din_ptr6 = din_ptr5 + win; + din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; + din_ptr_arr[6] = din_ptr6; + } + if (loop_h % 2 != 0) { + compute_all_padding_mid(dout_ptr0, + din_ptr_arr, + vbias, + weights_vec, + odds_w, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 = dout_ptr1; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr6; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + // bottom + for (int h = 0; h < pad_bottom; h++) { + compute_all_padding_post(dout_ptr0, + din_ptr_arr, + vbias, + weights_vec, + odds_w, + pad_left, + pad_right, + cnt, + remain, + 3 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + } + } +} - bool flag_bias = param.bias != nullptr; +inline void compute_all_padding_pre_relu(float* dout, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + int tmp_index = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[5][3 - k], + 4 - i); + } + *dout++ = sum > 0.f ? sum : 0.f; + } + // clang-format off + // mid + if (cnt > 0) { + switch (num) { + case 0: +#ifdef __aarch64__ + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 1: +#ifdef __aarch64__ + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 2: +#ifdef __aarch64__ + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 3: +#ifdef __aarch64__ + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + } + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); + din_ptr_arr[num]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - i], + weights[3 - i], + 0.f, + weights[5][3 - i], + 4); + din_ptr_arr[tmp_index - i]++; + } + *dout++ = sum > 0.f ? sum : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[3 - k][3 - i], + 3 - i); + din_ptr_arr[tmp_index - k]++; + } + *dout++ = sum > 0.f ? sum : 0.f; + } +} +inline void compute_all_padding_mid_relu(float* dout, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum > 0.f ? sum : 0.f; + } + // clang-format off + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[num]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp - i]++; + } + *dout++ = sum > 0.f ? sum : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[tmp - k]++; + } + *dout++ = sum > 0.f ? sum : 0.f; + } +} +inline void compute_all_padding_mid_relu_out2(float* dout0, + float* dout1, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + // left + int tmp = num - 1; + int tmp1 = num + 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + float sum1 = compute_one_data_pre( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + sum1 += compute_one_data_pre(din_ptr_arr[num - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout0++ = sum > 0.f ? sum : 0.f; + *dout1++ = sum1 > 0.f ? sum1 : 0.f; + } + // clang-format off + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_RELU_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_RELU_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[tmp1]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + sum1 += compute_one_data_post( + din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[num - i]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? sum : 0.f; + *dout1++ = sum1 > 0.f ? sum1 : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[tmp1]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + sum1 += compute_one_data_post(din_ptr_arr[num - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[num - k]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? sum : 0.f; + *dout1++ = sum1 > 0.f ? sum1 : 0.f; + } +} +inline void compute_all_padding_post_relu(float* dout, + const float** din_ptr_arr, + const float* bias, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { + // left + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum > 0.f ? sum : 0.f; + } + // clang-format off + // mid + if (cnt > 0) { + switch (num) { + case 0: +#ifdef __aarch64__ + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[3] -= 4; + break; + case 1: +#ifdef __aarch64__ + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[2] -= 4; + break; + case 2: +#ifdef __aarch64__ + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[1] -= 4; + break; + case 3: +#ifdef __aarch64__ + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + break; + default: + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + } + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[3]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[2 - i]++; + } + *dout++ = sum > 0.f ? sum : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[3]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[2 - k]++; + } + *dout++ = sum > 0.f ? sum : 0.f; + } +} - /// get workspace - auto ptr_zero = ctx->workspace_data(); - memset(ptr_zero, 0, sizeof(float) * win_round); - float* ptr_write = ptr_zero + win_round; +void conv_depthwise_5x5s2_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + ARMContext* ctx) { + int loop_w = wout - pad_left - pad_right; + int loop_h = hout - pad_top - pad_bottom; + int in_size = win * hin; + int out_size = wout * hout; + int cnt = loop_w >> 2; + int remain = loop_w & 3; + int in_channel_size = chin * in_size; + int out_channel_size = chin * out_size; + int weights_size = 25; + int num_out = wout << 1; + float32x4_t vzero = vdupq_n_f32(0.f); + for (int n = 0; n < num; n++) { + const float* din_batch = din + n * in_channel_size; + float* dout_batch = dout + n * out_channel_size; +#pragma omp parallel for + for (int c = 0; c < chin; c++) { + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + const float* din_ptr5 = din_ptr4 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr0 = dout_ch; + float* dout_ptr1 = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = { + din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; + // top_h + for (int h = pad_top; h > 0; h--) { + compute_all_padding_pre_relu(dout_ptr0, + din_ptr_arr, + vbias, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + dout_ptr1 = dout_ptr0 + wout; + // mid_h + for (int h = 0; h < loop_h - 1; h += 2) { + compute_all_padding_mid_relu_out2(dout_ptr0, + dout_ptr1, + din_ptr_arr, + vbias, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 += num_out; + dout_ptr1 += num_out; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr5 + win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr5 = din_ptr4 + win; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; + } + if (loop_h % 2 != 0) { + compute_all_padding_mid_relu(dout_ptr0, + din_ptr_arr, + vbias, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 = dout_ptr1; + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr5; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + // bottom + for (int h = 0; h < pad_bottom; h++) { + compute_all_padding_post_relu(dout_ptr0, + din_ptr_arr, + vbias, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 3 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + } + } +} - int size_in_channel = win * ih; - int size_out_channel = ow * oh; +inline void compute_all_padding_pre_relu6(float* dout, + const float** din_ptr_arr, + const float* bias, + const float* six, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vsix = vld1q_f32(six); +#endif + int tmp_index = num - 1; + // left + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[5][3 - k], + 4 - i); + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } + // clang-format off + // mid + if (cnt > 0) { + switch (num) { + case 0: +#ifdef __aarch64__ + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 1: +#ifdef __aarch64__ + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 2: +#ifdef __aarch64__ + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 3: +#ifdef __aarch64__ + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + } + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); + din_ptr_arr[num]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - i], + weights[3 - i], + 0.f, + weights[5][3 - i], + 4); + din_ptr_arr[tmp_index - i]++; + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[3 - k][3 - i], + 3 - i); + din_ptr_arr[tmp_index - k]++; + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } +} +inline void compute_all_padding_mid_relu6(float* dout, + const float** din_ptr_arr, + const float* bias, + const float* six, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vsix = vld1q_f32(six); +#endif + // left + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } + // clang-format off + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[num]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp - i]++; + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[tmp - k]++; + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } +} - int ws = -pad_w; - int we = ws + win_round; - int hs = -pad_h; - int he = hs + hin_round; - int w_loop = ow_round / 4; - auto remain = w_loop * 4 - ow; - bool flag_remain = remain > 0; - remain = 4 - remain; - remain = remain > 0 ? remain : 0; - int row_len = win_round * out_c_block; +inline void compute_all_padding_mid_relu6_out2(float* dout0, + float* dout1, + const float** din_ptr_arr, + const float* bias, + const float* six, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vsix = vld1q_f32(six); +#endif + // left + int tmp = num - 1; + int tmp1 = num + 1; + // clang-format off + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + float sum1 = compute_one_data_pre( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + sum1 += compute_one_data_pre(din_ptr_arr[num -k], + weights[tmp -k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; + } + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_RELU6_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_RELU6_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[tmp1]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + sum1 += compute_one_data_post( + din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[num - i]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[tmp1]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + sum1 += compute_one_data_post(din_ptr_arr[num - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[num - k]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + *dout1++ = sum1 > 0.f ? (sum1 < six[0] ? sum1 : six[0]) : 0.f; + } +} +inline void compute_all_padding_post_relu6(float* dout, + const float** din_ptr_arr, + const float* bias, + const float* six, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vsix = vld1q_f32(six); +#endif + // left + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } + // clang-format off + // mid + if (cnt > 0) { + switch (num) { + case 0: +#ifdef __aarch64__ + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[3] -= 4; + break; + case 1: +#ifdef __aarch64__ + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[2] -= 4; + break; + case 2: +#ifdef __aarch64__ + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[1] -= 4; + break; + case 3: +#ifdef __aarch64__ + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vsix] "w"(vsix), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2_RELU6 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [six_ptr] "r"(six), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + break; + default: + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + } + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[3]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[2 - i]++; + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[3]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[2 - k]++; + } + *dout++ = sum > 0.f ? (sum < six[0] ? sum : six[0]) : 0.f; + } +} +void conv_depthwise_5x5s2_bias_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + ARMContext* ctx) { + int loop_w = wout - pad_left - pad_right; + int loop_h = hout - pad_top - pad_bottom; + int in_size = win * hin; + int out_size = wout * hout; + int cnt = loop_w >> 2; + int remain = loop_w & 3; + int in_channel_size = chin * in_size; + int out_channel_size = chin * out_size; + int weights_size = 25; + int num_out = wout << 1; float32x4_t vzero = vdupq_n_f32(0.f); + for (int n = 0; n < num; n++) { + const float* din_batch = din + n * in_channel_size; + float* dout_batch = dout + n * out_channel_size; +#pragma omp parallel for + for (int c = 0; c < chin; c++) { + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + const float* din_ptr5 = din_ptr4 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr0 = dout_ch; + float* dout_ptr1 = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = { + din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; + // top_h + for (int h = pad_top; h > 0; h--) { + compute_all_padding_pre_relu6(dout_ptr0, + din_ptr_arr, + vbias, + six, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + dout_ptr1 = dout_ptr0 + wout; + // mid_h + for (int h = 0; h < loop_h - 1; h += 2) { + compute_all_padding_mid_relu6_out2(dout_ptr0, + dout_ptr1, + din_ptr_arr, + vbias, + six, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 += num_out; + dout_ptr1 += num_out; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr5 + win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr5 = din_ptr4 + win; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; + } + if (loop_h % 2 != 0) { + compute_all_padding_mid_relu6(dout_ptr0, + din_ptr_arr, + vbias, + six, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 = dout_ptr1; + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr5; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + // bottom + for (int h = 0; h < pad_bottom; h++) { + compute_all_padding_post_relu6(dout_ptr0, + din_ptr_arr, + vbias, + six, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 3 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + } + } +} + +inline void compute_all_padding_pre_leakyRelu(float* dout, + const float** din_ptr_arr, + const float* bias, + const float* scale, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vscale = vld1q_f32(scale); +#endif + int tmp_index = num - 1; + // left + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[5][3 - k], + 4 - i); + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } + // clang-format off + // mid + if (cnt > 0) { + switch (num) { + case 0: +#ifdef __aarch64__ + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_ONE_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[4]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 1: +#ifdef __aarch64__ + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[3]), + [wr1] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 2: +#ifdef __aarch64__ + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[2]), + [wr1] "w"(weights[3]), + [wr2] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case 3: +#ifdef __aarch64__ + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_PRE RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[1]), + [wr1] "w"(weights[2]), + [wr2] "w"(weights[3]), + [wr3] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + } + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[6][0], 4); + din_ptr_arr[num]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - i], + weights[3 - i], + 0.f, + weights[5][3 - i], + 4); + din_ptr_arr[tmp_index - i]++; + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[4], bias[0], weights[4][3 - i], 3 - i); + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp_index - k], + weights[3 - k], + 0.f, + weights[3 - k][3 - i], + 3 - i); + din_ptr_arr[tmp_index - k]++; + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } + for (int w = pad_right; w > 4; w--) { + *dout++ = bias[0] > 0.f ? bias[0] : bias[0] * scale[0]; + } - for (int n = 0; n < bs; ++n) { - const float* din_batch = i_data + n * ic * size_in_channel; - float* dout_batch = o_data + n * oc * size_out_channel; -#pragma omp parallel for num_threads(threads) - for (int c = 0; c < oc; c += out_c_block) { -#ifdef ARM_WITH_OMP - float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; -#else - float* pre_din = ptr_write + ow_round; -#endif - /// const array size - prepack_input_nxwc4_dw( - din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); - const float* weight_c = weights + c * 25; // kernel_w * kernel_h - float* dout_c00 = dout_batch + c * size_out_channel; - float bias_local[4] = {0, 0, 0, 0}; +} +inline void compute_all_padding_mid_leakyRelu(float* dout, + const float** din_ptr_arr, + const float* bias, + const float* scale, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vscale = vld1q_f32(scale); +#endif + // left + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } + // clang-format off + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2 RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[num]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[tmp - i]++; + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[num]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[tmp - k]++; + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } +} +inline void compute_all_padding_mid_leakyRelu_out2(float* dout0, + float* dout1, + const float** din_ptr_arr, + const float* bias, + const float* scale, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vscale = vld1q_f32(scale); +#endif + // left + int tmp = num - 1; + int tmp1 = num + 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4 - i); + float sum1 = compute_one_data_pre( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + sum1 += compute_one_data_pre(din_ptr_arr[num - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout0++ = sum > 0.f ? sum : sum * scale[0]; + *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; + } + // clang-format off + if (cnt > 0) { +#ifdef __aarch64__ + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_LEAKY_RELU_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20"); +#else + asm volatile(COMPUTE_FIVE_LINE_S2_OUT2 RESULT_S2_LEAKY_RELU_OUT2 + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [din_ptr4] "+r"(din_ptr_arr[4]), + [din_ptr5] "+r"(din_ptr_arr[5]), + [dout_ptr0] "+r"(dout0), + [dout_ptr1] "+r"(dout1) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr4] "w"(weights[4]), + [wr5] "w"(weights[5]), + [wr6] "w"(weights[6]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[6][0], 4); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[6][0], 4); + din_ptr_arr[tmp1]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[tmp - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + sum1 += compute_one_data_post( + din_ptr_arr[num - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[num - i]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? sum : sum * scale[0]; + *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[num], weights[num], bias[0], weights[num][3 - i], 3 - i); + float sum1 = compute_one_data_post( + din_ptr_arr[tmp1], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[tmp1]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[tmp - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + sum1 += compute_one_data_post(din_ptr_arr[num - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[num - k]++; + } + din_ptr_arr[0]++; + *dout0++ = sum > 0.f ? sum : sum * scale[0]; + *dout1++ = sum1 > 0.f ? sum1 : sum1 * scale[0]; + } +} +inline void compute_all_padding_post_leakyRelu(float* dout, + const float** din_ptr_arr, + const float* bias, + const float* scale, + float32x4_t* weights, + float32x4_t vzero, + int win, + int wout, + int pad_left, + int pad_right, + int cnt, + int remain, + int num) { +#ifdef __aarch64__ + float32x4_t vscale = vld1q_f32(scale); +#endif + // left + int tmp = num - 1; + for (int i = pad_left; i > 0; i--) { + float sum = compute_one_data_pre( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4 - i); + for (int k = 0; k < num; k++) { + sum += compute_one_data_pre(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[5][tmp - k], + 4 - i); + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } + // clang-format off + // mid + if (cnt > 0) { + switch (num) { + case 0: +#ifdef __aarch64__ + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_ONE_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[3] -= 4; + break; + case 1: +#ifdef __aarch64__ + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_TWO_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[2]), + [din_ptr1] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[2] -= 4; + break; + case 2: +#ifdef __aarch64__ + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_THREE_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[1]), + [din_ptr1] "+r"(din_ptr_arr[2]), + [din_ptr2] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[1] -= 4; + break; + case 3: +#ifdef __aarch64__ + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [vscale] "w"(vscale), + [bias] "r"(bias) + : "cc", + "memory", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18"); +#else + asm volatile(COMPUTE_FOUR_LINE_S2_POST RESULT_S2_LEAKY_RELU + : [cnt] "+r"(cnt), + [din_ptr0] "+r"(din_ptr_arr[0]), + [din_ptr1] "+r"(din_ptr_arr[1]), + [din_ptr2] "+r"(din_ptr_arr[2]), + [din_ptr3] "+r"(din_ptr_arr[3]), + [dout_ptr] "+r"(dout) + : [wr0] "w"(weights[0]), + [wr1] "w"(weights[1]), + [wr2] "w"(weights[2]), + [wr3] "w"(weights[3]), + [wr5] "w"(weights[5]), + [vzero] "w"(vzero), + [scale_ptr] "r"(scale), + [bias] "r"(bias) + : "cc", + "memory", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + din_ptr_arr[0] -= 4; + break; + default: + LOG(FATAL) << "This num: " << (num + 1) << "does not support"; + } + } + // clang-format on + // remain + for (int w = 0; w < remain; w++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[5][num], 4); + din_ptr_arr[3]++; + for (int i = 0; i < num; i++) { + sum += compute_one_data_post( + din_ptr_arr[2 - i], weights[tmp - i], 0.f, weights[5][tmp - i], 4); + din_ptr_arr[2 - i]++; + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } + // right + for (int i = 0; i < pad_right; i++) { + float sum = compute_one_data_post( + din_ptr_arr[3], weights[num], bias[0], weights[num][3 - i], 3 - i); + din_ptr_arr[3]++; + for (int k = 0; k < num; k++) { + sum += compute_one_data_post(din_ptr_arr[2 - k], + weights[tmp - k], + 0.f, + weights[tmp - k][3 - i], + 3 - i); + din_ptr_arr[2 - k]++; + } + *dout++ = sum > 0.f ? sum : sum * scale[0]; + } +} - if (flag_bias) { - bias_local[0] = bias[c]; - bias_local[1] = bias[c + 1]; - bias_local[2] = bias[c + 2]; - bias_local[3] = bias[c + 3]; +void conv_depthwise_5x5s2_bias_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + ARMContext* ctx) { + int loop_w = wout - pad_left - pad_right; + int loop_h = hout - pad_top - pad_bottom; + int in_size = win * hin; + int out_size = wout * hout; + int cnt = loop_w >> 2; + int remain = loop_w & 3; + int in_channel_size = chin * in_size; + int out_channel_size = chin * out_size; + int weights_size = 25; + int num_out = wout << 1; + float32x4_t vzero = vdupq_n_f32(0.f); + for (int n = 0; n < num; n++) { + const float* din_batch = din + n * in_channel_size; + float* dout_batch = dout + n * out_channel_size; +#pragma omp parallel for + for (int c = 0; c < chin; c++) { + const float* din_ch = din_batch + c * in_size; + const float* weights_ch = weights + c * weights_size; + float* dout_ch = dout_batch + c * out_size; + float bias_val = flag_bias ? bias[c] : 0.f; + const float* din_ptr0 = din_ch; + const float* din_ptr1 = din_ptr0 + win; + const float* din_ptr2 = din_ptr1 + win; + const float* din_ptr3 = din_ptr2 + win; + const float* din_ptr4 = din_ptr3 + win; + const float* din_ptr5 = din_ptr4 + win; + float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; + float* dout_ptr0 = dout_ch; + float* dout_ptr1 = dout_ch; + float32x4_t wr5; + float32x4_t wr6; + float32x4_t wr0 = vld1q_f32(weights_ch); + float32x4_t wr1 = vld1q_f32(weights_ch + 5); + float32x4_t wr2 = vld1q_f32(weights_ch + 10); + float32x4_t wr3 = vld1q_f32(weights_ch + 15); + float32x4_t wr4 = vld1q_f32(weights_ch + 20); + wr5 = vsetq_lane_f32(weights_ch[4], wr5, 0); + wr5 = vsetq_lane_f32(weights_ch[9], wr5, 1); + wr5 = vsetq_lane_f32(weights_ch[14], wr5, 2); + wr5 = vsetq_lane_f32(weights_ch[19], wr5, 3); + wr6 = vsetq_lane_f32(weights_ch[24], wr6, 0); + const float* din_ptr_arr[] = { + din_ptr0, din_ptr1, din_ptr2, din_ptr3, din_ptr4, din_ptr5}; + float32x4_t weights_vec[] = {wr0, wr1, wr2, wr3, wr4, wr5, wr6}; + // top_h + for (int h = pad_top; h > 0; h--) { + compute_all_padding_pre_leakyRelu(dout_ptr0, + din_ptr_arr, + vbias, + scale, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } -#ifdef __aarch64__ - float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 - float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 - float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 - float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 - float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 - float32x4_t vbias = vdupq_n_f32(0.f); - if (flag_bias) { - vbias = vld1q_f32(&bias[c]); // v28 + dout_ptr1 = dout_ptr0 + wout; + // mid_h + for (int h = 0; h < loop_h - 1; h += 2) { + compute_all_padding_mid_leakyRelu_out2(dout_ptr0, + dout_ptr1, + din_ptr_arr, + vbias, + scale, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 += num_out; + dout_ptr1 += num_out; + din_ptr0 = din_ptr2; + din_ptr1 = din_ptr3; + din_ptr2 = din_ptr4; + din_ptr3 = din_ptr5; + din_ptr4 = din_ptr5 + win; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr5 = din_ptr4 + win; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + din_ptr_arr[5] = din_ptr5; } - weight_c += 20; -#endif - for (int h = 0; h < oh; h += out_h_kernel) { - float* outc0 = dout_c00 + h * ow; - float* outc1 = outc0 + size_out_channel; - float* outc2 = outc1 + size_out_channel; - float* outc3 = outc2 + size_out_channel; - const float* inr0 = pre_din + h * 2 * row_len; - const float* inr1 = inr0 + row_len; - const float* inr2 = inr1 + row_len; - const float* inr3 = inr2 + row_len; - const float* inr4 = inr3 + row_len; - - if (c + out_c_block > oc) { - switch (c + out_c_block - oc) { - case 3: - outc1 = ptr_write; - case 2: - outc2 = ptr_write; - case 1: - outc3 = ptr_write; - default: - break; - } - } - auto c0 = outc0; - auto c1 = outc1; - auto c2 = outc2; - auto c3 = outc3; - float pre_out[16]; - for (int w = 0; w < w_loop; ++w) { - bool flag_mask = (w == w_loop - 1) && flag_remain; - if (flag_mask) { - c0 = outc0; - c1 = outc1; - c2 = outc2; - c3 = outc3; - outc0 = pre_out; - outc1 = pre_out + 4; - outc2 = pre_out + 8; - outc3 = pre_out + 12; - } -#ifdef __aarch64__ - act_switch_5x5s2(inr0, - inr1, - inr2, - inr3, - inr4, - outc0, - outc1, - outc2, - outc3, - w0, - w1, - w2, - w3, - w4, - vbias, - weight_c, - bias_local, - act_param); -#else - act_switch_5x5s2(inr0, - inr1, - inr2, - inr3, - inr4, - outc0, - outc1, - outc2, - outc3, - vzero, - vzero, - vzero, - vzero, - vzero, - vzero, - weight_c, - bias_local, - act_param); -#endif - if (flag_mask) { - for (int i = 0; i < remain; ++i) { - c0[i] = pre_out[i]; - c1[i] = pre_out[i + 4]; - c2[i] = pre_out[i + 8]; - c3[i] = pre_out[i + 12]; - } - } - inr0 += 32; - inr1 += 32; - inr2 += 32; - inr3 += 32; - inr4 += 32; - outc0 += 4; - outc1 += 4; - outc2 += 4; - outc3 += 4; - } + if (loop_h % 2 != 0) { + compute_all_padding_mid_leakyRelu(dout_ptr0, + din_ptr_arr, + vbias, + scale, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 4); + dout_ptr0 = dout_ptr1; + din_ptr0 = din_ptr1; + din_ptr1 = din_ptr2; + din_ptr2 = din_ptr3; + din_ptr3 = din_ptr4; + din_ptr4 = din_ptr5; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; + } + // bottom + for (int h = 0; h < pad_bottom; h++) { + compute_all_padding_post_leakyRelu(dout_ptr0, + din_ptr_arr, + vbias, + scale, + weights_vec, + vzero, + win, + wout, + pad_left, + pad_right, + cnt, + remain, + 3 - h); + dout_ptr0 += wout; + din_ptr_arr[0] = din_ptr0; + din_ptr_arr[1] = din_ptr1; + din_ptr_arr[2] = din_ptr2; + din_ptr_arr[3] = din_ptr3; + din_ptr_arr[4] = din_ptr4; } } } } - } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv5x5s2_depthwise_fp32_c4.cc b/lite/backends/arm/math/conv5x5s2_depthwise_fp32_c4.cc new file mode 100644 index 0000000000000000000000000000000000000000..a72b7553e0c8fddcb9028b0e6125281a07e65387 --- /dev/null +++ b/lite/backends/arm/math/conv5x5s2_depthwise_fp32_c4.cc @@ -0,0 +1,946 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#ifdef ARM_WITH_OMP +#include +#endif + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +#ifdef __aarch64__ +#define COMPUTE \ + "ldp q0, q1, [%[inr0]], #32\n" /* load r0, 0-1 */ \ + "and v19.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q2, q3, [%[inr0]], #32\n" /* load r0, 2-3 */ \ + "and v20.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q4, q5, [%[inr0]], #32\n" /* load r0, 4-5 */ \ + "and v21.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q6, q7, [%[inr0]], #32\n" /* load r0, 6-7 */ \ + "and v22.16b, %[vbias].16b, %[vbias].16b\n" \ + "ldp q8, q9, [%[inr0]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldr q10, [%[inr0]] \n" /* load r0, 10 */ \ + "fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "sub %[inr0], %[inr0], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr1]], #32\n" /* load r1, 0-1 */ \ + "fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , %[w3].4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w3].4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w3].4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w3].4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr1]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , %[w4].4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , %[w4].4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , %[w4].4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , %[w4].4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr1]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr1]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr1]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr1]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr1], %[inr1], #32\n" /* inr1 -= 32 */ \ + "ldp q0, q1, [%[inr2]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr2]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr2]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr2]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr2]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr2]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr2], %[inr2], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr3]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr3]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr3]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr3]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr3]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr3]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr3], %[inr3], #32\n" /* inr0 -= 32 */ \ + "ldp q0, q1, [%[inr4]], #32\n" /* load r1, 0-1 */ \ + "ldp q14, q15, [%[wc0]], #32\n" /* load w0-1, to q14-15*/ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q16, q17, [%[wc0]], #32\n" /* load w2-3, to q16-17*/ \ + "ldp q2, q3, [%[inr4]], #32\n" /* load r1, 2-3 */ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "ldp q4, q5, [%[inr4]], #32\n" /* load r1, 4-5 */ \ + "ldr q18, [%[wc0]], #16\n" /* load w4, to q18*/ \ + "ldp q6, q7, [%[inr4]], #32\n" /* load r0, 6-7 */ \ + "fmla v19.4s , v14.4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \ + "fmla v20.4s , v14.4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \ + "fmla v21.4s , v14.4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \ + "fmla v22.4s , v14.4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \ + "ldp q8, q9, [%[inr4]], #32\n" /* load r0, 8-9 */ \ + "fmla v19.4s , v15.4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \ + "fmla v20.4s , v15.4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \ + "fmla v21.4s , v15.4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \ + "fmla v22.4s , v15.4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \ + "ldr q10, [%[inr4]] \n" /* load r0, 10 */ \ + "fmla v19.4s , v16.4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \ + "fmla v20.4s , v16.4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \ + "fmla v21.4s , v16.4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \ + "fmla v22.4s , v16.4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \ + "sub %[inr4], %[inr4], #32\n" /* inr0 -= 32 */ \ + "fmla v19.4s , v17.4s, v3.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v17.4s, v5.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v17.4s, v7.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v17.4s, v9.4s\n" /* outr3 = w3 * r1, 6*/ \ + "fmla v19.4s , v18.4s, v4.4s\n" /* outr0 = w3 * r1, 0*/ \ + "fmla v20.4s , v18.4s, v6.4s\n" /* outr1 = w3 * r1, 2*/ \ + "fmla v21.4s , v18.4s, v8.4s\n" /* outr2 = w3 * r1, 4*/ \ + "fmla v22.4s , v18.4s, v10.4s\n" /* outr3 = w3 * r1, 6*/ \ + "sub %[wc0], %[wc0], #320\n" /* weight -= 320 */ \ + "trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \ + "trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \ + "trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \ + "trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \ + "trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \ + "trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \ + "trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \ + "trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/ +#define RELU /* relu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fmax v19.4s, v19.4s, v0.4s\n" \ + "fmax v20.4s, v20.4s, v0.4s\n" \ + "fmax v21.4s, v21.4s, v0.4s\n" \ + "fmax v22.4s, v22.4s, v0.4s\n" +#define RELU6 /* relu6 */ \ + "fmin v19.4s, v19.4s, %[vsix].4s\n" \ + "fmin v20.4s, v20.4s, %[vsix].4s\n" \ + "fmin v21.4s, v21.4s, %[vsix].4s\n" \ + "fmin v22.4s, v22.4s, %[vsix].4s\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "movi v0.4s, #0\n" /* for relu */ \ + "fcmge v1.4s, v19.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v3.4s, v20.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v5.4s, v21.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \ + "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_f32 */ \ + "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ + "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ + "bif v20.16b, v4.16b, v3.16b \n" /* choose*/ \ + "bif v21.16b, v6.16b, v5.16b \n" /* choose*/ \ + "bif v22.16b, v8.16b, v7.16b \n" /* choose*/ +#define STORE /* save result */ \ + "str q19, [%[outc0]], #16\n" \ + "str q20, [%[outc1]], #16\n" \ + "str q21, [%[outc2]], #16\n" \ + "str q22, [%[outc3]], #16\n" + +#else +#define COMPUTE \ + /* fill with bias */ \ + "vld1.32 {d12-d13}, [%[bias]]\n" /* load bias */ /* load weights */ \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \ + "vand.i32 q12, q6, q6\n" \ + "vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \ + "vand.i32 q13, q6, q6\n" \ + "vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \ + "vand.i32 q14, q6, q6\n" \ + "vand.i32 q15, q6, q6\n" \ + "vld1.32 {d12-d13}, [%[r0]]!\n" /* load input r0, 6*/ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-q10 */ \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr6\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr4\n" \ + "vld1.32 {d0-d3}, [%[r0]]! \n" /* load r0, 7-8 */ \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vld1.32 {d4-d7}, [%[r0]] \n" /* load r0, 9-10 */ \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d4-d5}, [%[r1]]! @ load r1, 2\n" \ + "sub %[r0], %[r0], #16 @ r0 - 16 to nextline address\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d6-d9}, [%[r1]]! @ load r1, 3, 4\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4, to q11 */ \ + "vld1.32 {d10-d13}, [%[r1]]! @ load r1, 5, 6\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr0\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr2\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.32 {d0-d3}, [%[r1]]! @ load r1, 7, 8\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.32 {d4-d7}, [%[r1]] @ load r1, 9, 10\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "sub %[r1], %[r1], #16 @ r1 - 16 to nextline address\n" \ + "vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vld1.32 {d12-d13}, [%[r2]]! @ load r2, 6 \n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r2]]! @ load r2, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vld1.32 {d4-d7}, [%[r2]] @ load r2, 9, 10\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[r2], %[r2], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 0, 1\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d4-d7}, [%[r3]]! @ load r3, 2, 3\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r3]]! @ load r3, 4, 5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.32 {d12-d13}, [%[r3]]! @ load r3, 6, \n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r3]]! @ load r3, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.32 {d4-d7}, [%[r3]] @ load r3, 9, 10\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "vld1.32 {d14-d17}, [%[wc0]]!\n" /* load w0-1, to q7-8 */ \ + "sub %[r3], %[r3], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vld1.32 {d0-d3}, [%[r4]]! @ load r4, 0, 1\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vld1.32 {d4-d7}, [%[r4]]! @ load r4, 2, 3\n" \ + "vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w2-3, to q9-10 */ \ + "vmla.f32 q12, q7, q0 @ w0 * inr0\n" \ + "vmla.f32 q13, q7, q2 @ w0 * inr2\n" \ + "vld1.32 {d8-d11}, [%[r4]]! @ load r3, 4, 5\n" \ + "vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w4 to q11 */ \ + "vld1.32 {d12-d13}, [%[r4]]! @ load r3, 6, \n" \ + "vmla.f32 q12, q8, q1 @ w1 * inr1\n" \ + "vmla.f32 q13, q8, q3 @ w1 * inr3\n" \ + "vmla.f32 q14, q7, q4 @ w0 * inr4\n" \ + "vmla.f32 q15, q7, q6 @ w0 * inr6\n" \ + "vld1.32 {d0-d3}, [%[r4]]! @ load r3, 7, 8\n" \ + "vmla.f32 q12, q9, q2 @ w2 * inr2\n" \ + "vmla.f32 q13, q9, q4 @ w2 * inr4\n" \ + "vmla.f32 q14, q8, q5 @ w1 * inr5\n" \ + "vmla.f32 q15, q8, q0 @ w1 * inr7\n" \ + "vmla.f32 q12, q10, q3 @ w3 * inr3\n" \ + "vld1.32 {d4-d7}, [%[r4]] @ load r3, 9, 10\n" \ + "vmla.f32 q13, q10, q5 @ w3 * inr5\n" \ + "vmla.f32 q14, q9, q6 @ w2 * inr6\n" \ + "vmla.f32 q15, q9, q1 @ w2 * inr8\n" \ + "vmla.f32 q12, q11, q4 @ w4 * inr4\n" \ + "vmla.f32 q13, q11, q6 @ w4 * inr6\n" \ + "vmla.f32 q14, q10, q0 @ w3 * inr7\n" \ + "vmla.f32 q15, q10, q2 @ w3 * inr9\n" \ + "sub %[wc0], %[wc0], #400 @ wc0 - 400 to start address\n" \ + "sub %[r4], %[r4], #16 @ r1 - 16 to nextline address\n" \ + "vmla.f32 q14, q11, q1 @ w4 * inr8\n" \ + "vmla.f32 q15, q11, q3 @ w4 * inr10\n" \ + "vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \ + "vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \ + "vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \ + "vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/ + +#define RELU /* relu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[six_ptr]]\n" \ + "vmax.f32 q12, q12, q0\n" \ + "vmax.f32 q13, q13, q0\n" \ + "vmax.f32 q14, q14, q0\n" \ + "vmax.f32 q15, q15, q0\n" +#define RELU6 /* relu6 */ \ + "vmin.f32 q12, q12, q1\n" \ + "vmin.f32 q13, q13, q1\n" \ + "vmin.f32 q14, q14, q1\n" \ + "vmin.f32 q15, q15, q1\n" +#define LEAKY_RELU /* LeakyRelu */ \ + "vmov.u32 q0, #0\n" \ + "vld1.32 {d2-d3}, [%[scale_ptr]]\n" \ + "vcge.f32 q2, q12, q0 @ q0 > 0 \n" \ + "vcge.f32 q4, q13, q0 @ q0 > 0 \n" \ + "vcge.f32 q6, q14, q0 @ q0 > 0 \n" \ + "vcge.f32 q8, q15, q0 @ q0 > 0 \n" \ + "vmul.f32 q3, q12, q1 @ mul \n" \ + "vmul.f32 q5, q13, q1 @ mul \n" \ + "vmul.f32 q7, q14, q1 @ mul \n" \ + "vmul.f32 q9, q15, q1 @ mul \n" \ + "vbif q12, q3, q2 @ choose \n" \ + "vbif q13, q5, q4 @ choose \n" \ + "vbif q14, q7, q6 @ choose \n" \ + "vbif q15, q9, q8 @ choose \n" +#define STORE /* save result */ \ + "vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \ + "vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \ + "vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \ + "vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/ + +#endif + +void act_switch_5x5s2(const float* inr0, + const float* inr1, + const float* inr2, + const float* inr3, + const float* inr4, + float* outc0, + float* outc1, + float* outc2, + float* outc3, + float32x4_t w0, + float32x4_t w1, + float32x4_t w2, + float32x4_t w3, + float32x4_t w4, + float32x4_t vbias, + const float* weight_c, + float* bias_local, + const operators::ActivationParam act_param) { + bool has_active = act_param.has_active; + if (has_active) { + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; +#ifdef __aarch64__ + float32x4_t vsix = vdupq_n_f32(tmp); + float32x4_t vscale = vdupq_n_f32(ss); +#else + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; +#endif + switch (act_param.active_type) { + case lite_api::ActivationType::kRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kRelu6: +#ifdef __aarch64__ + asm volatile(COMPUTE RELU RELU6 STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias), + [vsix] "w"(vsix) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE RELU RELU6 STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [six_ptr] "r"(vsix) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + case lite_api::ActivationType::kLeakyRelu: +#ifdef __aarch64__ + asm volatile(COMPUTE LEAKY_RELU STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias), + [vscale] "w"(vscale) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE LEAKY_RELU STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local), [scale_ptr] "r"(vscale) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + break; + default: + LOG(FATAL) << "this act_type: " + << static_cast(act_param.active_type) + << " fuse not support"; + } + } else { +#ifdef __aarch64__ + asm volatile(COMPUTE STORE + : [inr0] "+r"(inr0), + [inr1] "+r"(inr1), + [inr2] "+r"(inr2), + [inr3] "+r"(inr3), + [inr4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [w0] "w"(w0), + [w1] "w"(w1), + [w2] "w"(w2), + [w3] "w"(w3), + [w4] "w"(w4), + [vbias] "w"(vbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); +#else + asm volatile(COMPUTE STORE + : [r0] "+r"(inr0), + [r1] "+r"(inr1), + [r2] "+r"(inr2), + [r3] "+r"(inr3), + [r4] "+r"(inr4), + [wc0] "+r"(weight_c), + [outc0] "+r"(outc0), + [outc1] "+r"(outc1), + [outc2] "+r"(outc2), + [outc3] "+r"(outc3) + : [bias] "r"(bias_local) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + } +} +void conv_depthwise_5x5s2_fp32(const float* i_data, + float* o_data, + int bs, + int oc, + int oh, + int ow, + int ic, + int ih, + int win, + const float* weights, + const float* bias, + const operators::ConvParam& param, + const operators::ActivationParam act_param, + ARMContext* ctx) { + auto paddings = *param.paddings; + int threads = ctx->threads(); + const int pad_h = paddings[0]; + const int pad_w = paddings[2]; + const int out_c_block = 4; + const int out_h_kernel = 1; + const int out_w_kernel = 4; + const int win_ext = ow * 2 + 3; + const int ow_round = ROUNDUP(ow, 4); + const int win_round = ROUNDUP(win_ext, 4); + const int hin_round = oh * 2 + 3; + const int prein_size = win_round * hin_round * out_c_block; + auto workspace_size = threads * prein_size + win_round + ow_round; + ctx->ExtendWorkspace(sizeof(float) * workspace_size); + + bool flag_bias = param.bias != nullptr; + + /// get workspace + auto ptr_zero = ctx->workspace_data(); + memset(ptr_zero, 0, sizeof(float) * win_round); + float* ptr_write = ptr_zero + win_round; + + int size_in_channel = win * ih; + int size_out_channel = ow * oh; + + int ws = -pad_w; + int we = ws + win_round; + int hs = -pad_h; + int he = hs + hin_round; + int w_loop = ow_round / 4; + auto remain = w_loop * 4 - ow; + bool flag_remain = remain > 0; + remain = 4 - remain; + remain = remain > 0 ? remain : 0; + int row_len = win_round * out_c_block; + + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < bs; ++n) { + const float* din_batch = i_data + n * ic * size_in_channel; + float* dout_batch = o_data + n * oc * size_out_channel; +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < oc; c += out_c_block) { +#ifdef ARM_WITH_OMP + float* pre_din = ptr_write + ow_round + omp_get_thread_num() * prein_size; +#else + float* pre_din = ptr_write + ow_round; +#endif + /// const array size + prepack_input_nxwc4_dw( + din_batch, pre_din, c, hs, he, ws, we, ic, win, ih, ptr_zero); + const float* weight_c = weights + c * 25; // kernel_w * kernel_h + float* dout_c00 = dout_batch + c * size_out_channel; + float bias_local[4] = {0, 0, 0, 0}; + + if (flag_bias) { + bias_local[0] = bias[c]; + bias_local[1] = bias[c + 1]; + bias_local[2] = bias[c + 2]; + bias_local[3] = bias[c + 3]; + } +#ifdef __aarch64__ + float32x4_t w0 = vld1q_f32(weight_c); // w0, v23 + float32x4_t w1 = vld1q_f32(weight_c + 4); // w1, v24 + float32x4_t w2 = vld1q_f32(weight_c + 8); // w2, v25 + float32x4_t w3 = vld1q_f32(weight_c + 12); // w3, v26 + float32x4_t w4 = vld1q_f32(weight_c + 16); // w4, v27 + float32x4_t vbias = vdupq_n_f32(0.f); + if (flag_bias) { + vbias = vld1q_f32(&bias[c]); // v28 + } + weight_c += 20; +#endif + for (int h = 0; h < oh; h += out_h_kernel) { + float* outc0 = dout_c00 + h * ow; + float* outc1 = outc0 + size_out_channel; + float* outc2 = outc1 + size_out_channel; + float* outc3 = outc2 + size_out_channel; + const float* inr0 = pre_din + h * 2 * row_len; + const float* inr1 = inr0 + row_len; + const float* inr2 = inr1 + row_len; + const float* inr3 = inr2 + row_len; + const float* inr4 = inr3 + row_len; + + if (c + out_c_block > oc) { + switch (c + out_c_block - oc) { + case 3: + outc1 = ptr_write; + case 2: + outc2 = ptr_write; + case 1: + outc3 = ptr_write; + default: + break; + } + } + auto c0 = outc0; + auto c1 = outc1; + auto c2 = outc2; + auto c3 = outc3; + float pre_out[16]; + for (int w = 0; w < w_loop; ++w) { + bool flag_mask = (w == w_loop - 1) && flag_remain; + if (flag_mask) { + c0 = outc0; + c1 = outc1; + c2 = outc2; + c3 = outc3; + outc0 = pre_out; + outc1 = pre_out + 4; + outc2 = pre_out + 8; + outc3 = pre_out + 12; + } +#ifdef __aarch64__ + act_switch_5x5s2(inr0, + inr1, + inr2, + inr3, + inr4, + outc0, + outc1, + outc2, + outc3, + w0, + w1, + w2, + w3, + w4, + vbias, + weight_c, + bias_local, + act_param); +#else + act_switch_5x5s2(inr0, + inr1, + inr2, + inr3, + inr4, + outc0, + outc1, + outc2, + outc3, + vzero, + vzero, + vzero, + vzero, + vzero, + vzero, + weight_c, + bias_local, + act_param); +#endif + if (flag_mask) { + for (int i = 0; i < remain; ++i) { + c0[i] = pre_out[i]; + c1[i] = pre_out[i + 4]; + c2[i] = pre_out[i + 8]; + c3[i] = pre_out[i + 12]; + } + } + inr0 += 32; + inr1 += 32; + inr2 += 32; + inr3 += 32; + inr4 += 32; + outc0 += 4; + outc1 += 4; + outc2 += 4; + outc3 += 4; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index c0f7aaa75a0f6f52acc3afd043c752c81d9e646f..f76036803b79227bfc7524eaeaaa45c0d8122bc6 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -193,6 +193,25 @@ void conv_depthwise_5x5s2_fp32(const float* din, const operators::ActivationParam act_param, ARMContext* ctx); +void conv_depthwise_5x5s2_fp32(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + int num, + int chin, + int hin, + int win, + int hout, + int wout, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + const operators::ActivationParam& act_param, + ARMContext* ctx); + void conv_depthwise_5x5s2p2_fp32(const float* din, float* dout, int num, diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 2bad1f997f457429c013c11a1dce35eb43dc26da..dfbb5b3983c5f70abd970567c5ffdae5ae6bf36d 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -734,22 +734,45 @@ void conv_depthwise_5x5_fp32(const void* din, int stride = param.strides[1]; bool flag_relu = param.fuse_relu; bool flag_bias = param.bias != nullptr; + bool ch_four = ch_in > 4 * w_in; + bool pads_five = (pad_h < 5) || (pad_w < 5); ctx->ExtendWorkspace((w_in + w_out) * sizeof(float)); if (stride == 2) { - conv_depthwise_5x5s2_fp32(reinterpret_cast(din), - reinterpret_cast(dout), - num, - ch_out, - h_out, - w_out, - ch_in, - h_in, - w_in, - reinterpret_cast(weights), - bias, - param, - act_param, - ctx); + if (ch_four || !pads_five || h_in < 5 || w_in < 10) { + conv_depthwise_5x5s2_fp32(reinterpret_cast(din), + reinterpret_cast(dout), + num, + ch_out, + h_out, + w_out, + ch_in, + h_in, + w_in, + reinterpret_cast(weights), + bias, + param, + act_param, + ctx); + else { + conv_depthwise_5x5s2_fp32(reinterpret_cast(dout), + reinterpret_cast(din), + reinterpret_cast(weights), + bias, + flag_bias, + flag_relu, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + paddings[0], + paddings[1], + paddings[2], + paddings[3], + act_param, + ctx); + } } else if (stride == 1) { conv_depthwise_5x5s1_fp32(reinterpret_cast(dout), reinterpret_cast(din), diff --git a/lite/kernels/arm/conv_depthwise.cc b/lite/kernels/arm/conv_depthwise.cc index 3558eb22fbd4863771bf2b6b2e62e51b75a1227e..46d89db291dd6f16aecde8e70d0de9a6e51a9b56 100644 --- a/lite/kernels/arm/conv_depthwise.cc +++ b/lite/kernels/arm/conv_depthwise.cc @@ -28,7 +28,11 @@ void DepthwiseConv::PrepareForRun() { auto& ctx = this->ctx_->template As(); auto w_dims = param.filter->dims(); auto kw = w_dims[3]; + auto channel = w_dims[0]; + auto hin = param.x->dims()[2]; + auto win = param.x->dims()[3]; auto paddings = *param.paddings; + bool ch_four = channel <= 4 * win; // select dw conv kernel if (kw == 3) { bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); @@ -54,7 +58,15 @@ void DepthwiseConv::PrepareForRun() { #endif } else if (kw == 5) { auto strides = param.strides; - if ((strides[0] == 1 && strides[1] == 1) || + bool pads_five = (paddings[0] < 5) || (paddings[2] < 5); + if (ch_four && pads_five && win >= 2 * kw && hin >= kw && + (strides[0] == 2 && strides[1] == 2) { + flag_trans_weights_ = false; + impl_ = lite::arm::math::conv_depthwise_5x5_fp32; +#ifdef LITE_WITH_PROFILE + kernel_func_name_ = "conv_depthwise_5x5_fp32"; +#endif + } else if ((strides[0] == 1 && strides[1] == 1) || (strides[0] == 2 && strides[1] == 2)) { // trans weights constexpr int cblock = 4;