From f99e82953e0dc4305473a52aab788d8e8327abbf Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Wed, 5 Aug 2020 18:21:49 +0800 Subject: [PATCH] update conv_dw_3x3s2 --- .../arm/math/conv3x3s1p01_depthwise_fp32.cc | 22 +- .../arm/math/conv3x3s1p01_depthwise_fp32_1.cc | 3989 ----------------- .../math/conv3x3s1p01_depthwise_fp32_relu1.cc | 2983 ------------ .../math/conv3x3s2p01_depthwise_fp32_new.cc | 2912 ++++++++++++ .../conv3x3s2p01_depthwise_fp32_relu_new.cc | 2213 +++++++++ lite/backends/arm/math/conv_depthwise.h | 151 +- 6 files changed, 5239 insertions(+), 7031 deletions(-) delete mode 100644 lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_1.cc delete mode 100644 lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_relu1.cc create mode 100644 lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_new.cc create mode 100644 lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_relu_new.cc diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc index 1ae6611a77..667c2ca1f8 100644 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -203,7 +203,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_out, ctx); } else { - conv_depthwise_3x3s1p0_bias_s_relu(dout, + conv_depthwise_3x3s1p1_bias_s_relu(dout, din, weights, bias, @@ -267,7 +267,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_out, ctx); } else { - conv_depthwise_3x3s1p0_bias_s_relu6(dout, + conv_depthwise_3x3s1p1_bias_s_relu6(dout, din, weights, bias, @@ -331,7 +331,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, w_out, ctx); } else { - conv_depthwise_3x3s1p0_bias_s_leakyRelu(dout, + conv_depthwise_3x3s1p1_bias_s_leakyRelu(dout, din, weights, bias, @@ -2225,7 +2225,7 @@ void conv_depthwise_3x3s1p1_bias_relu6(float *dout, float32x4_t vzero = vdupq_n_f32(0.f); #ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); + float32x4_t vsix = vld1q_f32(six); #endif for (int n = 0; n < num; ++n) { const float *din_batch = din + n * ch_in * size_in_channel; @@ -2523,7 +2523,7 @@ void conv_depthwise_3x3s1p1_bias_leakyRelu(float *dout, float32x4_t vzero = vdupq_n_f32(0.f); #ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vscale = vld1q_f32(scale); #endif for (int n = 0; n < num; ++n) { const float *din_batch = din + n * ch_in * size_in_channel; @@ -2786,7 +2786,7 @@ void conv_depthwise_3x3s1p1_bias_s_relu6(float *dout, int size_in_channel = w_in * h_in; int size_out_channel = w_out * h_out; #ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); + float32x4_t vsix = vld1q_f32(six); #endif for (int n = 0; n < num; ++n) { const float *din_batch = din + n * ch_in * size_in_channel; @@ -2947,7 +2947,7 @@ void conv_depthwise_3x3s1p1_bias_s_leakyRelu(float *dout, int size_in_channel = w_in * h_in; int size_out_channel = w_out * h_out; #ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vscale = vld1q_f32(scale); #endif for (int n = 0; n < num; ++n) { const float *din_batch = din + n * ch_in * size_in_channel; @@ -3119,7 +3119,7 @@ void conv_depthwise_3x3s1p0_bias_relu6(float *dout, const int remian_idx[4] = {0, 1, 2, 3}; #ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); + float32x4_t vsix = vld1q_f32(six); #endif if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 @@ -3402,7 +3402,7 @@ void conv_depthwise_3x3s1p0_bias_s_relu6(float *dout, vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); #ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); + float32x4_t vsix = vld1q_f32(six); #endif unsigned int vmask[8]; @@ -3569,7 +3569,7 @@ void conv_depthwise_3x3s1p0_bias_leakyRelu(float *dout, const int remian_idx[4] = {0, 1, 2, 3}; #ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vscale = vld1q_f32(scale); #endif if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 @@ -3853,7 +3853,7 @@ void conv_depthwise_3x3s1p0_bias_s_leakyRelu(float *dout, vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); #ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vscale = vld1q_f32(scale); #endif unsigned int vmask[8]; diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_1.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_1.cc deleted file mode 100644 index f084b477f9..0000000000 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_1.cc +++ /dev/null @@ -1,3989 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/backends/arm/math/conv_depthwise.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p1_bias_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p1_bias_s_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p0_bias_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p0_bias_s_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p1_bias_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p1_bias_s_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p0_bias_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1p0_bias_s_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx); - -void conv_depthwise_3x3s1_fp32(const float *din, - float *dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float *weights, - const float *bias, - int pad, - bool flag_bias, - const operators::ActivationParam act_param, - ARMContext *ctx) { - bool has_active = act_param.has_active; - auto act_type = act_param.active_type; - float tmp = act_param.Relu_clipped_coef; - float ss = act_param.Leaky_relu_alpha; - float vsix[4] = {tmp, tmp, tmp, tmp}; - float vscale[4] = {ss, ss, ss, ss}; - if (has_active) { - switch (act_type) { - case lite_api::ActivationType::kRelu: - if (pad == 0) { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias_relu(dout, - din, - weights, - bias, - flag_bias, - true, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - true, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - if (pad == 1) { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias_relu(dout, - din, - weights, - bias, - flag_bias, - true, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - true, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - break; - case lite_api::ActivationType::kRelu6: - if (pad == 0) { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias_relu6(dout, - din, - weights, - bias, - vsix, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_relu6(dout, - din, - weights, - bias, - vsix, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - if (pad == 1) { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias_relu6(dout, - din, - weights, - bias, - vsix, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_relu6(dout, - din, - weights, - bias, - vsix, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - break; - case lite_api::ActivationType::kLeakyRelu: - if (pad == 0) { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias_leakyRelu(dout, - din, - weights, - bias, - vscale, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_leakyRelu(dout, - din, - weights, - bias, - vscale, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - if (pad == 1) { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias_leakyRelu(dout, - din, - weights, - bias, - vscale, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_leakyRelu(dout, - din, - weights, - bias, - vscale, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - break; - default: - LOG(FATAL) << "this act_type: " << static_cast(act_type) - << " fuse not support"; - } - } else { - if (pad == 0) { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias_no_relu(dout, - din, - weights, - bias, - flag_bias, - false, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_no_relu(dout, - din, - weights, - bias, - flag_bias, - false, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - if (pad == 1) { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias_no_relu(dout, - din, - weights, - bias, - flag_bias, - false, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s_no_relu(dout, - din, - weights, - bias, - flag_bias, - false, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } -} - -#ifdef __aarch64__ -#define INIT_S1 \ - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" \ - "movi v21.4s, #0x0\n" /* out0 = 0 */ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - -#define LEFT_COMPUTE_S1 \ - "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ /* r0 */ \ - "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * w0[1]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * w0[0]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * w0[2]*/ \ - \ - "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * w1[1]*/ \ - "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16=1234 */ \ - "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ \ - \ - /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ - -#define LEFT_RESULT_S1 \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ /* r5 */ \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "cmp %w[cnt], #1 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "blt 3f \n" - -#define MID_COMPUTE_S1 \ - "1: \n" /* r0 */ \ - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - -#define MID_RESULT_S1 \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define RIGHT_COMPUTE_S1 \ - "3: \n" \ - "movi v20.4s, #0 \n" \ - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \ - "ld1 {v22.4s}, [%[doutr0]] \n" \ - "ld1 {v23.4s}, [%[doutr1]] \n" \ - "ld1 {v24.4s}, [%[doutr2]] \n" \ - "ld1 {v25.4s}, [%[doutr3]] \n" \ - \ - "bif v0.16b, v20.16b, v18.16b \n" \ - "bif v1.16b, v20.16b, v19.16b \n" \ - "bif v2.16b, v20.16b, v18.16b \n" \ - "bif v3.16b, v20.16b, v19.16b \n" \ - \ - "bif v4.16b, v20.16b, v18.16b \n" \ - "bif v5.16b, v20.16b, v19.16b \n" \ - "bif v6.16b, v20.16b, v18.16b \n" \ - "bif v7.16b, v20.16b, v19.16b \n" \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \ - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v8.16b, v20.16b, v18.16b \n" \ - "bif v9.16b, v20.16b, v19.16b \n" \ - "bif v10.16b, v20.16b, v18.16b \n" \ - "bif v11.16b, v20.16b, v19.16b \n" \ - \ - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v18.4s}, [%[rmask]] \n" \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ - -#define RIGHT_RESULT_S1 \ - "bif v12.16b, v22.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "bif v15.16b, v25.16b, v18.16b \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define LEFT_RESULT_S1_RELU \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "cmp %w[cnt], #1 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "blt 3f \n" - -#define LEFT_RESULT_S1_RELU6 \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ - "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ - \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "cmp %w[cnt], #1 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "blt 3f \n" - -#define LEFT_RESULT_S1_LEAKY_RELU \ - "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ - "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ - "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ - "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ - "bif v13.16b, v21.16b, v19.16b \n" /* choose*/ \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ - "fcmge v18.4s, v14.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ - "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "fcmge v18.4s, v15.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ - "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ - "cmp %w[cnt], #1 \n" \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "blt 3f \n" - -#define MID_RESULT_S1_RELU \ - "movi v20.4s, #0 \n" \ - "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - \ - /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define MID_RESULT_S1_RELU6 \ - "movi v20.4s, #0 \n" \ - "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define MID_RESULT_S1_LEAKY_RELU \ - "movi v21.4s, #0 \n" \ - "fcmge v18.4s, v12.4s, v21.4s \n" /* vcgeq_f32 */ \ - "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v12.16b, v20.16b, v18.16b \n" /* choose*/ \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fcmge v18.4s, v13.4s, v21.4s \n" /* vcgeq_f32 */ \ - "fmul v20.4s, v13.4s, %[vscale].4s \n" /* mul */ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "bif v13.16b, v20.16b, v18.16b \n" /* choose*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "fcmge v18.4s, v14.4s, v21.4s \n" /* vcgeq_f32 */ \ - "fmul v20.4s, v14.4s, %[vscale].4s \n" /* mul */ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v14.16b, v20.16b, v18.16b \n" /* choose*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fcmge v18.4s, v15.4s, v21.4s \n" /* vcgeq_f32 */ \ - "fmul v20.4s, v15.4s, %[vscale].4s \n" /* mul */ \ - \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "bif v15.16b, v20.16b, v18.16b \n" /* choose*/ \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define RIGHT_RESULT_S1_RELU \ - "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v12.16b, v22.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ - \ - "bif v15.16b, v25.16b, v18.16b \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define RIGHT_RESULT_S1_RELU6 \ - "fmax v12.4s, v12.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmin v12.4s, v12.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "bif v12.16b, v22.16b, v18.16b \n" \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmax v13.4s, v13.4s, v20.4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmin v13.4s, v13.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmax v14.4s, v14.4s, v20.4s \n" /*relu*/ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmin v14.4s, v14.4s, %[vsix].4s \n" /*relu6*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - "fmax v15.4s, v15.4s, v20.4s \n" /*relu*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmin v15.4s, v15.4s, %[vsix].4s \n" /*relu6*/ \ - "bif v15.16b, v25.16b, v18.16b \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define RIGHT_RESULT_S1_LEAKY_RELU \ - "movi v1.4s, #0 \n" \ - "fcmge v20.4s, v12.4s, v1.4s \n" /* vcgeq_f32 */ \ - "fmul v21.4s, v12.4s, %[vscale].4s \n" /* mul */ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v12.16b, v21.16b, v20.16b \n" /* choose*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "bif v12.16b, v22.16b, v18.16b \n" \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fcmge v20.4s, v13.4s, v1.4s \n" /* vcgeq_f32 */ \ - "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v13.16b, v21.16b, v20.16b \n" \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fcmge v20.4s, v14.4s, v1.4s \n" /* vcgeq_f32 */ \ - "fmul v21.4s, v14.4s, %[vscale].4s \n" /* mul */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v14.16b, v21.16b, v20.16b \n" \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - \ - "fcmge v20.4s, v15.4s, v1.4s \n" /* vcgeq_f32 */ \ - "fmul v21.4s, v15.4s, %[vscale].4s \n" /* mul */ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - "bif v15.16b, v21.16b, v20.16b \n" \ - "bif v15.16b, v25.16b, v18.16b \n" \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define COMPUTE_S_S1 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s}, [%[din0]], #16\n" \ - "ld1 {v1.4s}, [%[din1]], #16\n" \ - "ld1 {v2.4s}, [%[din2]], #16\n" \ - "ld1 {v3.4s}, [%[din3]], #16\n" \ - \ - "bif v0.16b, %[vzero].16b, %[mask].16b\n" \ - "bif v1.16b, %[vzero].16b, %[mask].16b\n" \ - "bif v2.16b, %[vzero].16b, %[mask].16b\n" \ - "bif v3.16b, %[vzero].16b, %[mask].16b\n" \ - \ - "ext v4.16b, %[vzero].16b, v0.16b, #12\n" \ - "ext v5.16b, %[vzero].16b, v1.16b, #12\n" \ - "ext v6.16b, %[vzero].16b, v2.16b, #12\n" \ - "ext v7.16b, %[vzero].16b, v3.16b, #12\n" \ - \ - "ext v8.16b, v0.16b, %[vzero].16b, #4\n" \ - "ext v9.16b, v1.16b, %[vzero].16b, #4\n" \ - "ext v10.16b, v2.16b, %[vzero].16b, #4\n" \ - "ext v11.16b, v3.16b, %[vzero].16b, #4\n" \ - \ - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ - \ - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ - \ - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ - \ - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ - \ - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ - \ - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ - \ - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ - \ - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v14.4s\n" \ - "fadd v12.4s, v12.4s, v16.4s\n" \ - \ - "fadd v13.4s, v13.4s, v15.4s\n" \ - "fadd v13.4s, v13.4s, v17.4s\n" \ - \ - "fadd v12.4s, v12.4s, %[bias].4s\n" \ - "fadd v13.4s, v13.4s, %[bias].4s\n" - -#define RESULT_S_S1 \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" - -#define RESULT_S_S1_RELU \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "fmax v12.4s, v12.4s, %[vzero].4s\n" \ - "fmax v13.4s, v13.4s, %[vzero].4s\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" - -#define RESULT_S_S1_RELU6 \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "fmax v12.4s, v12.4s, %[vzero].4s\n" \ - "fmax v13.4s, v13.4s, %[vzero].4s\n" \ - \ - "fmin v12.4s, v12.4s, %[vsix].4s\n" \ - "fmin v13.4s, v13.4s, %[vsix].4s\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" - -#define RESULT_S_S1_LEAKY_RELU \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "fcmge v18.4s, v12.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fcmge v19.4s, v13.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ - "fmul v20.4s, v12.4s, %[vscale].4s \n" /* mul */ \ - "fmul v21.4s, v13.4s, %[vscale].4s \n" /* mul */ \ - \ - "bif v12.16b, v20.16b, v18.16b \n" \ - "bif v13.16b, v21.16b, v19.16b \n" \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" -#define COMPUTE_S_S1_P0 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s, v1.4s}, [%[din0]]\n" \ - "ld1 {v2.4s, v3.4s}, [%[din1]]\n" \ - "ld1 {v4.4s, v5.4s}, [%[din2]]\n" \ - "ld1 {v6.4s, v7.4s}, [%[din3]]\n" \ - \ - "bif v0.16b, %[vzero].16b, %[mask1].16b\n" \ - "bif v1.16b, %[vzero].16b, %[mask2].16b\n" \ - \ - "bif v2.16b, %[vzero].16b, %[mask1].16b\n" \ - "bif v3.16b, %[vzero].16b, %[mask2].16b\n" \ - \ - "bif v4.16b, %[vzero].16b, %[mask1].16b\n" \ - "bif v5.16b, %[vzero].16b, %[mask2].16b\n" \ - \ - "bif v6.16b, %[vzero].16b, %[mask1].16b\n" \ - "bif v7.16b, %[vzero].16b, %[mask2].16b\n" \ - \ - "ext v8.16b, v0.16b, v1.16b, #4\n" \ - "ext v9.16b, v0.16b, v1.16b, #8\n" \ - \ - "and v12.16b, %[vbias].16b, %[vbias].16b \n" \ - "and v13.16b, %[vbias].16b, %[vbias].16b \n" /* r0 */ \ - "fmul v10.4s, v0.4s, %[wr0].s[0]\n" \ - "fmul v11.4s, v8.4s, %[wr0].s[1]\n" \ - "fmla v12.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "ext v8.16b, v2.16b, v3.16b, #4\n" \ - "ext v9.16b, v2.16b, v3.16b, #8\n" /* r1 */ \ - "fmul v14.4s, v2.4s, %[wr0].s[0]\n" \ - "fmla v10.4s, v2.4s, %[wr1].s[0]\n" \ - \ - "fmul v15.4s, v8.4s, %[wr0].s[1]\n" \ - "fmla v11.4s, v8.4s, %[wr1].s[1]\n" \ - \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - "fmla v12.4s, v9.4s, %[wr1].s[2]\n" \ - \ - "ext v8.16b, v4.16b, v5.16b, #4\n" \ - "ext v9.16b, v4.16b, v5.16b, #8\n" /* r2 */ \ - "fmla v14.4s, v4.4s, %[wr1].s[0]\n" \ - "fmla v10.4s, v4.4s, %[wr2].s[0]\n" \ - \ - "fmla v15.4s, v8.4s, %[wr1].s[1]\n" \ - "fmla v11.4s, v8.4s, %[wr2].s[1]\n" \ - \ - "fmla v13.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v12.4s, v9.4s, %[wr2].s[2]\n" \ - \ - "ext v8.16b, v6.16b, v7.16b, #4\n" \ - "ext v9.16b, v6.16b, v7.16b, #8\n" \ - \ - "fmla v14.4s, v6.4s, %[wr2].s[0]\n" \ - \ - "fmla v15.4s, v8.4s, %[wr2].s[1]\n" \ - \ - "fadd v12.4s, v12.4s, v10.4s\n" \ - \ - "fmla v13.4s, v9.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v11.4s\n" \ - "fadd v13.4s, v13.4s, v14.4s\n" \ - "fadd v13.4s, v13.4s, v15.4s\n" // \ - // "prfm pldl1keep, [%[out1]]\n" \ - // "prfm pldl1keep, [%[out2]]\n" \ - // \ - // "st1 {v12.4s}, [%[out1]]\n" \ - // "st1 {v13.4s}, [%[out2]]\n" \ - -#else -#define INIT_S1 \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" \ - "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" \ - "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" \ - "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" \ - \ - "vdup.32 q4, %[bias_val] @ and \n" \ - "vdup.32 q5, %[bias_val] @ and \n" - -#define LEFT_COMPUTE_S1 \ - "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" \ - "vext.32 q7, q8, q9, #1 @ 1234\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" \ - "vext.32 q7, q10, q11, #1 @ 1234\n" \ - \ - /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" \ - "vext.32 q7, q12, q13, #1 @ 1234\n" \ - \ - /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" \ - "vext.32 q7, q14, q15, #1 @ 1234\n" - -#define LEFT_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define MID_COMPUTE_S1 \ - "1: @ right pad entry\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" - -#define MID_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define RIGHT_COMPUTE_S1 \ - "3: @ right pad entry\n" \ - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ - \ - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" \ - \ - "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" \ - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" - -#define RIGHT_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define LEFT_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define LEFT_RESULT_S1_RELU6 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.f32 {d28-d29}, [%[six_ptr]] @ load six \n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vmin.f32 q4, q4, q14 @ relu6 \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - "vmin.f32 q5, q5, q14 @ relu6 \n" \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define LEFT_RESULT_S1_LEAKY_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - "vld1.f32 {d28-d29}, [%[scale_ptr]] @ load scale \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ - "vmul.f32 q6, q4, q14 \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vbif q4, q6, q15 @ choose \n" \ - "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ - "vmul.f32 q6, q5, q14 \n" \ - \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vbif q5, q6, q7 @ choose \n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define MID_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define MID_RESULT_S1_RELU6 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmin.f32 q4, q4, q14 @ relu6 \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmin.f32 q5, q5, q14 @ relu6 \n" \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define MID_RESULT_S1_LEAKY_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[scale_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ - "vmul.f32 q6, q4, q14 \n" \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vbif q4, q6, q15 @ choose \n" \ - "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ - "vmul.f32 q6, q5, q14 \n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - \ - "vbif q5, q6, q7 @ choose \n" \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define RIGHT_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define RIGHT_RESULT_S1_RELU6 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[six_ptr]] @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmin.f32 q4, q4, q14 @ relu6 \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmin.f32 q5, q5, q14 @ relu6 \n" \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define RIGHT_RESULT_S1_LEAKY_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[scale_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vcge.f32 q15, q4, %q[vzero] @ q0 > 0 \n" \ - "vmul.f32 q6, q4, q14 \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vbif q4, q6, q15 @ choose \n" \ - \ - "vcge.f32 q7, q5, %q[vzero] @ q0 > 0 \n" \ - "vmul.f32 q6, q5, q14 \n" \ - \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - "vbif q5, q6, q7 @ choose \n" \ - \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define COMPUTE_S_S1 \ - "pld [%[din0]]\n" \ - "pld [%[din1]]\n" \ - "pld [%[din2]]\n" \ - "pld [%[din3]]\n" \ - \ - "vld1.32 {d12-d13}, [%[din0]]!\n" \ - "vld1.32 {d14-d15}, [%[din1]]!\n" \ - "vld1.32 {d16-d17}, [%[din2]]!\n" \ - "vld1.32 {d18-d19}, [%[din3]]!\n" \ - \ - "vbif q6, %q[vzero], %q[mask]\n" \ - "vbif q7, %q[vzero], %q[mask]\n" \ - "vbif q8, %q[vzero], %q[mask]\n" \ - "vbif q9, %q[vzero], %q[mask]\n" \ - \ - "vmul.f32 q14, q6, %e[wr0][1]\n" \ - "vmul.f32 q15, q7, %e[wr0][1]\n" \ - \ - "vmla.f32 q14, q7, %e[wr1][1]\n" \ - "vmla.f32 q15, q8, %e[wr1][1]\n" \ - \ - "vmla.f32 q14, q8, %e[wr2][1]\n" \ - "vmla.f32 q15, q9, %e[wr2][1]\n" \ - \ - "vext.32 q10, %q[vzero], q6, #3\n" \ - "vext.32 q11, %q[vzero], q7, #3\n" \ - "vext.32 q12, %q[vzero], q8, #3\n" \ - "vext.32 q13, %q[vzero], q9, #3\n" \ - \ - "vmla.f32 q14, q10, %e[wr0][0]\n" \ - "vmla.f32 q15, q11, %e[wr0][0]\n" \ - \ - "vmla.f32 q14, q11, %e[wr1][0]\n" \ - "vmla.f32 q15, q12, %e[wr1][0]\n" \ - \ - "vmla.f32 q14, q12, %e[wr2][0]\n" \ - "vmla.f32 q15, q13, %e[wr2][0]\n" \ - \ - "vext.32 q10, q6, %q[vzero], #1\n" \ - "vext.32 q11, q7, %q[vzero], #1\n" \ - "vext.32 q12, q8, %q[vzero], #1\n" \ - "vext.32 q13, q9, %q[vzero], #1\n" \ - \ - "vmla.f32 q14, q10, %f[wr0][0]\n" \ - "vmla.f32 q15, q11, %f[wr0][0]\n" \ - \ - "vmla.f32 q14, q11, %f[wr1][0]\n" \ - "vmla.f32 q15, q12, %f[wr1][0]\n" \ - \ - "vmla.f32 q14, q12, %f[wr2][0]\n" \ - "vmla.f32 q15, q13, %f[wr2][0]\n" \ - \ - "vadd.f32 q14, q14, %q[bias]\n" \ - "vadd.f32 q15, q15, %q[bias]\n" - -#define RESULT_S_S1 \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define RESULT_S_S1_RELU \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vmax.f32 q14, q14, %q[vzero]\n" \ - "vmax.f32 q15, q15, %q[vzero]\n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define RESULT_S_S1_RELU6 \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vld1.32 {d20-d21}, [%[six_ptr]] \n" \ - "vmax.f32 q14, q14, %q[vzero]\n" \ - "vmax.f32 q15, q15, %q[vzero]\n" \ - \ - "vmin.f32 q14, q14, q10 \n" \ - "vmin.f32 q15, q15, q10 \n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define RESULT_S_S1_LEAKY_RELU \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vld1.32 {d18-d19}, [%[scale_ptr]] \n" \ - "vcge.f32 q10, q14, %q[vzero] @ q0 > 0 \n" \ - "vcge.f32 q11, q15, %q[vzero] @ q0 > 0 \n" \ - "vmul.f32 q12, q14, q9 \n" \ - "vmul.f32 q13, q15, q9 \n" \ - \ - "vbif q14, q12, q10 \n" \ - "vbif q15, q13, q11 \n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define COMPUTE_S_S1_P0 \ - "pld [%[din0]]\n" \ - "pld [%[din1]]\n" \ - "pld [%[din2]]\n" \ - "pld [%[din3]]\n" \ - "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" \ - "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" \ - "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" \ - "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" \ - \ - "vdup.32 q4, %[bias_val] @ and \n" \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ - \ - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ - \ - "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ - \ - "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ - \ - "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ - "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - "vadd.f32 q4, q4, q10 @ q4 += q10 \n" \ - \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vadd.f32 q14, q4, q11 @ q4 += q10 \n" \ - \ - "vadd.f32 q5, q5, q8 @ q4 += q10 \n" \ - "vadd.f32 q15, q5, q9 @ q4 += q10 \n" - -#endif - -#ifdef __aarch64__ -void conv_depthwise_3x3s1p1_bias_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - int cnt_col = tile_w - 1; - - unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in); - const unsigned int remian_idx[4] = {0, 1, 2, 3}; - - if (remain == 0 && size_pad_right == 5) { - size_pad_right = 1; - cnt_col -= 1; - remain = 4; - } else if (remain == 0 && size_pad_right == 6) { - size_pad_right = 2; - cnt_col -= 1; - remain = 4; - } - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); -#ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); -#endif - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 - MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [vsix] "w"(vsix), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_out; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 - MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [six_ptr] "r"(six), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -void conv_depthwise_3x3s1p1_bias_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - int cnt_col = tile_w - 1; - - unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in); - const unsigned int remian_idx[4] = {0, 1, 2, 3}; - - if (remain == 0 && size_pad_right == 5) { - size_pad_right = 1; - cnt_col -= 1; - remain = 4; - } else if (remain == 0 && size_pad_right == 6) { - size_pad_right = 2; - cnt_col -= 1; - remain = 4; - } - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); -#ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); -#endif - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU - MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_LEAKY_RELU - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [vscale] "w"(vscale), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_out; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU - MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_LEAKY_RELU - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [six_ptr] "r"(six), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -void conv_depthwise_3x3s1p1_bias_s_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; -#ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); -#endif - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - const float *dr0 = din_channel; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - for (int j = 0; j < h_out; j += 2) { - const float *dr0_ptr = dr0; - const float *dr1_ptr = dr1; - const float *dr2_ptr = dr2; - const float *dr3_ptr = dr3; - if (j == 0) { - dr0_ptr = zero; - dr1_ptr = dr0; - dr2_ptr = dr1; - dr3_ptr = dr2; - dr0 = dr1; - dr1 = dr2; - } else { - dr0 = dr2; - dr1 = dr3; - } - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (j + 3 > h_in) { - switch (j + 3 - h_in) { - case 3: - dr1_ptr = zero; - case 2: - dr2_ptr = zero; - case 1: - dr3_ptr = zero; - default: - break; - } - } - //! process bottom remain - if (j + 2 > h_out) { - doutr1 = trash_buf; - } -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [vsix] "w"(vsix), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [six_ptr] "r"(six), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -void conv_depthwise_3x3s1p1_bias_s_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; -#ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); -#endif - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - const float *dr0 = din_channel; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - for (int j = 0; j < h_out; j += 2) { - const float *dr0_ptr = dr0; - const float *dr1_ptr = dr1; - const float *dr2_ptr = dr2; - const float *dr3_ptr = dr3; - if (j == 0) { - dr0_ptr = zero; - dr1_ptr = dr0; - dr2_ptr = dr1; - dr3_ptr = dr2; - dr0 = dr1; - dr1 = dr2; - } else { - dr0 = dr2; - dr1 = dr3; - } - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (j + 3 > h_in) { - switch (j + 3 - h_in) { - case 3: - dr1_ptr = zero; - case 2: - dr2_ptr = zero; - case 1: - dr3_ptr = zero; - default: - break; - } - } - //! process bottom remain - if (j + 2 > h_out) { - doutr1 = trash_buf; - } -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [vsix] "w"(vsix), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [six_ptr] "r"(six), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -void conv_depthwise_3x3s1p0_bias_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - -#ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); -#endif - - if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 - tile_w -= 1; - remain = 4; - size_pad_right = 2; - } - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 4: - din_ptr1 = zero_ptr; - case 3: - din_ptr2 = zero_ptr; - case 2: - din_ptr3 = zero_ptr; - case 1: - din_ptr4 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - MID_COMPUTE_S1 MID_RESULT_S1_RELU6 - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU6 "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [vsix] "w"(vsix), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_out; i += 2) { - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1_RELU6 - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU6 "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [six_ptr] "r"(six), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -void conv_depthwise_3x3s1p0_bias_s_relu6(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *six, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - -#ifdef __aarch64__ - float32x4_t vsix = vdupq_n_f32(six); -#endif - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t wbias; - float bias_val = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_val = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float *dr0 = din_channel + j * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 4 > h_in) { - switch (j + 4 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - default: - break; - } - } - if (j + 2 > h_out) { - doutr1 = trash_buf; - } - unsigned int *vmask_ptr = vmask; -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [vzero] "w"(vzero), - [vsix] "w"(vsix), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [six_ptr] "r"(six), - [bias_val] "r"(bias_val), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -void conv_depthwise_3x3s1p0_bias_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - -#ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); -#endif - - if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 - tile_w -= 1; - remain = 4; - size_pad_right = 2; - } - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 4: - din_ptr1 = zero_ptr; - case 3: - din_ptr2 = zero_ptr; - case 2: - din_ptr3 = zero_ptr; - case 1: - din_ptr4 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_LEAKY_RELU "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [vscale] "w"(vscale), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_out; i += 2) { - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1_LEAKY_RELU - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_LEAKY_RELU - "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [scale_ptr] "r"(scale), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -void conv_depthwise_3x3s1p0_bias_s_leakyRelu(float *dout, - const float *din, - const float *weights, - const float *bias, - const float *scale, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - -#ifdef __aarch64__ - float32x4_t vscale = vdupq_n_f32(scale); -#endif - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t wbias; - float bias_val = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_val = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float *dr0 = din_channel + j * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 4 > h_in) { - switch (j + 4 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - default: - break; - } - } - if (j + 2 > h_out) { - doutr1 = trash_buf; - } - unsigned int *vmask_ptr = vmask; -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [vzero] "w"(vzero), - [vscale] "w"(vscale), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU - : [din0] "+r"(din_ptr0), - [din1] "+r"(din_ptr1), - [din2] "+r"(din_ptr2), - [din3] "+r"(din_ptr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [scale_ptr] "r"(scale), - [bias_val] "r"(bias_val), - [out1] "r"(doutr0), - [out2] "r"(doutr1) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} \ No newline at end of file diff --git a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_relu1.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_relu1.cc deleted file mode 100644 index d2a3b1925f..0000000000 --- a/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32_relu1.cc +++ /dev/null @@ -1,2983 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/backends/arm/math/conv_depthwise.h" - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -#ifdef __aarch64__ -#define INIT_S1 \ - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" \ - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" \ - "movi v21.4s, #0x0\n" /* out0 = 0 */ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - -#define LEFT_COMPUTE_S1 \ - "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ /* r0 */ \ - "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * w0[1]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * w0[0]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * w0[2]*/ \ - \ - "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * w1[1]*/ \ - "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ \ - "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16=1234 */ \ - "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ \ - \ - /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ - -#define LEFT_RESULT_S1 \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ /* r5 */ \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "cmp %w[cnt], #1 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "blt 3f \n" - -#define MID_COMPUTE_S1 \ - "1: \n" /* r0 */ \ - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - -#define MID_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define RIGHT_COMPUTE_S1 \ - "3: \n" \ - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \ - "ld1 {v22.4s}, [%[doutr0]] \n" \ - "ld1 {v23.4s}, [%[doutr1]] \n" \ - "ld1 {v24.4s}, [%[doutr2]] \n" \ - "ld1 {v25.4s}, [%[doutr3]] \n" \ - \ - "bif v0.16b, %[vzero].16b, v18.16b \n" \ - "bif v1.16b, %[vzero].16b, v19.16b \n" \ - "bif v2.16b, %[vzero].16b, v18.16b \n" \ - "bif v3.16b, %[vzero].16b, v19.16b \n" \ - \ - "bif v4.16b, %[vzero].16b, v18.16b \n" \ - "bif v5.16b, %[vzero].16b, v19.16b \n" \ - "bif v6.16b, %[vzero].16b, v18.16b \n" \ - "bif v7.16b, %[vzero].16b, v19.16b \n" \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \ - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v8.16b, %[vzero].16b, v18.16b \n" \ - "bif v9.16b, %[vzero].16b, v19.16b \n" \ - "bif v10.16b, %[vzero].16b, v18.16b \n" \ - "bif v11.16b, %[vzero].16b, v19.16b \n" \ - \ - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "ld1 {v18.4s}, [%[rmask]] \n" \ - \ - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \ - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \ - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - -#define RIGHT_RESULT_S1 \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v12.16b, v22.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "bif v15.16b, v25.16b, v18.16b \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define LEFT_RESULT_S1_RELU \ - /* r4 */ \ - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \ - \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \ - \ - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \ - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \ - \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \ - \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \ - "cmp %w[cnt], #1 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - "blt 3f \n" - -#define MID_RESULT_S1_RELU \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - \ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" \ - \ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - \ - /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \ - \ - "subs %w[cnt], %w[cnt], #1 \n" \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" \ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \ - \ - "bne 1b \n" - -#define RIGHT_RESULT_S1_RELU \ - /* r3 */ \ - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v12.16b, v22.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \ - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "st1 {v12.4s}, [%[doutr0]], #16 \n" \ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v13.16b, v23.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \ - \ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \ - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \ - \ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \ - \ - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \ - \ - "bif v14.16b, v24.16b, v18.16b \n" \ - \ - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \ - \ - "st1 {v14.4s}, [%[doutr2]], #16 \n" \ - \ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \ - \ - "bif v15.16b, v25.16b, v18.16b \n" \ - \ - "st1 {v15.4s}, [%[doutr3]], #16 \n" - -#define COMPUTE_S_S1 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s}, [%[din0]], #16\n" \ - "ld1 {v1.4s}, [%[din1]], #16\n" \ - "ld1 {v2.4s}, [%[din2]], #16\n" \ - "ld1 {v3.4s}, [%[din3]], #16\n" \ - \ - "bif v0.16b, %[zero].16b, %[mask].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask].16b\n" \ - "bif v2.16b, %[zero].16b, %[mask].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask].16b\n" \ - \ - "ext v4.16b, %[zero].16b, v0.16b, #12\n" \ - "ext v5.16b, %[zero].16b, v1.16b, #12\n" \ - "ext v6.16b, %[zero].16b, v2.16b, #12\n" \ - "ext v7.16b, %[zero].16b, v3.16b, #12\n" \ - \ - "ext v8.16b, v0.16b, %[zero].16b, #4\n" \ - "ext v9.16b, v1.16b, %[zero].16b, #4\n" \ - "ext v10.16b, v2.16b, %[zero].16b, #4\n" \ - "ext v11.16b, v3.16b, %[zero].16b, #4\n" \ - \ - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" \ - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" \ - \ - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" \ - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" \ - \ - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" \ - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" \ - \ - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" \ - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" \ - \ - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" \ - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" \ - \ - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" \ - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" \ - \ - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" \ - \ - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" \ - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v14.4s\n" \ - "fadd v12.4s, v12.4s, v16.4s\n" \ - \ - "fadd v13.4s, v13.4s, v15.4s\n" \ - "fadd v13.4s, v13.4s, v17.4s\n" \ - \ - "fadd v12.4s, v12.4s, %[bias].4s\n" \ - "fadd v13.4s, v13.4s, %[bias].4s\n" - -#define RESULT_S_S1 \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" - -#define RESULT_S_S1_RELU \ - "prfm pldl1keep, [%[out1]]\n" \ - "prfm pldl1keep, [%[out2]]\n" \ - \ - "fmax v12.4s, v12.4s, %[zero].4s\n" \ - "fmax v13.4s, v13.4s, %[zero].4s\n" \ - \ - "st1 {v12.4s}, [%[out1]]\n" \ - "st1 {v13.4s}, [%[out2]]\n" - -#define COMPUTE_S_S1_P0 \ - "prfm pldl1keep, [%[din0]]\n" \ - "prfm pldl1keep, [%[din1]]\n" \ - "prfm pldl1keep, [%[din2]]\n" \ - "prfm pldl1keep, [%[din3]]\n" \ - \ - "ld1 {v0.4s, v1.4s}, [%[din0]]\n" \ - "ld1 {v2.4s, v3.4s}, [%[din1]]\n" \ - "ld1 {v4.4s, v5.4s}, [%[din2]]\n" \ - "ld1 {v6.4s, v7.4s}, [%[din3]]\n" \ - \ - "bif v0.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v1.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "bif v2.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v3.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "bif v4.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v5.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "bif v6.16b, %[zero].16b, %[mask1].16b\n" \ - "bif v7.16b, %[zero].16b, %[mask2].16b\n" \ - \ - "ext v8.16b, v0.16b, v1.16b, #4\n" \ - "ext v9.16b, v0.16b, v1.16b, #8\n" \ - \ - "and v12.16b, %[vbias].16b, %[vbias].16b \n" \ - "and v13.16b, %[vbias].16b, %[vbias].16b \n" /* r0 */ \ - "fmul v10.4s, v0.4s, %[wr0].s[0]\n" \ - "fmul v11.4s, v8.4s, %[wr0].s[1]\n" \ - "fmla v12.4s, v9.4s, %[wr0].s[2]\n" \ - \ - "ext v8.16b, v2.16b, v3.16b, #4\n" \ - "ext v9.16b, v2.16b, v3.16b, #8\n" /* r1 */ \ - "fmul v14.4s, v2.4s, %[wr0].s[0]\n" \ - "fmla v10.4s, v2.4s, %[wr1].s[0]\n" \ - \ - "fmul v15.4s, v8.4s, %[wr0].s[1]\n" \ - "fmla v11.4s, v8.4s, %[wr1].s[1]\n" \ - \ - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" \ - "fmla v12.4s, v9.4s, %[wr1].s[2]\n" \ - \ - "ext v8.16b, v4.16b, v5.16b, #4\n" \ - "ext v9.16b, v4.16b, v5.16b, #8\n" /* r2 */ \ - "fmla v14.4s, v4.4s, %[wr1].s[0]\n" \ - "fmla v10.4s, v4.4s, %[wr2].s[0]\n" \ - \ - "fmla v15.4s, v8.4s, %[wr1].s[1]\n" \ - "fmla v11.4s, v8.4s, %[wr2].s[1]\n" \ - \ - "fmla v13.4s, v9.4s, %[wr1].s[2]\n" \ - "fmla v12.4s, v9.4s, %[wr2].s[2]\n" \ - \ - "ext v8.16b, v6.16b, v7.16b, #4\n" \ - "ext v9.16b, v6.16b, v7.16b, #8\n" \ - \ - "fmla v14.4s, v6.4s, %[wr2].s[0]\n" \ - \ - "fmla v15.4s, v8.4s, %[wr2].s[1]\n" \ - \ - "fadd v12.4s, v12.4s, v10.4s\n" \ - \ - "fmla v13.4s, v9.4s, %[wr2].s[2]\n" \ - \ - "fadd v12.4s, v12.4s, v11.4s\n" \ - "fadd v13.4s, v13.4s, v14.4s\n" \ - "fadd v13.4s, v13.4s, v15.4s\n" // \ - // "prfm pldl1keep, [%[out1]]\n" \ - // "prfm pldl1keep, [%[out2]]\n" \ - // \ - // "st1 {v12.4s}, [%[out1]]\n" \ - // "st1 {v13.4s}, [%[out2]]\n" \ - - -#else -#define INIT_S1 \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" \ - "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" \ - "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" \ - "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" \ - \ - "vdup.32 q4, %[bias_val] @ and \n" \ - "vdup.32 q5, %[bias_val] @ and \n" - -#define LEFT_COMPUTE_S1 \ - "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" \ - "vext.32 q7, q8, q9, #1 @ 1234\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" \ - "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" \ - "vext.32 q7, q10, q11, #1 @ 1234\n" \ - \ - /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" \ - "vext.32 q7, q12, q13, #1 @ 1234\n" \ - \ - /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" \ - "vext.32 q7, q14, q15, #1 @ 1234\n" - -#define LEFT_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define MID_COMPUTE_S1 \ - "1: @ right pad entry\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "pld [%[din0_ptr]] @ preload data\n" \ - "pld [%[din1_ptr]] @ preload data\n" \ - "pld [%[din2_ptr]] @ preload data\n" \ - "pld [%[din3_ptr]] @ preload data\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" - -#define MID_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define RIGHT_COMPUTE_S1 \ - "3: @ right pad entry\n" \ - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ - \ - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" \ - \ - "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" \ - \ - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" \ - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" \ - \ - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" - -#define RIGHT_RESULT_S1 \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define LEFT_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "cmp %[cnt], #1 @ check whether has mid cols\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - "blt 3f @ jump to main loop start point\n" - -#define MID_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" \ - "vdup.32 q4, %[bias_val] @ and \n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \ - \ - "subs %[cnt], #1 @ loop count minus 1\n" \ - \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "bne 1b @ jump to main loop start point\n" - -#define RIGHT_RESULT_S1_RELU \ - /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vmax.f32 q4, q4, %q[vzero] @ relu \n" \ - \ - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d8, d16, d19 @ bit select, deal with right pad\n" \ - "vbif d9, d17, d23 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \ - \ - "vmax.f32 q5, q5, %q[vzero] @ relu \n" \ - \ - "vbif d10, d20, d19 @ bit select, deal with right pad\n" \ - "vbif d11, d21, d23 @ bit select, deal with right pad\n" \ - \ - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" - -#define COMPUTE_S_S1 \ - "pld [%[din0]]\n" \ - "pld [%[din1]]\n" \ - "pld [%[din2]]\n" \ - "pld [%[din3]]\n" \ - \ - "vld1.32 {d12-d13}, [%[din0]]!\n" \ - "vld1.32 {d14-d15}, [%[din1]]!\n" \ - "vld1.32 {d16-d17}, [%[din2]]!\n" \ - "vld1.32 {d18-d19}, [%[din3]]!\n" \ - \ - "vbif q6, %q[vzero], %q[mask]\n" \ - "vbif q7, %q[vzero], %q[mask]\n" \ - "vbif q8, %q[vzero], %q[mask]\n" \ - "vbif q9, %q[vzero], %q[mask]\n" \ - \ - "vmul.f32 q14, q6, %e[wr0][1]\n" \ - "vmul.f32 q15, q7, %e[wr0][1]\n" \ - \ - "vmla.f32 q14, q7, %e[wr1][1]\n" \ - "vmla.f32 q15, q8, %e[wr1][1]\n" \ - \ - "vmla.f32 q14, q8, %e[wr2][1]\n" \ - "vmla.f32 q15, q9, %e[wr2][1]\n" \ - \ - "vext.32 q10, %q[vzero], q6, #3\n" \ - "vext.32 q11, %q[vzero], q7, #3\n" \ - "vext.32 q12, %q[vzero], q8, #3\n" \ - "vext.32 q13, %q[vzero], q9, #3\n" \ - \ - "vmla.f32 q14, q10, %e[wr0][0]\n" \ - "vmla.f32 q15, q11, %e[wr0][0]\n" \ - \ - "vmla.f32 q14, q11, %e[wr1][0]\n" \ - "vmla.f32 q15, q12, %e[wr1][0]\n" \ - \ - "vmla.f32 q14, q12, %e[wr2][0]\n" \ - "vmla.f32 q15, q13, %e[wr2][0]\n" \ - \ - "vext.32 q10, q6, %q[vzero], #1\n" \ - "vext.32 q11, q7, %q[vzero], #1\n" \ - "vext.32 q12, q8, %q[vzero], #1\n" \ - "vext.32 q13, q9, %q[vzero], #1\n" \ - \ - "vmla.f32 q14, q10, %f[wr0][0]\n" \ - "vmla.f32 q15, q11, %f[wr0][0]\n" \ - \ - "vmla.f32 q14, q11, %f[wr1][0]\n" \ - "vmla.f32 q15, q12, %f[wr1][0]\n" \ - \ - "vmla.f32 q14, q12, %f[wr2][0]\n" \ - "vmla.f32 q15, q13, %f[wr2][0]\n" \ - \ - "vadd.f32 q14, q14, %q[bias]\n" \ - "vadd.f32 q15, q15, %q[bias]\n" - -#define RESULT_S_S1 \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define RESULT_S_S1_RELU \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vmax.f32 q14, q14, %q[vzero]\n" \ - "vmax.f32 q15, q15, %q[vzero]\n" \ - \ - "vst1.32 {d28-d29}, [%[out1]]\n" \ - "vst1.32 {d30-d31}, [%[out2]]\n" - -#define COMPUTE_S_S1_P0 \ - "pld [%[din0]]\n" \ - "pld [%[din1]]\n" \ - "pld [%[din2]]\n" \ - "pld [%[din3]]\n" \ - "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" \ - "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" \ - "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" \ - "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" \ - \ - "vdup.32 q4, %[bias_val] @ and \n" \ - "vdup.32 q5, %[bias_val] @ and \n" \ - \ - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \ - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \ - \ - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \ - \ - "vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \ - \ - "vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \ - \ - "vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \ - "vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vext.32 q6, q8, q9, #1 @ 1234\n" \ - "vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \ - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \ - "vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \ - "vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \ - \ - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \ - \ - "vext.32 q6, q10, q11, #1 @ 1234\n" \ - "vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \ - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \ - "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \ - "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q12, q13, #1 @ 1234\n" \ - "vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \ - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \ - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - \ - "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \ - "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \ - \ - "vext.32 q6, q14, q15, #1 @ 1234\n" \ - "vext.32 q7, q14, q15, #2 @ 2345\n" /* r3 */ \ - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \ - \ - "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \ - "vadd.f32 q4, q4, q10 @ q4 += q10 \n" \ - \ - "pld [%[out1]]\n" \ - "pld [%[out2]]\n" \ - \ - "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \ - "vadd.f32 q14, q4, q11 @ q4 += q10 \n" \ - \ - "vadd.f32 q5, q5, q8 @ q4 += q10 \n" \ - "vadd.f32 q15, q5, q9 @ q4 += q10 \n" - -#endif -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ - void conv_depthwise_3x3s1p1_bias_no_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - int cnt_col = tile_w - 1; - - unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in); - const unsigned int remian_idx[4] = {0, 1, 2, 3}; - - if (remain == 0 && size_pad_right == 5) { - size_pad_right = 1; - cnt_col -= 1; - remain = 4; - } else if (remain == 0 && size_pad_right == 6) { - size_pad_right = 2; - cnt_col -= 1; - remain = 4; - } - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - // unsigned int* rst_mask = rmask; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1 - MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1 - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -void conv_depthwise_3x3s1p1_bias_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - int cnt_col = tile_w - 1; - - unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in); - const unsigned int remian_idx[4] = {0, 1, 2, 3}; - - if (remain == 0 && size_pad_right == 5) { - size_pad_right = 1; - cnt_col -= 1; - remain = 4; - } else if (remain == 0 && size_pad_right == 6) { - size_pad_right = 2; - cnt_col -= 1; - remain = 4; - } - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - // unsigned int* rst_mask = rmask; - - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile( - INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 - MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p1_bias_s_no_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - int hs = -1; - int he = 3; - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - int h_cnt = (h_out + 1) >> 1; - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_cnt; ++j) { - const float *dr0 = din_channel + hs * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - if (hs == -1) { - dr0 = zero; - } - - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; - } -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile(COMPUTE_S_S1 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - hs += 2; - he += 2; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} -void conv_depthwise_3x3s1p1_bias_s_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - int hs = -1; - int he = 3; - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - int h_cnt = (h_out + 1) >> 1; - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_cnt; ++j) { - const float *dr0 = din_channel + hs * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - if (hs == -1) { - dr0 = zero; - } - - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; - } -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - hs += 2; - he += 2; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ -void conv_depthwise_3x3s1p0_bias_no_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 4: - din_ptr1 = zero_ptr; - case 3: - din_ptr2 = zero_ptr; - case 2: - din_ptr3 = zero_ptr; - case 1: - din_ptr4 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - MID_COMPUTE_S1 MID_RESULT_S1 - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1 "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_out; i += 2) { - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 3 >= h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - case 0: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1 - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1 "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} - -void conv_depthwise_3x3s1p0_bias_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float *zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float *write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int c = 0; c < ch_in; c++) { - float *dout_ptr = dout_batch + c * size_out_channel; - - const float *din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float *wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float *doutr0 = dout_ptr; - float *doutr1 = doutr0 + w_out; - float *doutr2 = doutr1 + w_out; - float *doutr3 = doutr2 + w_out; - - const float *dr0 = din_ch_ptr; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - const float *dr4 = dr3 + w_in; - const float *dr5 = dr4 + w_in; - - const float *din_ptr0 = dr0; - const float *din_ptr1 = dr1; - const float *din_ptr2 = dr2; - const float *din_ptr3 = dr3; - const float *din_ptr4 = dr4; - const float *din_ptr5 = dr5; - - float *ptr_zero = const_cast(zero); -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 4: - din_ptr1 = zero_ptr; - case 3: - din_ptr2 = zero_ptr; - case 2: - din_ptr3 = zero_ptr; - case 1: - din_ptr4 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - INIT_S1 - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - MID_COMPUTE_S1 MID_RESULT_S1_RELU - "cmp %w[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } -#else - for (int i = 0; i < h_out; i += 2) { - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - - doutr0 = dout_ptr; - doutr1 = dout_ptr + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 3 >= h_in) { - switch (i + 3 - h_in) { - case 3: - din_ptr1 = zero_ptr; - case 2: - din_ptr2 = zero_ptr; - case 1: - din_ptr3 = zero_ptr; - case 0: - din_ptr3 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int *rmask_ptr = rmask; - unsigned int *vmask_ptr = vmask; - asm volatile(INIT_S1 - "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" - "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" - "vext.32 q6, q8, q9, #1 @ 0012\n" - "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 - MID_RESULT_S1_RELU - "cmp %[remain], #1 \n" - "blt 0f \n" RIGHT_COMPUTE_S1 - RIGHT_RESULT_S1_RELU "0: \n" - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din_ptr0), - [din1_ptr] "+r"(din_ptr1), - [din2_ptr] "+r"(din_ptr2), - [din3_ptr] "+r"(din_ptr3), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_ptr += 2 * w_out; - } //! end of processing mid rows -#endif - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p0_bias_s_no_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - -#ifdef __aarch64__ - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } -#endif // __aarch64__ - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float *dr0 = din_channel + j * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - if (j + 2 > h_out) { - doutr1 = trash_buf; - } - default: - break; - } - } -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - unsigned int *vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -void conv_depthwise_3x3s1p0_bias_s_relu(float *dout, - const float *din, - const float *weights, - const float *bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext *ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float *din_batch = din + n * ch_in * size_in_channel; - float *dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float *dout_channel = dout_batch + i * size_out_channel; - const float *din_channel = din_batch + i * size_in_channel; - const float *weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - -#ifdef __aarch64__ - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } -#endif // __aarch64__ - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float *doutr0 = dout_channel; - float *doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float *dr0 = din_channel + j * w_in; - const float *dr1 = dr0 + w_in; - const float *dr2 = dr1 + w_in; - const float *dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - if (j + 2 > h_out) { - doutr1 = trash_buf; - } - default: - break; - } - } -#ifdef __aarch64__ - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - unsigned int *vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_new.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_new.cc new file mode 100644 index 0000000000..ea7d010f06 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_new.cc @@ -0,0 +1,2912 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_depthwise.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void conv_depthwise_3x3s2p0_bias_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias_s_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias_s_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_rlu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_s_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_s_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2_fp32(const float* din, + float* dout, + int num, + int ch_out, + int h_out, + int w_out, + int ch_in, + int h_in, + int w_in, + const float* weights, + const float* bias, + int pad, + bool flag_bias, + const operators::ActivationParam act_param, + ARMContext* ctx) { + bool has_active = act_param.has_active; + auto act_type = act_param.active_type; + float tmp = act_param.Relu_clipped_coef; + float ss = act_param.Leaky_relu_alpha; + float vsix[4] = {tmp, tmp, tmp, tmp}; + float vscale[4] = {ss, ss, ss, ss}; + if (has_active) { + switch (act_type) { + case lite_api::ActivationType::kRelu: + if (pad == 0) { + if (w_in > 8) { + conv_depthwise_3x3s2p0_bias_relu(dout, + din, + weights, + bias, + flag_bias, + true, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p0_bias_s_relu(dout, + din, + weights, + bias, + flag_bias, + true, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + if (pad == 1) { + if (w_in > 7) { + conv_depthwise_3x3s2p1_bias_relu(dout, + din, + weights, + bias, + flag_bias, + true, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_relu(dout, + din, + weights, + bias, + flag_bias, + true, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + break; + case lite_api::ActivationType::kRelu6: + if (pad == 0) { + if (w_in > 8) { + conv_depthwise_3x3s2p0_bias_relu6(dout, + din, + weights, + bias, + vsix, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p0_bias_s_relu6(dout, + din, + weights, + bias, + vsix, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + if (pad == 1) { + if (w_in > 7) { + conv_depthwise_3x3s2p1_bias_relu6(dout, + din, + weights, + bias, + vsix, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_relu6(dout, + din, + weights, + bias, + vsix, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + break; + case lite_api::ActivationType::kLeakyRelu: + if (pad == 0) { + if (w_in >8) { + conv_depthwise_3x3s2p0_bias_leakyRelu(dout, + din, + weights, + bias, + vscale, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p0_bias_s_leakyRelu(dout, + din, + weights, + bias, + vscale, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + if (pad == 1) { + if (w_in > 7) { + conv_depthwise_3x3s2p1_bias_leakyRelu(dout, + din, + weights, + bias, + vscale, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_leakyRelu(dout, + din, + weights, + bias, + vscale, + flag_bias, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + break; + default: + LOG(FATAL) << "this act_type: " << static_cast(act_type) + << " fuse not support"; + } + } else { + if (pad == 0) { + if (w_in > 8) { + conv_depthwise_3x3s2p0_bias_no_relu(dout, + din, + weights, + bias, + flag_bias, + false, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p0_bias_s_no_relu(dout, + din, + weights, + bias, + flag_bias, + false, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + if (pad == 1) { + if (w_in > 7) { + conv_depthwise_3x3s2p1_bias_no_relu(dout, + din, + weights, + bias, + flag_bias, + false, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } else { + conv_depthwise_3x3s2p1_bias_s_no_relu(dout, + din, + weights, + bias, + flag_bias, + false, + num, + ch_in, + h_in, + w_in, + h_out, + w_out, + ctx); + } + } + } +} +#ifdef __aarch64__ +#define INIT_S2 \ + "prfm pldl1keep, [%[inptr0]] \n" \ + "prfm pldl1keep, [%[inptr1]] \n" \ + "prfm pldl1keep, [%[inptr2]] \n" \ + "prfm pldl1keep, [%[inptr3]] \n" \ + "prfm pldl1keep, [%[inptr4]] \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" + +#define LEFT_COMPUTE_S2 \ + "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[1] \n" /* {0,2,4,6} * w01 */ \ + "fmul v12.4s, v1.4s, %[w0].s[2] \n" /* {1,3,5,7} * w02 */ \ + "fmla v16.4s, v10.4s, %[w0].s[0] \n" /* {0,1,3,5} * w00*/ \ + \ + "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" /* v10 = {0,1,3,5} */ \ + \ + "sub %[inptr0], %[inptr0], #4 \n" \ + "sub %[inptr1], %[inptr1], #4 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[1] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" \ + \ + "sub %[inptr2], %[inptr2], #4 \n" \ + "sub %[inptr3], %[inptr3], #4 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[1] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[1] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[2] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[2] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[0] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" \ + \ + "sub %[inptr4], %[inptr4], #4 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[1] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" + +#define LEFT_RESULT_S2 \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define MID_COMPUTE_S2 \ + "2: \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, v18.16b, #4 \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, v19.16b, #4 \n" \ + \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, v20.16b, #4 \n" \ + \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, v21.16b, #4 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" + +#define MID_RESULT_S2 \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define RIGHT_COMPUTE_S2 \ + "1: \n" \ + "cmp %w[remain], #1 \n" \ + "blt 4f \n" \ + "3: \n" \ + "bif v0.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v1.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "bif v2.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v3.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "bif v4.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v5.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" \ + \ + "bif v6.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v7.16b, %[vzero].16b, %[mask2].16b \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" \ + "bif v8.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v9.16b, %[vzero].16b, %[mask2].16b \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" \ + "ld1 {v0.4s}, [%[outptr0]] \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" \ + "ld1 {v1.4s}, [%[outptr1]] \n" /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" + +#define RIGHT_RESULT_S2 \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define LEFT_RESULT_S2_RELU \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" +#define LEFT_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" \ + "ld1 {v22.4s}, [%[six_ptr]] \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define LEFT_RESULT_S2_LEAKY_RELU \ + "ld1 {v22.4s}, [%[scale_ptr]] \n" \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_f32 */ \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define MID_RESULT_S2_RELU \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define MID_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define MID_RESULT_S2_LEAKY_RELU \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define RIGHT_RESULT_S2_RELU \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define RIGHT_RESULT_S2_RELU6 \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "fmin v16.4s, v16.4s, v22.4s \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "fmin v17.4s, v17.4s, v22.4s \n" \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define RIGHT_RESULT_S2_LEAKY_RELU \ + "fcmge v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v16.4s, v22.4s \n" \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ + "fmul v12.4s, v17.4s, v22.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + "bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define COMPUTE_S_S2 \ + "movi v9.4s, #0 \n" \ + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ + \ + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \ + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \ + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \ + \ + "bif v10.16b, v9.16b, v6.16b \n" \ + "bif v11.16b, v9.16b, v7.16b \n" \ + "bif v12.16b, v9.16b, v6.16b \n" \ + "bif v13.16b, v9.16b, v7.16b \n" \ + "bif v14.16b, v9.16b, v6.16b \n" \ + "bif v15.16b, v9.16b, v7.16b \n" \ + \ + "ext v6.16b, v9.16b, v11.16b, #12 \n" \ + "ext v7.16b, v9.16b, v13.16b, #12 \n" \ + "ext v8.16b, v9.16b, v15.16b, #12 \n" \ + \ + "fmul v4.4s, v10.4s, %[wr0].s[1] \n" \ + "fmul v5.4s, v11.4s, %[wr0].s[2] \n" \ + "fmul v6.4s, v6.4s, %[wr0].s[0] \n" \ + \ + "fmla v4.4s, v12.4s, %[wr1].s[1] \n" \ + "fmla v5.4s, v13.4s, %[wr1].s[2] \n" \ + "fmla v6.4s, v7.4s, %[wr1].s[0] \n" \ + \ + "fmla v4.4s, v14.4s, %[wr2].s[1] \n" \ + "fmla v5.4s, v15.4s, %[wr2].s[2] \n" \ + "fmla v6.4s, v8.4s, %[wr2].s[0] \n" \ + \ + "fadd v4.4s, v4.4s, v5.4s \n" \ + "fadd v4.4s, v4.4s, v6.4s \n" + +#define RESULT_S_S2 \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + \ + "st1 {v4.4s}, [%[out]] \n" + +#define RESULT_S_S2_RELU \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + \ + "st1 {v4.4s}, [%[out]] \n" +#define RESULT_S_S2_RELU6 \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + "fmin v4.4s, v4.4s, %[vsix].4s \n" \ + \ + "st1 {v4.4s}, [%[out]] \n" +#define RESULT_S_S2_LEAKY_RELU \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + "fcmge v11.4s, v4.4s, %[vzero].4s \n"/* vcgeq_u32 */\ + "fmul v12.4s, v4.4s, %[vscale].4s \n"\ + "bif v4.16b, v12.16b, v11.16b \n" /* choose*/ \ + "st1 {v4.4s}, [%[out]] \n" +#define COMPUTE_S_S2_P0 \ + "movi v9.4s, #0 \n" \ + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ + \ + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \ + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \ + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \ + "and v4.16b, %[bias].16b, %[bias].16b \n" \ + \ + "bif v10.16b, v9.16b, v6.16b \n" \ + "bif v11.16b, v9.16b, v7.16b \n" \ + "bif v12.16b, v9.16b, v6.16b \n" \ + "bif v13.16b, v9.16b, v7.16b \n" \ + "bif v14.16b, v9.16b, v6.16b \n" \ + "bif v15.16b, v9.16b, v7.16b \n" \ + \ + "ext v6.16b, v10.16b, v9.16b, #4 \n" \ + "ext v7.16b, v12.16b, v9.16b, #4 \n" \ + "ext v8.16b, v14.16b, v9.16b, #4 \n" \ + \ + "fmla v4.4s, v10.4s, %[wr0].s[0] \n" \ + "fmul v5.4s, v11.4s, %[wr0].s[1] \n" \ + "fmul v16.4s, v6.4s, %[wr0].s[2] \n" \ + \ + "fmla v4.4s, v12.4s, %[wr1].s[0] \n" \ + "fmla v5.4s, v13.4s, %[wr1].s[1] \n" \ + "fmla v16.4s, v7.4s, %[wr1].s[2] \n" \ + \ + "fmla v4.4s, v14.4s, %[wr2].s[0] \n" \ + "fmla v5.4s, v15.4s, %[wr2].s[1] \n" \ + "fmla v16.4s, v8.4s, %[wr2].s[2] \n" \ + \ + "fadd v4.4s, v4.4s, v5.4s \n" \ + "fadd v4.4s, v4.4s, v16.4s \n" + +#define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n" +#define RESULT_S_S2_P0_RELU \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + "st1 {v4.4s}, [%[out]] \n" +#define RESULT_S_S2_P0_RELU6 \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + "fmin v4.4s, v4.4s, %[vsix].4s \n" \ + "st1 {v4.4s}, [%[out]] \n" +#define RESULT_S_S2_P0_LEAKY_RELU \ + "fcmge v11.4s, v4.4s, %[vzero].4s \n"/* vcgeq_u32 */\ + "fmul v12.4s, v4.4s, %[vscale].4s \n"\ + "bif v4.16b, v12.16b, v11.16b \n" /* choose*/ \ + "st1 {v4.4s}, [%[out]] \n" + +#else +#define INIT_S2 \ + "vmov.u32 q9, #0 \n" \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + \ + "vdup.32 q3, %[bias] @ and \n" + +#define LEFT_COMPUTE_S2 \ + "vext.32 q6, q9, q11, #3 @ shift right 1 data\n" \ + "vext.32 q7, q9, q13, #3 @ shift right 1 data\n" \ + "vext.32 q8, q9, q15, #3 @ shift right 1 data\n" \ + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, out0\n" \ + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, out0\n" \ + \ + "sub %[din0_ptr], #4 @ inpitr0 - 1\n" \ + "sub %[din1_ptr], #4 @ inpitr1 - 1\n" \ + "sub %[din2_ptr], #4 @ inpitr2 - 1\n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \ + \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, out1\n" \ + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \ + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \ + \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" + +#define LEFT_RESULT_S2 \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "cmp %[cnt], #1 \n" \ + "blt 1f \n" + +#define MID_COMPUTE_S2 \ + "2: \n" \ + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + "vext.32 q6, q10, q8, #1 @ shift left 1 \n" \ + "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vext.32 q7, q12, q8, #1 @ shift left 1 \n" \ + "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vext.32 q6, q14, q8, #1 @ shift left 1 \n" \ + \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ + \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" + +#define MID_RESULT_S2 \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "subs %[cnt], #1 \n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "bne 2b \n" + +#define RIGHT_COMPUTE_S2 \ + "1: \n" \ + "cmp %[remain], #1 \n" \ + "blt 3f \n" \ + \ + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" \ + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vext.32 q6, q14, q9, #1 @ shift left 1 \n" \ + "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" + +#define RIGHT_RESULT_S2 \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" \ + "vbif.f32 q3, q10, q11 @ write mask\n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "3: \n" + +#define LEFT_RESULT_S2_RELU \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vmax.f32 q3, q3, q9 \n"\ + "cmp %[cnt], #1 \n"\ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "blt 1f \n" +#define LEFT_RESULT_S2_RELU6 \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vmax.f32 q3, q3, q9 @ relu \n"\ + "cmp %[cnt], #1 \n"\ + "vmin.f32 q3, q3, q6 @ relu \n"\ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "blt 1f \n" +#define LEFT_RESULT_S2_LEAKY_RELU \ + "vadd.f32 q3, q3, q4 \n"\ + "vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\ + "vadd.f32 q3, q3, q5 \n"\ + "vcge.f32 q7, q3, q9 \n"\ + "vmul.f32 q8, q3, q6 \n"\ + "cmp %[cnt], #1 \n"\ + "vbif q3, q8, q7 @ choose \n"\ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "blt 1f \n" +#define MID_RESULT_S2_RELU \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "subs %[cnt], #1 \n"\ + "vmax.f32 q3, q3, q9 @ relu \n"\ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "bne 2b \n" + +#define MID_RESULT_S2_RELU6 \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vmax.f32 q3, q3, q9 @ relu \n"\ + "subs %[cnt], #1 \n"\ + "vmin.f32 q3, q3, q6 @ relu \n"\ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "bne 2b \n" +#define MID_RESULT_S2_LEAKY_RELU \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vcge.f32 q7, q3, q9 \n"\ + "vmul.f32 q8, q3, q6 \n"\ + "subs %[cnt], #1 \n"\ + "vbif q3, q8, q7 @ choose \n"\ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "bne 2b \n" + +#define RIGHT_RESULT_S2_RELU \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vmax.f32 q3, q3, q9 @ relu\n"\ + "vbif.f32 q3, q10, q11 @ write mask\n"\ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "3: \n" + +#define RIGHT_RESULT_S2_RELU6 \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vmax.f32 q3, q3, q9 @ relu\n"\ + "vmin.f32 q3, q3, q6 @ relu \n"\ + \ + "vbif.f32 q3, q10, q11 @ write mask\n"\ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "3: \n" +#define RIGHT_RESULT_S2_LEAKY_RELU \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vcge.f32 q7, q3, q9 \n"\ + "vmul.f32 q8, q3, q6 \n"\ + "vbif q3, q8, q7 @ choose \n"\ + "vbif.f32 q3, q10, q11 @ write mask\n"\ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n"\ + "3: \n" +#define COMPUTE_S_S2 \ + "vmov.u32 q9, #0 \n" \ + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q9, q11, #3 @ shift left 1 \n" \ + "vext.32 q7, q9, q13, #3 @ shift left 1 \n" \ + "vext.32 q8, q9, q15, #3 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, out0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n" + +#define RESULT_S_S2 \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vst1.32 {d6-d7}, [%[out]] \n" +#define RESULT_S_S2_RELU \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vmax.f32 q3, q3, q9 @ relu\n"\ + \ + "vst1.32 {d6-d7}, [%[out]] \n" + +#define RESULT_S_S2_RELU6 \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vmax.f32 q3, q3, q9 @ relu\n"\ + "vmin.f32 q3, q3, q6 @ relu\n"\ + \ + "vst1.32 {d6-d7}, [%[out]] \n" +#define RESULT_S_S2_LEAKY_RELU \ + "vadd.f32 q3, q3, q4 @ add \n"\ + "vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\ + "vadd.f32 q3, q3, q5 @ add \n"\ + "vcge.f32 q7, q3, q9 \n"\ + "vmul.f32 q8, q3, q6 \n"\ + "vbif q3, q8, q7 @ choose \n"\ + \ + "vst1.32 {d6-d7}, [%[out]] \n" +#define COMPUTE_S_S2_P0 \ + "vmov.u32 q9, #0 \n" \ + "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" \ + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" \ + "vext.32 q8, q14, q9, #1 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n" + +#define RESULT_S_S2_P0 \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" \ + "vst1.32 {d6-d7}, [%[out]] \n" +#define RESULT_S_S2_P0_RELU \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "vst1.32 {d6-d7}, [%[out]] \n" +#define RESULT_S_S2_P0_RELU6 \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n" \ + "vadd.f32 q3, q3, q5 @ add \n" \ + "vmax.f32 q3, q3, q9 @ relu\n" \ + "vmin.f32 q3, q3, q6 @ relu\n" \ + "vst1.32 {d6-d7}, [%[out]] \n" +#define RESULT_S_S2_P0_RELU \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vld1.f32 {d12-d13}, [%[scale_ptr]] @ load six \n" \ + "vadd.f32 q3, q3, q5 @ add \n" \ + "vcge.f32 q7, q3, q9 \n" \ + "vmul.f32 q8, q3, q6 \n" \ + "vbif q3, q8, q7 @ choose \n" \ + "vst1.32 {d6-d7}, [%[out]] \n" +#endif + +/** + * \brief depthwise convolution kernel 3x3, stride 2 + * w_in > 7 + */ +void conv_depthwise_3x3s2p1_bias_rlu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + int size_pad_bottom = h_out * 2 - h_in; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + int cnt_col = tile_w - 1; + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + din4_ptr = dr3; + dr0 = dr3; + dr1 = dr4; + } else { + dr0 = dr4; + dr1 = dr0 + w_in; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i * 2 + 4 > h_in) { + switch (i * 2 + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = cnt_col; + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2 + MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6 + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [six_ptr] "r"(six), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + + //! process bottom pad + if (i * 2 + 2 > h_in) { + switch (i * 2 + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; + unsigned int* mask_ptr = dmask; + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2 + MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2_RELU6 RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [six_ptr] "r"(six), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} +void conv_depthwise_3x3s2p1_bias_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + int size_pad_bottom = h_out * 2 - h_in; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + int cnt_col = tile_w - 1; + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + din4_ptr = dr3; + dr0 = dr3; + dr1 = dr4; + } else { + dr0 = dr4; + dr1 = dr0 + w_in; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i * 2 + 4 > h_in) { + switch (i * 2 + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = cnt_col; + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU MID_COMPUTE_S2 + MID_RESULT_S2_LEAKY_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [scale_ptr] "r"(scale), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + + //! process bottom pad + if (i * 2 + 2 > h_in) { + switch (i * 2 + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; + unsigned int* mask_ptr = dmask; + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU MID_COMPUTE_S2 + MID_RESULT_S2_LEAKY_RELU RIGHT_COMPUTE_S2_LEAKY_RELU RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [scale_ptr] "r"(scale), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ +void conv_depthwise_3x3s2p1_bias_s_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); +#ifdef __aarch64__ + float32x4_t vsix = vld1q_f32(six); +#endif + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + int hs = -1; + int he = 2; + float out_buf[4]; + for (int j = 0; j < h_out; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + if (hs == -1) { + dr0 = zeros; + } + if (he > h_in) { + dr2 = zeros; + } + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU6 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [vsix] "w"(vsix), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU6 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [six_ptr] "r"(six), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + hs += 2; + he += 2; + } + } + } +} +void conv_depthwise_3x3s2p1_bias_s_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); +#ifdef __aarch64__ + float32x4_t vscale = vld1q_f32(scale); +#endif + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + int hs = -1; + int he = 2; + float out_buf[4]; + for (int j = 0; j < h_out; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + if (hs == -1) { + dr0 = zeros; + } + if (he > h_in) { + dr2 = zeros; + } + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2 RESULT_S_S2_LEAKY_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [vscale] "w"(vscale), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(COMPUTE_S_S2 RESULT_S_S2_LEAKY_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [scale_ptr] "r"(scale), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + hs += 2; + he += 2; + } + } + } +} +/** + * \brief depthwise convolution kernel 3x3, stride 2 + */ +// w_in > 7 +void conv_depthwise_3x3s2p0_bias_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + + unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + dr0 = dr4; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i * 2 + 5 > h_in) { + switch (i * 2 + 5 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + case 0: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = tile_w; + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2 + MID_RESULT_S2_RELU6 + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_RELU6 + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [six_ptr] "r"(six), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + //! process bottom pad + if (i * 2 + 3 > h_in) { + switch (i * 2 + 3 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = tile_w; + unsigned int* mask_ptr = dmask; + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU6 + RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [six_ptr] "r"(six), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +void conv_depthwise_3x3s2p0_bias_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + + unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + dr0 = dr4; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i * 2 + 5 > h_in) { + switch (i * 2 + 5 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + case 0: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = tile_w; + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2 + MID_RESULT_S2_LEAKY_RELU + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_LEAKY_RELU + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [scale_ptr] "r"(scale), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21", + "v22"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + //! process bottom pad + if (i * 2 + 3 > h_in) { + switch (i * 2 + 3 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = tile_w; + unsigned int* mask_ptr = dmask; + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU + RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [scale_ptr] "r"(scale), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ +void conv_depthwise_3x3s2p0_bias_s_relu6(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* six, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); +#ifdef __aarch64__ + float32x4_t vsix = vld1q_f32(six); +#endif + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + float out_buf[4]; + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + for (int j = 0; j < h_out; j++) { + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + if (j * 2 + 2 >= h_in) { + switch (j + 2 - h_in) { + case 1: + din1_ptr = zero_ptr; + case 0: + din2_ptr = zero_ptr; + default: + break; + } + } + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU6 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [vsix] "w"(vsix), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + +#else + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU6 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [six_ptr] "r"(six), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + } + } + } +} +void conv_depthwise_3x3s2p0_bias_s_leakyRelu(float* dout, + const float* din, + const float* weights, + const float* bias, + const float* scale, + bool flag_bias, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); +#ifdef __aarch64__ + float32x4_t vscale = vld1q_f32(scale); +#endif + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + float out_buf[4]; + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + for (int j = 0; j < h_out; j++) { + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + if (j * 2 + 2 >= h_in) { + switch (j + 2 - h_in) { + case 1: + din1_ptr = zero_ptr; + case 0: + din2_ptr = zero_ptr; + default: + break; + } + } + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_LEAKY_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [vscale] "w"(vscale), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); + +#else + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_LEAKY_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [scale_ptr] "r"(scale), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + } + } + } +} +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_relu_new.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_relu_new.cc new file mode 100644 index 0000000000..e2da19c631 --- /dev/null +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32_relu_new.cc @@ -0,0 +1,2213 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/arm/math/conv_depthwise.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +#define INIT_S2 \ + "prfm pldl1keep, [%[inptr0]] \n" \ + "prfm pldl1keep, [%[inptr1]] \n" \ + "prfm pldl1keep, [%[inptr2]] \n" \ + "prfm pldl1keep, [%[inptr3]] \n" \ + "prfm pldl1keep, [%[inptr4]] \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" + +#define LEFT_COMPUTE_S2 \ + "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[1] \n" /* {0,2,4,6} * w01 */ \ + "fmul v12.4s, v1.4s, %[w0].s[2] \n" /* {1,3,5,7} * w02 */ \ + "fmla v16.4s, v10.4s, %[w0].s[0] \n" /* {0,1,3,5} * w00*/ \ + \ + "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" /* v10 = {0,1,3,5} */ \ + \ + "sub %[inptr0], %[inptr0], #4 \n" \ + "sub %[inptr1], %[inptr1], #4 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[1] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" \ + \ + "sub %[inptr2], %[inptr2], #4 \n" \ + "sub %[inptr3], %[inptr3], #4 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[1] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[1] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[2] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[2] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[0] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" \ + \ + "sub %[inptr4], %[inptr4], #4 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[1] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[0] \n" \ + \ + "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" + +#define LEFT_RESULT_S2 \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define MID_COMPUTE_S2 \ + "2: \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, v18.16b, #4 \n" \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, v19.16b, #4 \n" \ + \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, v20.16b, #4 \n" \ + \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, v21.16b, #4 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" + +#define MID_RESULT_S2 \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define RIGHT_COMPUTE_S2 \ + "1: \n" \ + "cmp %w[remain], #1 \n" \ + "blt 4f \n" \ + "3: \n" \ + "bif v0.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v1.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "bif v2.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v3.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "bif v4.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v5.16b, %[vzero].16b, %[mask2].16b \n" \ + \ + "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" \ + \ + "bif v6.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v7.16b, %[vzero].16b, %[mask2].16b \n" /* r0 */ \ + "fmul v11.4s, v0.4s, %[w0].s[0] \n" \ + "fmul v12.4s, v1.4s, %[w0].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w0].s[2] \n" \ + \ + "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" \ + "bif v8.16b, %[vzero].16b, %[mask1].16b \n" \ + "bif v9.16b, %[vzero].16b, %[mask2].16b \n" /* r1 */ \ + "fmla v11.4s, v2.4s, %[w1].s[0] \n" \ + "fmla v12.4s, v3.4s, %[w1].s[1] \n" \ + "fmla v16.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" /* r2 */ \ + "fmul v13.4s, v4.4s, %[w0].s[0] \n" \ + "fmla v11.4s, v4.4s, %[w2].s[0] \n" \ + \ + "fmul v14.4s, v5.4s, %[w0].s[1] \n" \ + "fmla v12.4s, v5.4s, %[w2].s[1] \n" \ + \ + "fmla v17.4s, v10.4s, %[w0].s[2] \n" \ + "fmla v16.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" /* r3 */ \ + "fmla v13.4s, v6.4s, %[w1].s[0] \n" \ + "fmla v14.4s, v7.4s, %[w1].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w1].s[2] \n" \ + \ + "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" \ + "ld1 {v0.4s}, [%[outptr0]] \n" \ + \ + "fadd v16.4s, v16.4s, v11.4s \n" \ + "fadd v16.4s, v16.4s, v12.4s \n" \ + "ld1 {v1.4s}, [%[outptr1]] \n" + +#define RIGHT_RESULT_S2 \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define LEFT_RESULT_S2_RELU \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[1] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[2] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[0] \n" \ + \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" \ + \ + "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ + "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \ + "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" \ + \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "cmp %w[cnt], #1 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "blt 1f \n" + +#define MID_RESULT_S2_RELU \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ + "ld1 {v15.4s}, [%[inptr0]] \n" \ + "ld1 {v18.4s}, [%[inptr1]] \n" \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "ld1 {v19.4s}, [%[inptr2]] \n" \ + "ld1 {v20.4s}, [%[inptr3]] \n" \ + "ld1 {v21.4s}, [%[inptr4]] \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "ext v10.16b, v0.16b, v15.16b, #4 \n" \ + "and v16.16b, %[vbias].16b, %[vbias].16b \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + \ + "and v17.16b, %[vbias].16b, %[vbias].16b \n" \ + \ + "bne 2b \n" + +#define RIGHT_RESULT_S2_RELU \ + /* r4 */ \ + "fmla v13.4s, v8.4s, %[w2].s[0] \n" \ + "fmla v14.4s, v9.4s, %[w2].s[1] \n" \ + "fmla v17.4s, v10.4s, %[w2].s[2] \n" \ + \ + "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ + \ + "fadd v17.4s, v17.4s, v13.4s \n" \ + \ + "bif v16.16b, v0.16b, %[wmask].16b \n" \ + \ + "fadd v17.4s, v17.4s, v14.4s \n" \ + \ + "st1 {v16.4s}, [%[outptr0]], #16 \n" \ + \ + "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \ + \ + "bif v17.16b, v1.16b, %[wmask].16b \n" \ + \ + "st1 {v17.4s}, [%[outptr1]], #16 \n" \ + "4: \n" + +#define COMPUTE_S_S2 \ + "movi v9.4s, #0 \n" \ + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ + \ + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \ + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \ + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \ + \ + "bif v10.16b, v9.16b, v6.16b \n" \ + "bif v11.16b, v9.16b, v7.16b \n" \ + "bif v12.16b, v9.16b, v6.16b \n" \ + "bif v13.16b, v9.16b, v7.16b \n" \ + "bif v14.16b, v9.16b, v6.16b \n" \ + "bif v15.16b, v9.16b, v7.16b \n" \ + \ + "ext v6.16b, v9.16b, v11.16b, #12 \n" \ + "ext v7.16b, v9.16b, v13.16b, #12 \n" \ + "ext v8.16b, v9.16b, v15.16b, #12 \n" \ + \ + "fmul v4.4s, v10.4s, %[wr0].s[1] \n" \ + "fmul v5.4s, v11.4s, %[wr0].s[2] \n" \ + "fmul v6.4s, v6.4s, %[wr0].s[0] \n" \ + \ + "fmla v4.4s, v12.4s, %[wr1].s[1] \n" \ + "fmla v5.4s, v13.4s, %[wr1].s[2] \n" \ + "fmla v6.4s, v7.4s, %[wr1].s[0] \n" \ + \ + "fmla v4.4s, v14.4s, %[wr2].s[1] \n" \ + "fmla v5.4s, v15.4s, %[wr2].s[2] \n" \ + "fmla v6.4s, v8.4s, %[wr2].s[0] \n" \ + \ + "fadd v4.4s, v4.4s, v5.4s \n" \ + "fadd v4.4s, v4.4s, v6.4s \n" + +#define RESULT_S_S2 \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + \ + "st1 {v4.4s}, [%[out]] \n" + +#define RESULT_S_S2_RELU \ + "fadd v4.4s, v4.4s, %[bias].4s \n" \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + \ + "st1 {v4.4s}, [%[out]] \n" + +#define COMPUTE_S_S2_P0 \ + "movi v9.4s, #0 \n" \ + "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ + \ + "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \ + "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \ + "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \ + "and v4.16b, %[bias].16b, %[bias].16b \n" \ + \ + "bif v10.16b, v9.16b, v6.16b \n" \ + "bif v11.16b, v9.16b, v7.16b \n" \ + "bif v12.16b, v9.16b, v6.16b \n" \ + "bif v13.16b, v9.16b, v7.16b \n" \ + "bif v14.16b, v9.16b, v6.16b \n" \ + "bif v15.16b, v9.16b, v7.16b \n" \ + \ + "ext v6.16b, v10.16b, v9.16b, #4 \n" \ + "ext v7.16b, v12.16b, v9.16b, #4 \n" \ + "ext v8.16b, v14.16b, v9.16b, #4 \n" \ + \ + "fmla v4.4s, v10.4s, %[wr0].s[0] \n" \ + "fmul v5.4s, v11.4s, %[wr0].s[1] \n" \ + "fmul v16.4s, v6.4s, %[wr0].s[2] \n" \ + \ + "fmla v4.4s, v12.4s, %[wr1].s[0] \n" \ + "fmla v5.4s, v13.4s, %[wr1].s[1] \n" \ + "fmla v16.4s, v7.4s, %[wr1].s[2] \n" \ + \ + "fmla v4.4s, v14.4s, %[wr2].s[0] \n" \ + "fmla v5.4s, v15.4s, %[wr2].s[1] \n" \ + "fmla v16.4s, v8.4s, %[wr2].s[2] \n" \ + \ + "fadd v4.4s, v4.4s, v5.4s \n" \ + "fadd v4.4s, v4.4s, v16.4s \n" + +#define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n" + +#define RESULT_S_S2_P0_RELU \ + "fmax v4.4s, v4.4s, v9.4s \n" \ + "st1 {v4.4s}, [%[out]] \n" + +#else +#define INIT_S2 \ + "vmov.u32 q9, #0 \n" \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \ + "pld [%[din0_ptr]] @ preload data\n" \ + "pld [%[din1_ptr]] @ preload data\n" \ + "pld [%[din2_ptr]] @ preload data\n" \ + \ + "vdup.32 q3, %[bias] @ and \n" + +#define LEFT_COMPUTE_S2 \ + "vext.32 q6, q9, q11, #3 @ shift right 1 data\n" \ + "vext.32 q7, q9, q13, #3 @ shift right 1 data\n" \ + "vext.32 q8, q9, q15, #3 @ shift right 1 data\n" \ + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, out0\n" \ + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, out0\n" \ + \ + "sub %[din0_ptr], #4 @ inpitr0 - 1\n" \ + "sub %[din1_ptr], #4 @ inpitr1 - 1\n" \ + "sub %[din2_ptr], #4 @ inpitr2 - 1\n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \ + \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, out1\n" \ + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \ + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \ + \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define LEFT_RESULT_S2 \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "cmp %[cnt], #1 \n" \ + "blt 1f \n" + +#define MID_COMPUTE_S2 \ + "2: \n" \ + "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + "vext.32 q6, q10, q8, #1 @ shift left 1 \n" \ + "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vext.32 q7, q12, q8, #1 @ shift left 1 \n" \ + "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vext.32 q6, q14, q8, #1 @ shift left 1 \n" \ + \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ + \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define MID_RESULT_S2 \ + "subs %[cnt], #1 \n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "bne 2b \n" + +#define RIGHT_COMPUTE_S2 \ + "1: \n" \ + "cmp %[remain], #1 \n" \ + "blt 3f \n" \ + \ + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" \ + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vext.32 q6, q14, q9, #1 @ shift left 1 \n" \ + "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define RIGHT_RESULT_S2 \ + "vbif.f32 q3, q10, q11 @ write mask\n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "3: \n" + +#define LEFT_RESULT_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "cmp %[cnt], #1 \n" \ + "blt 1f \n" + +#define MID_RESULT_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "subs %[cnt], #1 \n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "bne 2b \n" + +#define RIGHT_RESULT_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "vbif.f32 q3, q10, q11 @ write mask\n" \ + \ + "vst1.32 {d6-d7}, [%[outptr]]! \n" \ + "3: \n" + +#define COMPUTE_S_S2 \ + "vmov.u32 q9, #0 \n" \ + "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q9, q11, #3 @ shift left 1 \n" \ + "vext.32 q7, q9, q13, #3 @ shift left 1 \n" \ + "vext.32 q8, q9, q15, #3 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, out0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n" + +#define RESULT_S_S2_RELU \ + "vmax.f32 q3, q3, q9 @ relu\n" \ + \ + "vst1.32 {d6-d7}, [%[out]] \n" + +#define COMPUTE_S_S2_P0 \ + "vmov.u32 q9, #0 \n" \ + "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \ + "vdup.32 q3, %[bias] @ and \n" \ + \ + "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \ + "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \ + "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ + \ + "vbif q10, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q11, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q12, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q13, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + "vbif q14, q9, q6 @ bit select, deal with " \ + "right pad\n" \ + "vbif q15, q9, q7 @ bit select, deal with " \ + "right pad\n" \ + \ + "vext.32 q6, q10, q9, #1 @ shift left 1 \n" \ + "vext.32 q7, q12, q9, #1 @ shift left 1 \n" \ + "vext.32 q8, q14, q9, #1 @ shift left 1 \n" \ + \ + "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \ + "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \ + "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \ + \ + "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \ + "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \ + "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \ + \ + "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ + "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ + "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n" \ + \ + "vadd.f32 q3, q3, q4 @ add \n" \ + "vadd.f32 q3, q3, q5 @ add \n" + +#define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n" + +#define RESULT_S_S2_P0_RELU \ + "vmax.f32 q3, q3, q9 @ relu \n" \ + "vst1.32 {d6-d7}, [%[out]] \n" + +#endif + +/** + * \brief depthwise convolution kernel 3x3, stride 2 + * w_in > 7 + */ +void conv_depthwise_3x3s2p1_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + int size_pad_bottom = h_out * 2 - h_in; + + int cnt_col = (w_out >> 2) - 2; + int size_right_remain = w_in - (7 + cnt_col * 8); + if (size_right_remain >= 9) { + cnt_col++; + size_right_remain -= 8; + } + int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // + + int size_right_pad = w_out * 2 - w_in; + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_in; i += 4) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + din4_ptr = dr3; + dr0 = dr3; + dr1 = dr4; + } else { + dr0 = dr4; + dr1 = dr0 + w_in; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i + 4 > h_in) { + switch (i + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i / 2 + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = cnt_col; + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 + MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_in; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; + unsigned int* mask_ptr = dmask; + asm volatile( + INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 + MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +void conv_depthwise_3x3s2p1_bias_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + int size_pad_bottom = h_out * 2 - h_in; + + int cnt_col = (w_out >> 2) - 2; + int size_right_remain = w_in - (7 + cnt_col * 8); + if (size_right_remain >= 9) { + cnt_col++; + size_right_remain -= 8; + } + int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // + + int size_right_pad = w_out * 2 - w_in; + + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_in; i += 4) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + din3_ptr = dr2; + din4_ptr = dr3; + dr0 = dr3; + dr1 = dr4; + } else { + dr0 = dr4; + dr1 = dr0 + w_in; + } + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i + 4 > h_in) { + switch (i + 4 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i / 2 + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = cnt_col; + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_in; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + if (i == 0) { + din0_ptr = zero_ptr; + din1_ptr = dr0; + din2_ptr = dr1; + dr0 = dr1; + dr1 = dr2; + dr2 = dr1 + w_in; + } else { + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + } + + //! process bottom pad + if (i + 2 > h_in) { + switch (i + 2 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = cnt_col; + unsigned int* mask_ptr = dmask; + asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 + MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ +void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + int hs = -1; + int he = 2; + float out_buf[4]; + for (int j = 0; j < h_out; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + if (hs == -1) { + dr0 = zeros; + } + if (he > h_in) { + dr2 = zeros; + } + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + hs += 2; + he += 2; + } + } + } +} +void conv_depthwise_3x3s2p1_bias_s_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + int hs = -1; + int he = 2; + float out_buf[4]; + for (int j = 0; j < h_out; ++j) { + const float* dr0 = din_channel + hs * w_in; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + if (hs == -1) { + dr0 = zeros; + } + if (he > h_in) { + dr2 = zeros; + } + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(COMPUTE_S_S2 RESULT_S_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + hs += 2; + he += 2; + } + } + } +} + +/** + * \brief depthwise convolution kernel 3x3, stride 2 + */ +// w_in > 7 +void conv_depthwise_3x3s2p0_bias_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + + unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + dr0 = dr4; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i * 2 + 5 > h_in) { + switch (i * 2 + 5 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + case 0: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = tile_w; + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2_RELU + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2_RELU + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + //! process bottom pad + if (i * 2 + 3 > h_in) { + switch (i * 2 + 3 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = tile_w; + unsigned int* mask_ptr = dmask; + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU + RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} +void conv_depthwise_3x3s2p0_bias_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + + int tile_w = w_out >> 2; + int cnt_remain = w_out % 4; + + unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in); + size_right_remain = 8 - size_right_remain; + + if (cnt_remain == 0 && size_right_remain == 0) { + cnt_remain = 4; + tile_w -= 1; + size_right_remain = 8; + } + uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), + vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + uint32x4_t wmask = + vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + float* zero_ptr = ctx->workspace_data(); + memset(zero_ptr, 0, w_in * sizeof(float)); + float* write_ptr = zero_ptr + w_in; + + unsigned int dmask[12]; + + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + vst1q_u32(dmask + 8, wmask); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float32x4_t vzero = vdupq_n_f32(0.f); + +#ifdef __aarch64__ + float32x4_t wbias; + if (flag_bias) { + wbias = vdupq_n_f32(bias[i]); + } else { + wbias = vdupq_n_f32(0.f); + } +#else + float bias_c = 0.f; + if (flag_bias) { + bias_c = bias[i]; + } +#endif // __aarch64__ + + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + const float* dr3 = dr2 + w_in; + const float* dr4 = dr3 + w_in; + + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + const float* din3_ptr = dr3; + const float* din4_ptr = dr4; + + float* doutr0 = dout_channel; + float* doutr0_ptr = nullptr; + float* doutr1_ptr = nullptr; + +#ifdef __aarch64__ + for (int i = 0; i < h_out; i += 2) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + din3_ptr = dr3; + din4_ptr = dr4; + + doutr0_ptr = doutr0; + doutr1_ptr = doutr0 + w_out; + + dr0 = dr4; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + dr3 = dr2 + w_in; + dr4 = dr3 + w_in; + + //! process bottom pad + if (i * 2 + 5 > h_in) { + switch (i * 2 + 5 - h_in) { + case 4: + din1_ptr = zero_ptr; + case 3: + din2_ptr = zero_ptr; + case 2: + din3_ptr = zero_ptr; + case 1: + din4_ptr = zero_ptr; + case 0: + din4_ptr = zero_ptr; + default: + break; + } + } + //! process output pad + if (i + 2 > h_out) { + doutr1_ptr = write_ptr; + } + int cnt = tile_w; + asm volatile( + INIT_S2 + "ld1 {v15.4s}, [%[inptr0]] \n" + "ld1 {v18.4s}, [%[inptr1]] \n" + "ld1 {v19.4s}, [%[inptr2]] \n" + "ld1 {v20.4s}, [%[inptr3]] \n" + "ld1 {v21.4s}, [%[inptr4]] \n" + "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} + MID_COMPUTE_S2 MID_RESULT_S2 + "cmp %w[remain], #1 \n" + "blt 4f \n" RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 + "4: \n" + : [inptr0] "+r"(din0_ptr), + [inptr1] "+r"(din1_ptr), + [inptr2] "+r"(din2_ptr), + [inptr3] "+r"(din3_ptr), + [inptr4] "+r"(din4_ptr), + [outptr0] "+r"(doutr0_ptr), + [outptr1] "+r"(doutr1_ptr), + [cnt] "+r"(cnt) + : [vzero] "w"(vzero), + [w0] "w"(wr0), + [w1] "w"(wr1), + [w2] "w"(wr2), + [remain] "r"(cnt_remain), + [mask1] "w"(vmask_rp1), + [mask2] "w"(vmask_rp2), + [wmask] "w"(wmask), + [vbias] "w"(wbias) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16", + "v17", + "v18", + "v19", + "v20", + "v21"); + doutr0 = doutr0 + 2 * w_out; + } +#else + for (int i = 0; i < h_out; i++) { + din0_ptr = dr0; + din1_ptr = dr1; + din2_ptr = dr2; + + doutr0_ptr = doutr0; + + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + //! process bottom pad + if (i * 2 + 3 > h_in) { + switch (i * 2 + 3 - h_in) { + case 2: + din1_ptr = zero_ptr; + case 1: + din2_ptr = zero_ptr; + default: + break; + } + } + int cnt = tile_w; + unsigned int* mask_ptr = dmask; + asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 + RIGHT_RESULT_S2 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [outptr] "+r"(doutr0_ptr), + [cnt] "+r"(cnt), + [mask_ptr] "+r"(mask_ptr) + : [remain] "r"(cnt_remain), + [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); + doutr0 = doutr0 + w_out; + } +#endif + } + } +} +/** + * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 + */ +void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + float out_buf[4]; + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + for (int j = 0; j < h_out; j++) { + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + if (j * 2 + 2 >= h_in) { + switch (j + 2 - h_in) { + case 1: + din1_ptr = zero_ptr; + case 0: + din2_ptr = zero_ptr; + default: + break; + } + } + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + } + } + } +} +void conv_depthwise_3x3s2p0_bias_s_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx) { + int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + int out_pad_idx[4] = {0, 1, 2, 3}; + float zeros[8] = {0.0f}; + const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; + + uint32x4_t vmask_rp1 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 + uint32x4_t vmask_rp2 = + vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 + + int size_in_channel = w_in * h_in; + int size_out_channel = w_out * h_out; + + unsigned int dmask[8]; + vst1q_u32(dmask, vmask_rp1); + vst1q_u32(dmask + 4, vmask_rp2); + + for (int n = 0; n < num; ++n) { + const float* din_batch = din + n * ch_in * size_in_channel; + float* dout_batch = dout + n * ch_in * size_out_channel; +#pragma omp parallel for + for (int i = 0; i < ch_in; ++i) { + const float* din_channel = din_batch + i * size_in_channel; + float* dout_channel = dout_batch + i * size_out_channel; + + const float* weight_ptr = weights + i * 9; + float32x4_t wr0 = vld1q_f32(weight_ptr); + float32x4_t wr1 = vld1q_f32(weight_ptr + 3); + float32x4_t wr2 = vld1q_f32(weight_ptr + 6); + + float bias_c = 0.f; + + if (flag_bias) { + bias_c = bias[i]; + } + float32x4_t vbias = vdupq_n_f32(bias_c); + float out_buf[4]; + const float* dr0 = din_channel; + const float* dr1 = dr0 + w_in; + const float* dr2 = dr1 + w_in; + for (int j = 0; j < h_out; j++) { + const float* din0_ptr = dr0; + const float* din1_ptr = dr1; + const float* din2_ptr = dr2; + if (j * 2 + 2 >= h_in) { + switch (j + 2 - h_in) { + case 1: + din1_ptr = zero_ptr; + case 0: + din2_ptr = zero_ptr; + default: + break; + } + } + dr0 = dr2; + dr1 = dr0 + w_in; + dr2 = dr1 + w_in; + + unsigned int* mask_ptr = dmask; +#ifdef __aarch64__ + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [mask_ptr] "+r"(mask_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "w"(vbias), + [out] "r"(out_buf) + : "cc", + "memory", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15", + "v16"); +#else + asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr) + : [wr0] "w"(wr0), + [wr1] "w"(wr1), + [wr2] "w"(wr2), + [bias] "r"(bias_c), + [out] "r"(out_buf), + [mask_ptr] "r"(dmask) + : "cc", + "memory", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15"); +#endif + for (int w = 0; w < w_out; ++w) { + *dout_channel++ = out_buf[w]; + } + } + } + } +} +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise.h b/lite/backends/arm/math/conv_depthwise.h index ff239e325f..58e0543170 100644 --- a/lite/backends/arm/math/conv_depthwise.h +++ b/lite/backends/arm/math/conv_depthwise.h @@ -324,61 +324,116 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, ARMContext* ctx); void conv_depthwise_3x3s1p0_bias_no_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); void conv_depthwise_3x3s1p0_bias_s_no_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); void conv_depthwise_3x3s1p1_bias_no_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); void conv_depthwise_3x3s1p1_bias_s_no_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - bool flag_relu, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p0_bias_s_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); + +void conv_depthwise_3x3s2p1_bias_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); +void conv_depthwise_3x3s2p1_bias_s_no_relu(float* dout, + const float* din, + const float* weights, + const float* bias, + bool flag_bias, + bool flag_relu, + const int num, + const int ch_in, + const int h_in, + const int w_in, + const int h_out, + const int w_out, + ARMContext* ctx); } // namespace math } // namespace arm -- GitLab