diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index cea1bf04677a3cc19ded5a20311518155760de69..a38afd5503843f95ef9716a6183539017a188ccc 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -57,8 +57,8 @@ endif() if (NOT HAS_ARM_MATH_LIB_DIR) # TODO(xxx): seperate them and do not deps proto, eigen3 - cc_library(math_arm SRCS - funcs.cc + cc_library(math_arm SRCS + funcs.cc packed_sgemm.cc packed_sgemm_c4.cc sgemm.cc @@ -68,8 +68,10 @@ if (NOT HAS_ARM_MATH_LIB_DIR) gemv_arm_int8.cc conv3x3s1_direct_fp32.cc conv3x3s2_direct_fp32.cc - conv3x3s1_depthwise_fp32.cc - conv3x3s2_depthwise_fp32.cc + conv3x3s1p01_depthwise_fp32.cc + conv3x3s2p01_depthwise_fp32.cc + conv3x3s1px_depthwise_fp32.cc + conv3x3s2px_depthwise_fp32.cc conv3x3s1_direct_int8.cc conv3x3s2_direct_int8.cc conv3x3s1_depthwise_int8.cc @@ -77,16 +79,13 @@ if (NOT HAS_ARM_MATH_LIB_DIR) conv5x5s1_depthwise_int8.cc conv5x5s1_depthwise_fp32.cc conv5x5s2_depthwise_fp32.cc - conv_depthwise_3x3p0.cc - conv_depthwise_3x3p1.cc - conv_depthwise_3x3s1.cc - conv_depthwise_3x3s2.cc conv_winograd_3x3.cc conv_impl.cc - softmax.cc + softmax.cc scale.cc pooling.cc elementwise.cc + layout.cc lrn.cc decode_bboxes.cc concat.cc @@ -122,4 +121,3 @@ if (NOT HAS_ARM_MATH_LIB_DIR) anchor_generator.cc DEPS ${lite_kernel_deps} context tensor) endif() - diff --git a/lite/backends/arm/math/conv_depthwise_3x3s1.cc b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc similarity index 100% rename from lite/backends/arm/math/conv_depthwise_3x3s1.cc rename to lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc index 8d0ebb58ad1b7e325bae3649b13914641021038f..e4c9fb99ef9a6b5d3987a1efd5a644f322ea043c 100644 --- a/lite/backends/arm/math/conv_depthwise_3x3s1.cc +++ b/lite/backends/arm/math/conv3x3s1p01_depthwise_fp32.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/arm/math/conv_depthwise.h" #include +#include "lite/backends/arm/math/conv_depthwise.h" namespace paddle { namespace lite { diff --git a/lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc similarity index 100% rename from lite/backends/arm/math/conv3x3s1_depthwise_fp32.cc rename to lite/backends/arm/math/conv3x3s1px_depthwise_fp32.cc diff --git a/lite/backends/arm/math/conv_depthwise_3x3s2.cc b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc similarity index 100% rename from lite/backends/arm/math/conv_depthwise_3x3s2.cc rename to lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc index ec039af98cb7e4fb037475dd4e5ee29204252165..455781e37e0747950e6740f6db45c1ce8c0e96c8 100644 --- a/lite/backends/arm/math/conv_depthwise_3x3s2.cc +++ b/lite/backends/arm/math/conv3x3s2p01_depthwise_fp32.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/backends/arm/math/conv_depthwise.h" #include +#include "lite/backends/arm/math/conv_depthwise.h" namespace paddle { namespace lite { diff --git a/lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc b/lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc similarity index 100% rename from lite/backends/arm/math/conv3x3s2_depthwise_fp32.cc rename to lite/backends/arm/math/conv3x3s2px_depthwise_fp32.cc diff --git a/lite/backends/arm/math/conv_depthwise_3x3p0.cc b/lite/backends/arm/math/conv_depthwise_3x3p0.cc deleted file mode 100644 index 0c050ffe6fb0f064f5c26ea0da6acee17f4403ae..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise_3x3p0.cc +++ /dev/null @@ -1,4178 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_depthwise.h" -#include - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s1p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3p0_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - if (stride == 1) { - if (flag_relu) { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 5) { - conv_depthwise_3x3s1p0_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p0_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } else { //! stride = 2 - if (flag_relu) { - if (w_in > 8) { - conv_depthwise_3x3s2p0_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p0_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 8) { - conv_depthwise_3x3s2p0_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p0_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ -// 4line -void conv_depthwise_3x3s1p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - // wr0 = vsetq_lane_f32(0.f, wr0, 3); - // wr1 = vsetq_lane_f32(0.f, wr1, 3); - // wr2 = vsetq_lane_f32(0.f, wr2, 3); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - - // mid - // "cmp %[cnt], #1 \n" - // "blt 5f \n" - "4: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r5 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 4b \n" - - // right - "5: \n" - "cmp %[remain], #1 \n" - "blt 0f \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v18.4s}, [%[rmask]] \n" - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - // end - "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_out; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 3 >= h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - case 0: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n" - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "cmp %[remain], #1 @ check whether has " - "mid cols\n" - "blt 0f @ jump to main loop start " - "point\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - "0: \n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2 - */ -// w_in > 7 -void conv_depthwise_3x3s2p0_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - - int tile_w = w_out >> 2; - int cnt_remain = w_out % 4; - - unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - dr0 = dr4; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 >= h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - case 0: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = tile_w; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v13.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - for (int i = 0; i < h_out; i++) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = tile_w; - unsigned int* mask_ptr = dmask; - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - // mid - "2: \n" - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "subs %[cnt], #1 \n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} - -// 4line -void conv_depthwise_3x3s1p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = w_out >> 2; - int remain = w_out % 4; - - unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); - const int remian_idx[4] = {0, 1, 2, 3}; - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - // wr0 = vsetq_lane_f32(0.f, wr0, 3); - // wr1 = vsetq_lane_f32(0.f, wr1, 3); - // wr2 = vsetq_lane_f32(0.f, wr2, 3); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_out; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 >= h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - case 0: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = tile_w; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ - - // mid - "4: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - // r5 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 4b \n" - - // right - "5: \n" - "cmp %[remain], #1 \n" - "blt 0f \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - "ld1 {v18.4s}, [%[rmask]] \n" - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - // end - "0: \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_out; i += 2) { - //! process top pad pad_h = 1 - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - //! process bottom pad - if (i + 3 >= h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - case 0: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = tile_w; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n" - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "cmp %[remain], #1 @ check whether has " - "mid cols\n" - "blt 0f @ jump to main loop start " - "point\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - "0: \n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero), - [remain] "r"(remain) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, with reulu - */ -// w_in > 7 -void conv_depthwise_3x3s2p0_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - - int tile_w = w_out >> 2; - int cnt_remain = w_out % 4; - - unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_out; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - dr0 = dr4; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 >= h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - case 0: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = tile_w; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v14.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - for (int i = 0; i < h_out; i++) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = tile_w; - unsigned int* mask_ptr = dmask; - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - // mid - "2: \n" - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "subs %[cnt], #1 \n" - "vmax.f32 q3, q3, q9 @ relu \n" - - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float* dr0 = din_channel + j * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - doutr1 = trash_buf; - default: - break; - } - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s, v1.4s}, [%[din0]]\n" - "ld1 {v2.4s, v3.4s}, [%[din1]]\n" - "ld1 {v4.4s, v5.4s}, [%[din2]]\n" - "ld1 {v6.4s, v7.4s}, [%[din3]]\n" - - "bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234 - - "bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234 - "bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234 - - "bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234 - "bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234 - - "bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234 - "bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234 - - "ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345 - "ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450 - - "and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias - "and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias - - // r0 - "fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345 - "ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450 - - // r1 - "fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - - "fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345 - "ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450 - - // r2 - "fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345 - "ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450 - - // r3 - "fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fadd v12.4s, v12.4s, v10.4s\n" - - "fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "fadd v12.4s, v12.4s, v11.4s\n" // out1 - "fadd v13.4s, v13.4s, v14.4s\n" // out2 - "fadd v13.4s, v13.4s, v15.4s\n" // out2 - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - unsigned int* vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - "vadd.f32 q4, q4, q10 @ q4 += q10 \n" - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - "vadd.f32 q4, q4, q11 @ q4 += q10 \n" - - "vadd.f32 q5, q5, q8 @ q4 += q10 \n" - "vadd.f32 q5, q5, q9 @ q4 += q10 \n" - - "vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 - */ - -void conv_depthwise_3x3s2p0_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - float out_buf[4]; - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - for (int j = 0; j < h_out; ++j) { - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - "and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - - "fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00 - "fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01 - "fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02 - - "fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12 - "fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21 - "fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v16.4s \n" - - // "fadd v4.4s, v4.4s, %[bias].4s \n" - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0} - "vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // {0,2,4,6} - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // {1,3,5,7} - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // {2,4,6,0} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(dmask) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp1 = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); - uint32x4_t vmask_rp2 = - vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_out; j += 2) { - const float* dr0 = din_channel + j * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - doutr0 = dout_channel + j * w_out; - doutr1 = doutr0 + w_out; - - if (j + 3 >= h_in) { - switch (j + 3 - h_in) { - case 3: - dr1 = zero_ptr; - case 2: - dr2 = zero_ptr; - case 1: - dr3 = zero_ptr; - doutr1 = trash_buf; - case 0: - dr3 = zero_ptr; - doutr1 = trash_buf; - default: - break; - } - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s, v1.4s}, [%[din0]]\n" - "ld1 {v2.4s, v3.4s}, [%[din1]]\n" - "ld1 {v4.4s, v5.4s}, [%[din2]]\n" - "ld1 {v6.4s, v7.4s}, [%[din3]]\n" - - "bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234 - - "bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234 - "bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234 - - "bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234 - "bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234 - - "bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234 - "bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234 - - "ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345 - "ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450 - - "and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias - "and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias - - // r0 - "fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345 - "ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450 - - // r1 - "fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - - "fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345 - "ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450 - - // r2 - "fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0] - "fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1] - "fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2] - "fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345 - "ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450 - - // r3 - "fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0] - - "fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1] - - "fadd v12.4s, v12.4s, v10.4s\n" - - "fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2] - - "fadd v12.4s, v12.4s, v11.4s\n" // out1 - "fadd v13.4s, v13.4s, v14.4s\n" // out2 - "fadd v13.4s, v13.4s, v15.4s\n" // out2 - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - "fmax v12.4s, v12.4s, %[zero].4s \n" - "fmax v13.4s, v13.4s, %[zero].4s \n" - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vbias] "w"(wbias), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [zero] "w"(vzero), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); -#else - unsigned int* vmask_ptr = vmask; - float bias_val = flag_bias ? bias[i] : 0.f; - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - "vadd.f32 q4, q4, q10 @ q4 += q10 \n" - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - "vadd.f32 q4, q4, q11 @ q4 += q10 \n" - - "vadd.f32 q5, q5, q8 @ q4 += q10 \n" - "vadd.f32 q5, q5, q9 @ q4 += q10 \n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n" - "vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [vzero] "w"(vzero), - [bias_val] "r"(bias_val), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - // doutr0 = doutr1; - // doutr1 += w_out; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 7 - */ -void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - float out_buf[4]; - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - for (int j = 0; j < h_out; ++j) { - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]] \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - "and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - "ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 = - // {2,4,6,8} - - "fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00 - "fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01 - "fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02 - - "fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12 - "fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21 - "fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v16.4s \n" - "fmax v4.4s, v4.4s, v9.4s \n" - - // "fadd v4.4s, v4.4s, %[bias].4s \n" - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf), - [mask_ptr] "r"(mask_ptr) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0} - "vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // {0,2,4,6} - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // {1,3,5,7} - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // {2,4,6,0} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf), - [mask_ptr] "r"(mask_ptr) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_depthwise_3x3p1.cc b/lite/backends/arm/math/conv_depthwise_3x3p1.cc deleted file mode 100644 index 6f28d48d6d2bdd60e0c33f9b4b753835337fc8a4..0000000000000000000000000000000000000000 --- a/lite/backends/arm/math/conv_depthwise_3x3p1.cc +++ /dev/null @@ -1,4850 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/backends/arm/math/conv_depthwise.h" -#include - -namespace paddle { -namespace lite { -namespace arm { -namespace math { - -void conv_depthwise_3x3s1p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s1p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s1p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3s2p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -//! for input width <= 4 -void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx); - -void conv_depthwise_3x3p1_fp32(const float* din, - float* dout, - int num, - int ch_out, - int h_out, - int w_out, - int ch_in, - int h_in, - int w_in, - const float* weights, - const float* bias, - int stride, - bool flag_bias, - bool flag_relu, - ARMContext* ctx) { - if (stride == 1) { - if (flag_relu) { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 4) { - conv_depthwise_3x3s1p1_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s1p1_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } else { //! stride = 2 - if (flag_relu) { - if (w_in > 7) { - conv_depthwise_3x3s2p1_bias_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s_relu(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } else { - if (w_in > 7) { - conv_depthwise_3x3s2p1_bias(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } else { - conv_depthwise_3x3s2p1_bias_s(dout, - din, - weights, - bias, - flag_bias, - num, - ch_in, - h_in, - w_in, - h_out, - w_out, - ctx); - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width > 4 - */ -// 4line -void conv_depthwise_3x3s1p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - // printf("conv3x3_dw start \n"); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 3) >> 2; - int cnt_col = tile_w - 2; - - unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_in; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - - // left - // r0 - "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * - w0[1]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * - w0[0]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * - w0[2]*/ - - "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * - w1[1]*/ - "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ - - "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ - - // r5 - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ - "cmp %[cnt], #1 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "blt 3f \n" - // mid - "1: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 1b \n" - - // right - "3: \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v18.4s}, [%[rmask]] \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - // unsigned int* rst_mask = rmask; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" - "vext.32 q7, q8, q9, #1 @ 1234\n" - - // left - // r0 - "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" - - "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" - "vext.32 q7, q10, q11, #1 @ 1234\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" - "vext.32 q7, q12, q13, #1 @ 1234\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" - "vext.32 q7, q14, q15, #1 @ 1234\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "cmp %[cnt], #1 @ check whether has " - "mid cols\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - "blt 3f @ jump to main loop start " - "point\n" - - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2 - */ -// w_in > 7 -void conv_depthwise_3x3s2p1_bias(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - int size_pad_bottom = h_out * 2 - h_in; - - int cnt_col = (w_out >> 2) - 2; - int size_right_remain = w_in - (7 + cnt_col * 8); - if (size_right_remain >= 9) { - cnt_col++; - size_right_remain -= 8; - } - int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // - - int size_right_pad = w_out * 2 - w_in; - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - din4_ptr = dr3; - dr0 = dr3; - dr1 = dr4; - } else { - dr0 = dr4; - dr1 = dr0 + w_in; - } - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i / 2 + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = cnt_col; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5} - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr0], %[inptr0], #4 \n" - "sub %[inptr1], %[inptr1], #4 \n" - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr2], %[inptr2], #4 \n" - "sub %[inptr3], %[inptr3], #4 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - - "fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - - "fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - "fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr4], %[inptr4], #4 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5} - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "cmp %[cnt], #1 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "blt 1f \n" - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "subs %[cnt], %[cnt], #1 \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v13.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - for (int i = 0; i < h_in; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = cnt_col; - unsigned int* mask_ptr = dmask; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11 - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13 - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15 - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vext.32 q6, q9, q11, #3 @ shift right 1 " - "data\n" // q2 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "sub %[din0_ptr], #4 @ inpitr0 - 1\n" - "sub %[din1_ptr], #4 @ inpitr1 - 1\n" - "sub %[din2_ptr], #4 @ inpitr2 - 1\n" - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, " - "out1\n" // q0 * w01 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, " - "out1\n" // q1 * w02 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, " - "out1\n" // q2 * w00 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "cmp %[cnt], #1 \n" - "blt 1f \n" - // mid - "2: \n" - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "subs %[cnt], #1 \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} - -// 4line -void conv_depthwise_3x3s1p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! pad is done implicit - const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; - //! for 4x6 convolution window - const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; - - // printf("conv3x3_dw start \n"); - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - int w_stride = 9; - - int tile_w = (w_in + 3) >> 2; - int tile_h = (h_in + 3) >> 2; - int cnt_col = tile_w - 2; - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); - int size_pad_bottom = (unsigned int)(1 + (tile_h << 2) - h_in); - - uint32x4_t vmask_rp1 = - vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_rp2 = - vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); - uint32x4_t vmask_result = - vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); - - unsigned int vmask[8]; - vst1q_u32(vmask, vmask_rp1); - vst1q_u32(vmask + 4, vmask_rp2); - - unsigned int rmask[4]; - vst1q_u32(rmask, vmask_result); - - float32x4_t vzero = vdupq_n_f32(0.f); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for -#ifdef __aarch64__ - for (int c = 0; c < ch_in; c++) { - float* dout_ptr = dout_batch + c * size_out_channel; - - const float* din_ch_ptr = din_batch + c * size_in_channel; - - float bias_val = flag_bias ? bias[c] : 0.f; - float vbias[4] = {bias_val, bias_val, bias_val, bias_val}; - - const float* wei_ptr = weights + c * w_stride; - - float32x4_t wr0 = vld1q_f32(wei_ptr); - float32x4_t wr1 = vld1q_f32(wei_ptr + 3); - float32x4_t wr2 = vld1q_f32(wei_ptr + 6); - - float* doutr0 = dout_ptr; - float* doutr1 = doutr0 + w_out; - float* doutr2 = doutr1 + w_out; - float* doutr3 = doutr2 + w_out; - - const float* dr0 = din_ch_ptr; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - const float* dr5 = dr4 + w_in; - - const float* din_ptr0 = dr0; - const float* din_ptr1 = dr1; - const float* din_ptr2 = dr2; - const float* din_ptr3 = dr3; - const float* din_ptr4 = dr4; - const float* din_ptr5 = dr5; - - for (int i = 0; i < h_in; i += 4) { - //! process top pad pad_h = 1 - din_ptr0 = dr0; - din_ptr1 = dr1; - din_ptr2 = dr2; - din_ptr3 = dr3; - din_ptr4 = dr4; - din_ptr5 = dr5; - - doutr0 = dout_ptr; - doutr1 = doutr0 + w_out; - doutr2 = doutr1 + w_out; - doutr3 = doutr2 + w_out; - if (i == 0) { - din_ptr0 = zero_ptr; - din_ptr1 = dr0; - din_ptr2 = dr1; - din_ptr3 = dr2; - din_ptr4 = dr3; - din_ptr5 = dr4; - dr0 = dr3; - dr1 = dr4; - dr2 = dr5; - } else { - dr0 = dr4; - dr1 = dr5; - dr2 = dr1 + w_in; - } - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - dr5 = dr4 + w_in; - - //! process bottom pad - if (i + 5 > h_in) { - switch (i + 5 - h_in) { - case 5: - din_ptr1 = zero_ptr; - case 4: - din_ptr2 = zero_ptr; - case 3: - din_ptr3 = zero_ptr; - case 2: - din_ptr4 = zero_ptr; - case 1: - din_ptr5 = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 4 > h_out) { - switch (i + 4 - h_out) { - case 3: - doutr1 = write_ptr; - case 2: - doutr2 = write_ptr; - case 1: - doutr3 = write_ptr; - default: - break; - } - } - - int cnt = cnt_col; - asm volatile( - "PRFM PLDL1KEEP, [%[din_ptr0]] \n" - "PRFM PLDL1KEEP, [%[din_ptr1]] \n" - "PRFM PLDL1KEEP, [%[din_ptr2]] \n" - "PRFM PLDL1KEEP, [%[din_ptr3]] \n" - "PRFM PLDL1KEEP, [%[din_ptr4]] \n" - "PRFM PLDL1KEEP, [%[din_ptr5]] \n" - "movi v21.4s, #0x0\n" /* out0 = 0 */ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ - - // left - // r0 - "fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * - w0[1]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * - w0[0]*/ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * - w0[2]*/ - - "ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * - w1[1]*/ - "sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ - "sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ - - "fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */ - - // r4 - "fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - "fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w2[1]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ - "st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w1[1]*/ - - "ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ - "ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - // r5 - "fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * - w1[1]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ - - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * - w0[1]*/ - - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ - - "st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ - "cmp %[cnt], #1 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "blt 3f \n" - // mid - "1: \n" - // r0 - "fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - - "ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ - "ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - "subs %[cnt], %[cnt], #1 \n" - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - "ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ - - "bne 1b \n" - - // right - "3: \n" - "ld1 {v18.4s, v19.4s}, [%[vmask]] \n" - "ld1 {v22.4s}, [%[doutr0]] \n" - "ld1 {v23.4s}, [%[doutr1]] \n" - "ld1 {v24.4s}, [%[doutr2]] \n" - "ld1 {v25.4s}, [%[doutr3]] \n" - - "bif v0.16b, %[vzero].16b, v18.16b \n" - "bif v1.16b, %[vzero].16b, v19.16b \n" - "bif v2.16b, %[vzero].16b, v18.16b \n" - "bif v3.16b, %[vzero].16b, v19.16b \n" - - "bif v4.16b, %[vzero].16b, v18.16b \n" - "bif v5.16b, %[vzero].16b, v19.16b \n" - "bif v6.16b, %[vzero].16b, v18.16b \n" - "bif v7.16b, %[vzero].16b, v19.16b \n" - - "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ - - // r0 - "fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "bif v8.16b, %[vzero].16b, v18.16b \n" - "bif v9.16b, %[vzero].16b, v19.16b \n" - "bif v10.16b, %[vzero].16b, v18.16b \n" - "bif v11.16b, %[vzero].16b, v19.16b \n" - - "fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "ld1 {v18.4s}, [%[rmask]] \n" - - "fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ - - // r1 - "fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ - - // r2 - "fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v12.16b, v22.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ - - // r3 - "fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - "fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "st1 {v12.4s}, [%[doutr0]], #16 \n" - "fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - "fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v13.16b, v23.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - "fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ - "ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ - - "st1 {v13.4s}, [%[doutr1]], #16 \n" - - // r3 - "fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * - w0[0]*/ - - "fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ - - "fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * - w0[1]*/ - - "bif v14.16b, v24.16b, v18.16b \n" - - "fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * - w0[2]*/ - - "st1 {v14.4s}, [%[doutr2]], #16 \n" - - "fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ - - "bif v15.16b, v25.16b, v18.16b \n" - - "st1 {v15.4s}, [%[doutr3]], #16 \n" - : [cnt] "+r"(cnt), - [din_ptr0] "+r"(din_ptr0), - [din_ptr1] "+r"(din_ptr1), - [din_ptr2] "+r"(din_ptr2), - [din_ptr3] "+r"(din_ptr3), - [din_ptr4] "+r"(din_ptr4), - [din_ptr5] "+r"(din_ptr5), - [doutr0] "+r"(doutr0), - [doutr1] "+r"(doutr1), - [doutr2] "+r"(doutr2), - [doutr3] "+r"(doutr3) - : [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [bias_val] "r"(vbias), - [vmask] "r"(vmask), - [rmask] "r"(rmask), - [vzero] "w"(vzero) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21", - "v22", - "v23", - "v24", - "v25"); - dout_ptr = dout_ptr + 4 * w_out; - } - } -#else - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float bias_val = flag_bias ? bias[i] : 0.f; - - float* dout_channel = dout_batch + i * size_out_channel; - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - const float* din0_ptr = nullptr; - const float* din1_ptr = nullptr; - const float* din2_ptr = nullptr; - const float* din3_ptr = nullptr; - - float* doutr0 = nullptr; - float* doutr1 = nullptr; - - float* ptr_zero = const_cast(zero); - - for (int i = 0; i < h_in; i += 2) { - //! process top pad pad_h = 1 - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - - doutr0 = dout_channel; - doutr1 = dout_channel + w_out; - // unsigned int* rst_mask = rmask; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - dr0 = dr1; - dr1 = dr2; - dr2 = dr3; - dr3 = dr2 + w_in; - } else { - dr0 = dr2; - dr1 = dr3; - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - } - //! process bottom pad - if (i + 3 > h_in) { - switch (i + 3 - h_in) { - case 3: - din1_ptr = zero_ptr; - case 2: - din2_ptr = zero_ptr; - case 1: - din3_ptr = zero_ptr; - default: - break; - } - } - //! process bottom remain - if (i + 2 > h_out) { - doutr1 = write_ptr; - } - int cnt = cnt_col; - unsigned int* rmask_ptr = rmask; - unsigned int* vmask_ptr = vmask; - asm volatile( - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" - "vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" - "vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" - - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - - "vext.32 q6, %q[vzero], q8, #3 @ 0012\n" - "vext.32 q7, q8, q9, #1 @ 1234\n" - - // left - // r0 - "vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" - "sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" - - "vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q10, #3 @ 0012\n" - "vext.32 q7, q10, q11, #1 @ 1234\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q12, #3 @ 0012\n" - "vext.32 q7, q12, q13, #1 @ 1234\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, %q[vzero], q14, #3 @ 0012\n" - "vext.32 q7, q14, q15, #1 @ 1234\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "cmp %[cnt], #1 @ check whether has " - "mid cols\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q5 - // = - // vbias - "blt 3f @ jump to main loop start " - "point\n" - - // mid - "1: @ right pad entry\n" - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - "pld [%[din3_ptr]] @ preload data\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - "vdup.32 q4, %[bias_val] @ and \n" // q4 - // = - // vbias - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - "subs %[cnt], #1 @ loop count minus 1\n" - - "vdup.32 q5, %[bias_val] @ and \n" // q4 - // = - // vbias - - "bne 1b @ jump to main loop start " - "point\n" - - // right - "3: @ right pad entry\n" - "vld1.32 {d19}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[vmask]]! @ load din r0\n" - - "vld1.32 {d27}, [%[vmask]]! @ load din r0\n" - "vld1.32 {d31}, [%[vmask]]! @ load din r0\n" - - "vbif d16, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d17, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d18, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vbif d20, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d21, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d22, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vext.32 q6, q8, q9, #1 @ 1234\n" - "vext.32 q7, q8, q9, #2 @ 2345\n" - - // r0 - "vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" - - "vbif d24, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d25, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d26, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d28, %e[vzero], d19 @ bit select, deal with " - "right pad\n" - "vbif d29, %e[vzero], d23 @ bit select, deal with " - "right pad\n" - "vbif d30, %e[vzero], d27 @ bit select, deal with " - "right pad\n" - - "vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" - - "vext.32 q6, q10, q11, #1 @ 1234\n" - "vext.32 q7, q10, q11, #2 @ 2345\n" - - // r1 - "vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d19}, [%[rmask]]! @ load din r0\n" - "vld1.32 {d23}, [%[rmask]]! @ load din r0\n" - - "vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - - "vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" - "vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" - - "vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q12, q13, #1 @ 1234\n" - "vext.32 q7, q12, q13, #2 @ 2345\n" - - // r2 - "vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" - "vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" - "vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" - - "vext.32 q6, q14, q15, #1 @ 1234\n" - "vext.32 q7, q14, q15, #2 @ 2345\n" - - // r3 - "vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" - - "vmax.f32 q4, q4, %q[vzero] @ relu \n" - - "vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" - - "vbif d8, d16, d19 @ bit select, deal with right pad\n" - "vbif d9, d17, d23 @ bit select, deal with right pad\n" - - "vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" - "vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" - - "vmax.f32 q5, q5, %q[vzero] @ relu \n" - - "vbif d10, d20, d19 @ bit select, deal with right " - "pad\n" - "vbif d11, d21, d23 @ bit select, deal with right " - "pad\n" - - "vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add " - "pointer\n" - - : [dout_ptr1] "+r"(doutr0), - [dout_ptr2] "+r"(doutr1), - [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [din3_ptr] "+r"(din3_ptr), - [cnt] "+r"(cnt), - [rmask] "+r"(rmask_ptr), - [vmask] "+r"(vmask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias_val] "r"(bias_val), - [vzero] "w"(vzero) - : "cc", - "memory", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - dout_channel += 2 * w_out; - } //! end of processing mid rows - } -#endif - } -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, with reulu - */ -// w_in > 7 -void conv_depthwise_3x3s2p1_bias_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - int size_pad_bottom = h_out * 2 - h_in; - - int cnt_col = (w_out >> 2) - 2; - int size_right_remain = w_in - (7 + cnt_col * 8); - if (size_right_remain >= 9) { - cnt_col++; - size_right_remain -= 8; - } - int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // - - int size_right_pad = w_out * 2 - w_in; - - uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain), - vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - uint32x4_t wmask = - vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3 - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - float* zero_ptr = ctx->workspace_data(); - memset(zero_ptr, 0, w_in * sizeof(float)); - float* write_ptr = zero_ptr + w_in; - - unsigned int dmask[12]; - - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - vst1q_u32(dmask + 8, wmask); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float32x4_t vzero = vdupq_n_f32(0.f); - - float32x4_t wbias; - float bias_c = 0.f; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - bias_c = bias[i]; - } else { - wbias = vdupq_n_f32(0.f); - } - - const float* dr0 = din_channel; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - const float* dr4 = dr3 + w_in; - - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - const float* din3_ptr = dr3; - const float* din4_ptr = dr4; - - float* doutr0 = dout_channel; - float* doutr0_ptr = nullptr; - float* doutr1_ptr = nullptr; - -#ifdef __aarch64__ - for (int i = 0; i < h_in; i += 4) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - din3_ptr = dr3; - din4_ptr = dr4; - - doutr0_ptr = doutr0; - doutr1_ptr = doutr0 + w_out; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - din3_ptr = dr2; - din4_ptr = dr3; - dr0 = dr3; - dr1 = dr4; - } else { - dr0 = dr4; - dr1 = dr0 + w_in; - } - dr2 = dr1 + w_in; - dr3 = dr2 + w_in; - dr4 = dr3 + w_in; - - //! process bottom pad - if (i + 4 > h_in) { - switch (i + 4 - h_in) { - case 4: - din1_ptr = zero_ptr; - case 3: - din2_ptr = zero_ptr; - case 2: - din3_ptr = zero_ptr; - case 1: - din4_ptr = zero_ptr; - default: - break; - } - } - //! process output pad - if (i / 2 + 2 > h_out) { - doutr1_ptr = write_ptr; - } - int cnt = cnt_col; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "prfm pldl1keep, [%[inptr0]] \n" - "prfm pldl1keep, [%[inptr1]] \n" - "prfm pldl1keep, [%[inptr2]] \n" - "prfm pldl1keep, [%[inptr3]] \n" - "prfm pldl1keep, [%[inptr4]] \n" - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5} - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr0], %[inptr0], #4 \n" - "sub %[inptr1], %[inptr1], #4 \n" - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr2], %[inptr2], #4 \n" - "sub %[inptr3], %[inptr3], #4 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01 - "fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - - "fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02 - "fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - - "fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00 - "fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5} - - "sub %[inptr4], %[inptr4], #4 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00 - - "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5} - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01 - "fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02 - "fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00 - - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - "fadd v17.4s, v17.4s, v13.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - - "ld1 {v18.4s}, [%[inptr1]] \n" - "ld1 {v19.4s}, [%[inptr2]] \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "cmp %[cnt], #1 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "blt 1f \n" - // mid - "2: \n" - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8} - "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6} - // v1={1,3,5,7} - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8} - - "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" - "ld1 {v15.4s}, [%[inptr0]] \n" - "ld1 {v18.4s}, [%[inptr1]] \n" - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "ld1 {v19.4s}, [%[inptr2]] \n" - "ld1 {v20.4s}, [%[inptr3]] \n" - "ld1 {v21.4s}, [%[inptr4]] \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fadd v17.4s, v17.4s, v14.4s \n" - - "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} - "and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias - "subs %[cnt], %[cnt], #1 \n" - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - - "and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias - - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 4f \n" - "3: \n" - "bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - "ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - "bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r0 - "fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei - "bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei - - // r1 - "fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r2 - "fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00 - "fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - - "fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01 - "fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - - "fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02 - "fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - - // r3 - "fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02 - - "ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8} - "ld1 {v0.4s}, [%[outptr0]] \n" - - "fadd v16.4s, v16.4s, v11.4s \n" - "fadd v16.4s, v16.4s, v12.4s \n" - "ld1 {v1.4s}, [%[outptr1]] \n" - - // r4 - "fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00 - "fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01 - "fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02 - - "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ - - "fadd v17.4s, v17.4s, v13.4s \n" - - "bif v16.16b, v0.16b, %[wmask].16b \n" // pipei - - "fadd v17.4s, v17.4s, v14.4s \n" - - "st1 {v16.4s}, [%[outptr0]], #16 \n" - - "fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ - - "bif v17.16b, v1.16b, %[wmask].16b \n" // pipei - - "st1 {v17.4s}, [%[outptr1]], #16 \n" - "4: \n" - : [inptr0] "+r"(din0_ptr), - [inptr1] "+r"(din1_ptr), - [inptr2] "+r"(din2_ptr), - [inptr3] "+r"(din3_ptr), - [inptr4] "+r"(din4_ptr), - [outptr0] "+r"(doutr0_ptr), - [outptr1] "+r"(doutr1_ptr), - [cnt] "+r"(cnt) - : [vzero] "w"(vzero), - [w0] "w"(wr0), - [w1] "w"(wr1), - [w2] "w"(wr2), - [remain] "r"(cnt_remain), - [mask1] "w"(vmask_rp1), - [mask2] "w"(vmask_rp2), - [wmask] "w"(wmask), - [vbias] "w"(wbias) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17", - "v18", - "v19", - "v20", - "v21"); - doutr0 = doutr0 + 2 * w_out; - } -#else - - for (int i = 0; i < h_in; i += 2) { - din0_ptr = dr0; - din1_ptr = dr1; - din2_ptr = dr2; - - doutr0_ptr = doutr0; - - if (i == 0) { - din0_ptr = zero_ptr; - din1_ptr = dr0; - din2_ptr = dr1; - dr0 = dr1; - dr1 = dr2; - dr2 = dr1 + w_in; - } else { - dr0 = dr2; - dr1 = dr0 + w_in; - dr2 = dr1 + w_in; - } - - //! process bottom pad - if (i + 2 > h_in) { - switch (i + 2 - h_in) { - case 2: - din1_ptr = zero_ptr; - case 1: - din2_ptr = zero_ptr; - default: - break; - } - } - int cnt = cnt_col; - - unsigned int* mask_ptr = dmask; - asm volatile( - // top - // Load up 12 elements (3 vectors) from each of 8 sources. - "0: \n" - "vmov.u32 q9, #0 \n" - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11 - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13 - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15 - "pld [%[din0_ptr]] @ preload data\n" - "pld [%[din1_ptr]] @ preload data\n" - "pld [%[din2_ptr]] @ preload data\n" - - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vext.32 q6, q9, q11, #3 @ shift right 1 " - "data\n" // q2 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift right 1 " - "data\n" // q6 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "sub %[din0_ptr], #4 @ inpitr0 - 1\n" - "sub %[din1_ptr], #4 @ inpitr1 - 1\n" - "sub %[din2_ptr], #4 @ inpitr2 - 1\n" - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q11 * w01 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q12 * w02 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w00 - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, " - "out1\n" // q0 * w01 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, " - "out1\n" // q1 * w02 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, " - "out1\n" // q2 * w00 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "cmp %[cnt], #1 \n" - "blt 1f \n" - // mid - "2: \n" - "vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14} - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - "vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14} - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7} - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7} - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "subs %[cnt], #1 \n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "bne 2b \n" - - // right - "1: \n" - "cmp %[remain], #1 \n" - "blt 3f \n" - - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q10 = - // vbias - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - - "vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, " - "out0\n" // q0 * w00 - "vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w02 - - "vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8} - "vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" - - "vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, " - "out0\n" // q6 * w02 - - "vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" - - "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, " - "out0\n" // q0 * w00 - "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, " - "out0\n" // q1 * w01 - "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, " - "out0\n" // q6 * w02 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu \n" - - "vbif.f32 q3, q10, q11 @ write mask\n" - - "vst1.32 {d6-d7}, [%[outptr]]! \n" - "3: \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [outptr] "+r"(doutr0_ptr), - [cnt] "+r"(cnt), - [mask_ptr] "+r"(mask_ptr) - : [remain] "r"(cnt_remain), - [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); - - doutr0 = doutr0 + w_out; - } -#endif - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - int hs = -1; - int he = 3; - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - int h_cnt = (h_out + 1) >> 1; - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_cnt; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - if (hs == -1) { - dr0 = zero; - } - - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s}, [%[din0]], #16\n" - "ld1 {v1.4s}, [%[din1]], #16\n" - "ld1 {v2.4s}, [%[din2]], #16\n" - "ld1 {v3.4s}, [%[din3]], #16\n" - - "bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234 - "bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234 - "bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234 - - "ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123 - "ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123 - "ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123 - "ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123 - - "ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340 - "ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340 - "ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340 - "ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340 - - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" - - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" - - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" - - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" - - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" - - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" - - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" - - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" - - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" - - "fadd v12.4s, v12.4s, v14.4s\n" - "fadd v12.4s, v12.4s, v16.4s\n" - - "fadd v13.4s, v13.4s, v15.4s\n" // out1 - "fadd v13.4s, v13.4s, v17.4s\n" // out2 - - "fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias - "fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d12-d13}, [%[din0]]!\n" - "vld1.32 {d14-d15}, [%[din1]]!\n" - "vld1.32 {d16-d17}, [%[din2]]!\n" - "vld1.32 {d18-d19}, [%[din3]]!\n" - - "vbif q6, %q[zero], %q[mask]\n" // d0_1234 - "vbif q7, %q[zero], %q[mask]\n" // d1_1234 - "vbif q8, %q[zero], %q[mask]\n" // d2_1234 - "vbif q9, %q[zero], %q[mask]\n" // d3_1234 - - "vmul.f32 q14, q6, %e[wr0][1]\n" - "vmul.f32 q15, q7, %e[wr0][1]\n" - - "vmla.f32 q14, q7, %e[wr1][1]\n" - "vmla.f32 q15, q8, %e[wr1][1]\n" - - "vmla.f32 q14, q8, %e[wr2][1]\n" - "vmla.f32 q15, q9, %e[wr2][1]\n" - - "vext.32 q10, %q[zero], q6, #3\n" // d0_0123 - "vext.32 q11, %q[zero], q7, #3\n" // d1_0123 - "vext.32 q12, %q[zero], q8, #3\n" // d2_0123 - "vext.32 q13, %q[zero], q9, #3\n" // d3_0123 - - "vmla.f32 q14, q10, %e[wr0][0]\n" - "vmla.f32 q15, q11, %e[wr0][0]\n" - - "vmla.f32 q14, q11, %e[wr1][0]\n" - "vmla.f32 q15, q12, %e[wr1][0]\n" - - "vmla.f32 q14, q12, %e[wr2][0]\n" - "vmla.f32 q15, q13, %e[wr2][0]\n" - - "vext.32 q10, q6, %q[zero], #1\n" // d0_2340 - "vext.32 q11, q7, %q[zero], #1\n" // d1_2340 - "vext.32 q12, q8, %q[zero], #1\n" // d2_2340 - "vext.32 q13, q9, %q[zero], #1\n" // d3_2340 - - "vmla.f32 q14, q10, %f[wr0][0]\n" - "vmla.f32 q15, q11, %f[wr0][0]\n" - - "vmla.f32 q14, q11, %f[wr1][0]\n" - "vmla.f32 q15, q12, %f[wr1][0]\n" - - "vmla.f32 q14, q12, %f[wr2][0]\n" // out1 - "vmla.f32 q15, q13, %f[wr2][0]\n" // out2 - - "vadd.f32 q14, q14, %q[bias]\n" // out1 add bias - "vadd.f32 q15, q15, %q[bias]\n" // out2 add bias - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vst1.32 {d28-d29}, [%[out1]]\n" - "vst1.32 {d30-d31}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - hs += 2; - he += 2; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 4 - */ - -void conv_depthwise_3x3s2p1_bias_s(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - int hs = -1; - int he = 2; - float out_buf[4]; - for (int j = 0; j < h_out; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - if (hs == -1) { - dr0 = zeros; - } - if (he > h_in) { - dr2 = zeros; - } - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 = - // {0,1,3,5} - "ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 = - // {0,1,3,5} - "ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 = - // {0,1,3,5} - - "fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01 - "fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02 - "fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00 - - "fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12 - "fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21 - "fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v6.4s \n" - - "fadd v4.4s, v4.4s, %[bias].4s \n" - - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, " - "out0\n" // q10 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, " - "out0\n" // q11 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w00 - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - hs += 2; - he += 2; - } - } - } -} -/** - * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, - * width <= 4 - */ -void conv_depthwise_3x3s1p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - //! 3x3s1 convolution, implemented by direct algorithm - //! pad is done implicit - //! for 4x6 convolution window - const int right_pad_idx[4] = {3, 2, 1, 0}; - const float zero[4] = {0.f, 0.f, 0.f, 0.f}; - - float32x4_t vzero = vdupq_n_f32(0.f); - uint32x4_t vmask_rp = - vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in)); - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - float* dout_channel = dout_batch + i * size_out_channel; - const float* din_channel = din_batch + i * size_in_channel; - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - float32x4_t wbias; - if (flag_bias) { - wbias = vdupq_n_f32(bias[i]); - } else { - wbias = vdupq_n_f32(0.f); - } - - int hs = -1; - int he = 3; - - float out_buf1[4]; - float out_buf2[4]; - float trash_buf[4]; - - int h_cnt = (h_out + 1) >> 1; - float* doutr0 = dout_channel; - float* doutr1 = dout_channel + w_out; - - for (int j = 0; j < h_cnt; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - const float* dr3 = dr2 + w_in; - - if (hs == -1) { - dr0 = zero; - } - - switch (he - h_in) { - case 2: - dr2 = zero; - doutr1 = trash_buf; - case 1: - dr3 = zero; - default: - break; - } -#ifdef __aarch64__ - asm volatile( - "prfm pldl1keep, [%[din0]]\n" - "prfm pldl1keep, [%[din1]]\n" - "prfm pldl1keep, [%[din2]]\n" - "prfm pldl1keep, [%[din3]]\n" - - "ld1 {v0.4s}, [%[din0]], #16\n" - "ld1 {v1.4s}, [%[din1]], #16\n" - "ld1 {v2.4s}, [%[din2]], #16\n" - "ld1 {v3.4s}, [%[din3]], #16\n" - - "bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234 - "bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234 - "bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234 - "bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234 - - "ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123 - "ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123 - "ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123 - "ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123 - - "ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340 - "ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340 - "ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340 - "ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340 - - "fmul v12.4s, v0.4s, %[wr0].s[1]\n" - "fmul v13.4s, v1.4s, %[wr0].s[1]\n" - - "fmul v14.4s, v1.4s, %[wr1].s[1]\n" - "fmul v15.4s, v2.4s, %[wr1].s[1]\n" - - "fmul v16.4s, v2.4s, %[wr2].s[1]\n" - "fmul v17.4s, v3.4s, %[wr2].s[1]\n" - - "fmla v12.4s, v4.4s, %[wr0].s[0]\n" - "fmla v13.4s, v5.4s, %[wr0].s[0]\n" - - "fmla v14.4s, v5.4s, %[wr1].s[0]\n" - "fmla v15.4s, v6.4s, %[wr1].s[0]\n" - - "fmla v16.4s, v6.4s, %[wr2].s[0]\n" - "fmla v17.4s, v7.4s, %[wr2].s[0]\n" - - "fmla v12.4s, v8.4s, %[wr0].s[2]\n" - "fmla v13.4s, v9.4s, %[wr0].s[2]\n" - - "fmla v14.4s, v9.4s, %[wr1].s[2]\n" - "fmla v15.4s, v10.4s, %[wr1].s[2]\n" - - "fmla v16.4s, v10.4s, %[wr2].s[2]\n" - "fmla v17.4s, v11.4s, %[wr2].s[2]\n" - - "fadd v12.4s, v12.4s, v14.4s\n" - "fadd v12.4s, v12.4s, v16.4s\n" - - "fadd v13.4s, v13.4s, v15.4s\n" // out1 - "fadd v13.4s, v13.4s, v17.4s\n" // out2 - - "fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias - "fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias - - "prfm pldl1keep, [%[out1]]\n" - "prfm pldl1keep, [%[out2]]\n" - - "fmax v12.4s, v12.4s, %[zero].4s\n" // out1 -> relu - "fmax v13.4s, v13.4s, %[zero].4s\n" // out2 -> relu - - "st1 {v12.4s}, [%[out1]]\n" - "st1 {v13.4s}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "v0", - "v1", - "v2", - "v3", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15", - "v16", - "v17"); -#else - asm volatile( - "pld [%[din0]]\n" - "pld [%[din1]]\n" - "pld [%[din2]]\n" - "pld [%[din3]]\n" - - "vld1.32 {d12-d13}, [%[din0]]!\n" - "vld1.32 {d14-d15}, [%[din1]]!\n" - "vld1.32 {d16-d17}, [%[din2]]!\n" - "vld1.32 {d18-d19}, [%[din3]]!\n" - - "vbif q6, %q[zero], %q[mask]\n" // d0_1234 - "vbif q7, %q[zero], %q[mask]\n" // d1_1234 - "vbif q8, %q[zero], %q[mask]\n" // d2_1234 - "vbif q9, %q[zero], %q[mask]\n" // d3_1234 - - "vmul.f32 q14, q6, %e[wr0][1]\n" - "vmul.f32 q15, q7, %e[wr0][1]\n" - - "vmla.f32 q14, q7, %e[wr1][1]\n" - "vmla.f32 q15, q8, %e[wr1][1]\n" - - "vmla.f32 q14, q8, %e[wr2][1]\n" - "vmla.f32 q15, q9, %e[wr2][1]\n" - - "vext.32 q10, %q[zero], q6, #3\n" // d0_0123 - "vext.32 q11, %q[zero], q7, #3\n" // d1_0123 - "vext.32 q12, %q[zero], q8, #3\n" // d2_0123 - "vext.32 q13, %q[zero], q9, #3\n" // d3_0123 - - "vmla.f32 q14, q10, %e[wr0][0]\n" - "vmla.f32 q15, q11, %e[wr0][0]\n" - - "vmla.f32 q14, q11, %e[wr1][0]\n" - "vmla.f32 q15, q12, %e[wr1][0]\n" - - "vmla.f32 q14, q12, %e[wr2][0]\n" - "vmla.f32 q15, q13, %e[wr2][0]\n" - - "vext.32 q10, q6, %q[zero], #1\n" // d0_2340 - "vext.32 q11, q7, %q[zero], #1\n" // d1_2340 - "vext.32 q12, q8, %q[zero], #1\n" // d2_2340 - "vext.32 q13, q9, %q[zero], #1\n" // d3_2340 - - "vmla.f32 q14, q10, %f[wr0][0]\n" - "vmla.f32 q15, q11, %f[wr0][0]\n" - - "vmla.f32 q14, q11, %f[wr1][0]\n" - "vmla.f32 q15, q12, %f[wr1][0]\n" - - "vmla.f32 q14, q12, %f[wr2][0]\n" // out1 - "vmla.f32 q15, q13, %f[wr2][0]\n" // out2 - - "vadd.f32 q14, q14, %q[bias]\n" // out1 add bias - "vadd.f32 q15, q15, %q[bias]\n" // out2 add bias - - "pld [%[out1]]\n" - "pld [%[out2]]\n" - - "vmax.f32 q14, q14, %q[zero]\n" // out1 -> relu - "vmax.f32 q15, q15, %q[zero]\n" // out2 -> relu - - "vst1.32 {d28-d29}, [%[out1]]\n" - "vst1.32 {d30-d31}, [%[out2]]\n" - - : [din0] "+r"(dr0), - [din1] "+r"(dr1), - [din2] "+r"(dr2), - [din3] "+r"(dr3) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [zero] "w"(vzero), - [mask] "w"(vmask_rp), - [bias] "w"(wbias), - [out1] "r"(out_buf1), - [out2] "r"(out_buf2) - : "cc", - "memory", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *doutr0++ = out_buf1[w]; - *doutr1++ = out_buf2[w]; - } - doutr0 = doutr1; - doutr1 += w_out; - hs += 2; - he += 2; - } // end of processing heights - } // end of processing channels - } // end of processing batchs -} - -/** - * \brief depthwise convolution kernel 3x3, stride 2, width <= 7 - */ -void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, - const float* din, - const float* weights, - const float* bias, - bool flag_bias, - const int num, - const int ch_in, - const int h_in, - const int w_in, - const int h_out, - const int w_out, - ARMContext* ctx) { - int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - int out_pad_idx[4] = {0, 1, 2, 3}; - float zeros[8] = {0.0f}; - - uint32x4_t vmask_rp1 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6 - uint32x4_t vmask_rp2 = - vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7 - - int size_in_channel = w_in * h_in; - int size_out_channel = w_out * h_out; - - unsigned int dmask[8]; - vst1q_u32(dmask, vmask_rp1); - vst1q_u32(dmask + 4, vmask_rp2); - - for (int n = 0; n < num; ++n) { - const float* din_batch = din + n * ch_in * size_in_channel; - float* dout_batch = dout + n * ch_in * size_out_channel; -#pragma omp parallel for - for (int i = 0; i < ch_in; ++i) { - const float* din_channel = din_batch + i * size_in_channel; - float* dout_channel = dout_batch + i * size_out_channel; - - const float* weight_ptr = weights + i * 9; - float32x4_t wr0 = vld1q_f32(weight_ptr); - float32x4_t wr1 = vld1q_f32(weight_ptr + 3); - float32x4_t wr2 = vld1q_f32(weight_ptr + 6); - - float bias_c = 0.f; - - if (flag_bias) { - bias_c = bias[i]; - } - float32x4_t vbias = vdupq_n_f32(bias_c); - int hs = -1; - int he = 2; - float out_buf[4]; - for (int j = 0; j < h_out; ++j) { - const float* dr0 = din_channel + hs * w_in; - const float* dr1 = dr0 + w_in; - const float* dr2 = dr1 + w_in; - if (hs == -1) { - dr0 = zeros; - } - if (he > h_in) { - dr2 = zeros; - } - const float* din0_ptr = dr0; - const float* din1_ptr = dr1; - const float* din2_ptr = dr2; - - unsigned int* mask_ptr = dmask; -#ifdef __aarch64__ - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "movi v9.4s, #0 \n" - "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" - - "ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6} - // v11={1,3,5,7} - "ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6} - // v12={1,3,5,7} - "ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6} - // v15={1,3,5,7} - - "bif v10.16b, v9.16b, v6.16b \n" - "bif v11.16b, v9.16b, v7.16b \n" - "bif v12.16b, v9.16b, v6.16b \n" - "bif v13.16b, v9.16b, v7.16b \n" - "bif v14.16b, v9.16b, v6.16b \n" - "bif v15.16b, v9.16b, v7.16b \n" - - "ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 = - // {0,1,3,5} - "ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 = - // {0,1,3,5} - "ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 = - // {0,1,3,5} - - "fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01 - "fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02 - "fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00 - - "fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11 - "fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12 - "fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10 - - "fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20 - "fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21 - "fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22 - - "fadd v4.4s, v4.4s, v5.4s \n" - "fadd v4.4s, v4.4s, v6.4s \n" - - "fadd v4.4s, v4.4s, %[bias].4s \n" // out add bias - "fmax v4.4s, v4.4s, v9.4s \n" - - "st1 {v4.4s}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "w"(vbias), - [out] "r"(out_buf) - : "cc", - "memory", - "v4", - "v5", - "v6", - "v7", - "v8", - "v9", - "v10", - "v11", - "v12", - "v13", - "v14", - "v15"); - -#else - asm volatile( - // Load up 12 elements (3 vectors) from each of 8 sources. - "vmov.u32 q9, #0 \n" - "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" - "vdup.32 q3, %[bias] @ and \n" // q3 = - // vbias - - "vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7} - "vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7} - "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7} - - "vbif q10, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q11, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q12, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q13, q9, q7 @ bit select, deal " - "with right pad\n" - "vbif q14, q9, q6 @ bit select, deal " - "with right pad\n" - "vbif q15, q9, q7 @ bit select, deal " - "with right pad\n" - - "vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5} - "vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5} - "vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5} - - "vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, " - "out0\n" // q10 * w01 - "vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, " - "out0\n" // q11 * w02 - "vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, " - "out0\n" // q6 * w00 - - "vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, " - "out0\n" // q12 * w11 - "vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, " - "out0\n" // q13 * w12 - "vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, " - "out0\n" // q7 * w10 - - "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, " - "out0\n" // q14 * w20 - "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, " - "out0\n" // q15 * w21 - "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, " - "out0\n" // q8 * w22 - - "vadd.f32 q3, q3, q4 @ add \n" - "vadd.f32 q3, q3, q5 @ add \n" - - "vmax.f32 q3, q3, q9 @ relu\n" - - "vst1.32 {d6-d7}, [%[out]] \n" - : [din0_ptr] "+r"(din0_ptr), - [din1_ptr] "+r"(din1_ptr), - [din2_ptr] "+r"(din2_ptr), - [mask_ptr] "+r"(mask_ptr) - : [wr0] "w"(wr0), - [wr1] "w"(wr1), - [wr2] "w"(wr2), - [bias] "r"(bias_c), - [out] "r"(out_buf) - : "cc", - "memory", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12", - "q13", - "q14", - "q15"); -#endif // __aarch64__ - for (int w = 0; w < w_out; ++w) { - *dout_channel++ = out_buf[w]; - } - hs += 2; - he += 2; - } - } - } -} - -} // namespace math -} // namespace arm -} // namespace lite -} // namespace paddle diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 02a49cf157296763ce3a61ea99dd4ce513dc2f30..8618baf286ecabd3e5ad47dc1c80f0c66f1ece0f 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -361,7 +361,6 @@ void conv_im2col_gemm(const float* i_data, float* tmp_work_space = ctx->workspace_data() + ctx->llc_size() / sizeof(float); - //! use gemv when the output channel size = 1 for (int b = 0; b < num; ++b) { // dC diff --git a/lite/backends/arm/math/funcs.h b/lite/backends/arm/math/funcs.h index 2d07e908c229c8d300ad64510d72fc12f8374fea..8977b5712c13dec0088d83db4cbfef8494785301 100644 --- a/lite/backends/arm/math/funcs.h +++ b/lite/backends/arm/math/funcs.h @@ -39,6 +39,7 @@ #include "lite/backends/arm/math/im2sequence.h" #include "lite/backends/arm/math/increment.h" #include "lite/backends/arm/math/interpolate.h" +#include "lite/backends/arm/math/layout.h" #include "lite/backends/arm/math/lrn.h" #include "lite/backends/arm/math/negative.h" #include "lite/backends/arm/math/norm.h" diff --git a/lite/backends/arm/math/layout.cc b/lite/backends/arm/math/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd9126ab48c8f829c82d0c78a338074c695f0b9c --- /dev/null +++ b/lite/backends/arm/math/layout.cc @@ -0,0 +1,668 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/layout.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +#ifdef __aarch64__ +#define TRANS_C4 \ + "ld1 {v0.4s}, [%[din0_ptr]] \n" \ + "ld1 {v1.4s}, [%[din1_ptr]] \n" \ + "ld1 {v2.4s}, [%[din2_ptr]] \n" \ + "ld1 {v3.4s}, [%[din3_ptr]] \n" \ + \ + "1: \n" \ + "trn1 v4.4s, v0.4s, v1.4s \n" /*00 10 02 12 */ \ + "trn1 v5.4s, v2.4s, v3.4s \n" /*20 30 22 32 */ \ + "trn2 v6.4s, v0.4s, v1.4s \n" /*01 11 03 13 */ \ + "trn2 v7.4s, v2.4s, v3.4s \n" /*21 31 23 33 */ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride] \n" /* din+=c*size*/ \ + \ + "trn1 v8.2d, v4.2d, v5.2d \n" /*00 10 20 30 */ \ + "trn1 v9.2d, v6.2d, v7.2d \n" /*01 11 21 31 */ \ + "trn2 v10.2d, v4.2d, v5.2d \n" /*02 12 22 32 */ \ + "trn2 v11.2d, v6.2d, v7.2d \n" /*03 13 23 33 */ \ + \ + "ld1 {v0.4s}, [%[din0_ptr]] \n" \ + "ld1 {v1.4s}, [%[din1_ptr]] \n" \ + "ld1 {v2.4s}, [%[din2_ptr]] \n" \ + "ld1 {v3.4s}, [%[din3_ptr]] \n" \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "str q8, [%[out0_ptr]], #16 \n" \ + "str q9, [%[out1_ptr]], #16 \n" \ + "str q10, [%[out2_ptr]], #16 \n" \ + "str q11, [%[out3_ptr]], #16 \n" \ + "bne 1b \n" + +#define TRANS_C8 \ + "1: \n" \ + "ld1 {v0.8b}, [%[din0_ptr]] \n" \ + "ld1 {v1.8b}, [%[din1_ptr]] \n" \ + "ld1 {v2.8b}, [%[din2_ptr]] \n" \ + "ld1 {v3.8b}, [%[din3_ptr]] \n" \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "trn1 v8.8b, v0.8b, v1.8b \n" /*00 10 02 12 04 14 06 16 */ \ + "trn1 v9.8b, v2.8b, v3.8b \n" /*20 30 22 32 */ \ + "trn2 v12.8b, v0.8b, v1.8b \n" /*01 11 03 13 05 15 07 17 */ \ + "trn2 v13.8b, v2.8b, v3.8b \n" /*21 31 23 33 */ \ + \ + "ld1 {v4.8b}, [%[din0_ptr]] \n" \ + "ld1 {v5.8b}, [%[din1_ptr]] \n" \ + "ld1 {v6.8b}, [%[din2_ptr]] \n" \ + "ld1 {v7.8b}, [%[din3_ptr]] \n" \ + \ + "trn1 v10.8b, v4.8b, v5.8b \n" /*40 50 42 52 */ \ + "trn1 v11.8b, v6.8b, v7.8b \n" /*60 70 62 72 */ \ + "trn2 v14.8b, v4.8b, v5.8b \n" /*41 51 43 53 */ \ + "trn2 v15.8b, v6.8b, v7.8b \n" /*61 71 63 73 */ \ + \ + "trn1 v0.4h, v8.4h, v9.4h \n" /*00 10 20 30 04 14 24 34*/ \ + "trn1 v2.4h, v12.4h, v13.4h \n" /*01 11 21 31 05 15 25 35*/ \ + "trn1 v1.4h, v10.4h, v11.4h \n" /*40 50 60 70 44 54 64 74*/ \ + "trn1 v3.4h, v14.4h, v15.4h \n" /*41 51 61 71 45 55 65 75*/ \ + \ + "trn2 v4.4h, v8.4h, v9.4h \n" /*02 10 20 30 06 14 24 34*/ \ + "trn2 v6.4h, v12.4h, v13.4h \n" /*03 11 21 31 07 15 25 35*/ \ + "trn2 v5.4h, v10.4h, v11.4h \n" /*42 50 60 70 46 54 64 74*/ \ + "trn2 v7.4h, v14.4h, v15.4h \n" /*43 51 61 71 47 55 65 75*/ \ + \ + "trn1 v8.2s, v0.2s, v1.2s \n" /*00 10 20 30 40 50 60 70*/ \ + "trn1 v9.2s, v2.2s, v3.2s \n" /*01 11 21 31 41 51 61 71*/ \ + "trn1 v10.2s, v4.2s, v5.2s \n" /*02 12 22 32 42 50 60 70*/ \ + "trn1 v11.2s, v6.2s, v7.2s \n" /*03 13 23 33 41 51 61 71*/ \ + \ + "trn2 v12.2s, v0.2s, v1.2s \n" /*04 14 24 34 44 54 64 74*/ \ + "trn2 v13.2s, v2.2s, v3.2s \n" /*05 15 25 35 45 55 65 75*/ \ + "trn2 v14.2s, v4.2s, v5.2s \n" /*06 16 22 32 42 50 60 70*/ \ + "trn2 v15.2s, v6.2s, v7.2s \n" /*07 17 23 33 41 51 61 71*/ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "st1 {v8.8b}, [%[out0_ptr]], #8 \n" \ + "st1 {v9.8b}, [%[out1_ptr]], #8 \n" \ + "st1 {v10.8b}, [%[out2_ptr]], #8 \n" \ + "st1 {v11.8b}, [%[out3_ptr]], #8 \n" \ + \ + "st1 {v11.8b}, [%[out4_ptr]], #8 \n" \ + "st1 {v12.8b}, [%[out5_ptr]], #8 \n" \ + "st1 {v13.8b}, [%[out6_ptr]], #8 \n" \ + "st1 {v14.8b}, [%[out7_ptr]], #8 \n" \ + "bne 1b \n" + +#else +#define TRANS_C4 \ + "1: \n" \ + "vld1.32 {d0-d1}, [%[din0_ptr]] \n" \ + "vld1.32 {d2-d3}, [%[din1_ptr]] \n" \ + "vld1.32 {d4-d5}, [%[din2_ptr]] \n" \ + "vld1.32 {d6-d7}, [%[din3_ptr]] \n" \ + \ + "vtrn.32 q0, q1 \n" /*00 10 02 12 01 11 03 13*/ \ + "vtrn.32 q2, q3 \n" /*20 30 22 32 21 31 23 33 */ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride] \n" /* din+=c*size*/ \ + "vswp d1, d4 \n" \ + "vswp d3, d6 \n" \ + \ + "subs %[cnt], %[cnt], #1 \n" \ + "vst1.32 {d0-d1}, [%[out0_ptr]]! \n" \ + "vst1.32 {d2-d3}, [%[out1_ptr]]! \n" \ + "vst1.32 {d4-d5}, [%[out2_ptr]]! \n" \ + "vst1.32 {d6-d7}, [%[out3_ptr]]! \n" \ + "bne 1b \n" + +#define TRANS_C8 \ + "1: \n" \ + "vld1.8 d0, [%[din0_ptr]] \n" \ + "vld1.8 d1, [%[din1_ptr]] \n" \ + "vld1.8 d2, [%[din2_ptr]] \n" \ + "vld1.8 d3, [%[din3_ptr]] \n" \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "vtrn.8 d0, d1 \n" /*00 10 02 12 04 14 06 16*/ \ + "vtrn.8 d2, d3 \n" /*20 30 22 32 24 34 26 36 */ \ + \ + "vld1.8 d4, [%[din0_ptr]] \n" \ + "vld1.8 d5, [%[din1_ptr]] \n" \ + "vld1.8 d6, [%[din2_ptr]] \n" \ + "vld1.8 d7, [%[din3_ptr]] \n" \ + \ + "vtrn.16 d0, d2 \n" /*00 10 20 30 04 14 24 34*/ \ + "vtrn.16 d1, d3 \n" /* 01 11 21 31 05 15 25 35 */ \ + "vtrn.8 d4, d5 \n" /*40 50 02 12 04 14 06 16*/ \ + "vtrn.8 d6, d7 \n" /*60 70 22 32 24 34 26 36 */ \ + \ + "add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \ + "add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \ + \ + "vtrn.16 d4, d6 \n" /*40 50 60 70 04 14 24 34*/ \ + "vtrn.16 d5, d7 \n" /* 41 51 61 71 05 15 25 35 */ \ + \ + "vtrn.32 d0, d4 \n" /*00 10 20 30 40 50 60 70*/ \ + "vtrn.32 d1, d5 \n" /* 01 11 21 31 41 51 61 71 */ \ + "vtrn.32 d2, d6 \n" /*02 12 22 32 42 52 62 72*/ \ + "vtrn.32 d3, d7 \n" /* 03 11 21 33 43 53 63 73 */ \ + \ + "subs %[cnt], %[cnt], #1 \n" \ + "vst1.8 {d0}, [%[out0_ptr]]! \n" \ + "vst1.8 {d1}, [%[out1_ptr]]! \n" \ + "vst1.8 {d2}, [%[out2_ptr]]! \n" \ + "vst1.8 {d3}, [%[out3_ptr]]! \n" \ + "vst1.8 {d4}, [%[out4_ptr]]! \n" \ + "vst1.8 {d5}, [%[out5_ptr]]! \n" \ + "vst1.8 {d6}, [%[out6_ptr]]! \n" \ + "vst1.8 {d7}, [%[out7_ptr]]! \n" \ + "bne 1b \n" + +#endif +template <> +void NCHW2NHWC(int N, int C, int size, const float* X, float* Y) { + int cnt = C >> 2; + int remain = C % 4; + int sum = C * size; + int stride = size << 4; // 4 * size + int stride_w = stride >> 2; + for (int n = 0; n < N; n++) { + const float* din = X + n * sum; + float* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < size - 3; s += 4) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + size; + const float* din2_ptr = din1_ptr + size; + const float* din3_ptr = din2_ptr + size; + float* out0_ptr = dout + s * C; + float* out1_ptr = out0_ptr + C; + float* out2_ptr = out1_ptr + C; + float* out3_ptr = out2_ptr + C; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12"); +#else + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + for (int i = 0; i < remain; i++) { + const float* ptr = din0_ptr; + *out0_ptr++ = *ptr++; + *out1_ptr++ = *ptr++; + *out2_ptr++ = *ptr++; + *out3_ptr++ = *ptr++; + din0_ptr += size; + } + } + // remain size + for (; s < size; s++) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + size; + const float* din2_ptr = din1_ptr + size; + const float* din3_ptr = din2_ptr + size; + float* out0_ptr = dout + s * C; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + din0_ptr += stride_w; + din1_ptr += stride_w; + din2_ptr += stride_w; + din3_ptr += stride_w; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += size; + } + } + } +} +template <> +void NCHW2NHWC(int N, int C, int size, const int8_t* X, int8_t* Y) { + int cnt = C >> 3; + int remain = C % 8; + int sum = C * size; + int stride = size << 3; // 8 * size + int stride_w = size << 4; // 4 * size * 4 + for (int n = 0; n < N; n++) { + const int8_t* din = X + n * sum; + int8_t* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < size - 7; s += 8) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + size; + const int8_t* din2_ptr = din1_ptr + size; + const int8_t* din3_ptr = din2_ptr + size; + int8_t* out0_ptr = dout + s * C; + int8_t* out1_ptr = out0_ptr + C; + int8_t* out2_ptr = out1_ptr + C; + int8_t* out3_ptr = out2_ptr + C; + int8_t* out4_ptr = out3_ptr + C; + int8_t* out5_ptr = out4_ptr + C; + int8_t* out6_ptr = out5_ptr + C; + int8_t* out7_ptr = out6_ptr + C; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + // const int8_t* din_ptr = din + 8 * cnt * size + s; // remain channel + for (int i = 0; i < remain; i++) { + const int8_t* ptr = din0_ptr; + *out0_ptr = *ptr++; + *out1_ptr = *ptr++; + *out2_ptr = *ptr++; + *out3_ptr = *ptr++; + din0_ptr += size; + *out4_ptr = *ptr++; + *out5_ptr = *ptr++; + *out6_ptr = *ptr++; + *out7_ptr = *ptr++; + } + } + // remain size + for (; s < size; s++) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + size; + const int8_t* din2_ptr = din1_ptr + size; + const int8_t* din3_ptr = din2_ptr + size; + const int8_t* din4_ptr = din3_ptr + size; + const int8_t* din5_ptr = din4_ptr + size; + const int8_t* din6_ptr = din5_ptr + size; + const int8_t* din7_ptr = din6_ptr + size; + int8_t* out0_ptr = dout + s * C; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + *out0_ptr++ = *din4_ptr; + *out0_ptr++ = *din5_ptr; + *out0_ptr++ = *din6_ptr; + *out0_ptr++ = *din7_ptr; + din0_ptr += stride; + din1_ptr += stride; + din2_ptr += stride; + din3_ptr += stride; + din4_ptr += stride; + din5_ptr += stride; + din6_ptr += stride; + din7_ptr += stride; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += size; + } + } + } +} +template <> +void NHWC2NCHW(int N, int C, int size, const float* X, float* Y) { + int cnt = size >> 2; + int remain = size % 4; + int sum = C * size; + int stride = C << 4; // 4 * size + int stride_w = C << 2; + for (int n = 0; n < N; n++) { + const float* din = X + n * sum; + float* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < C - 3; s += 4) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + C; + const float* din2_ptr = din1_ptr + C; + const float* din3_ptr = din2_ptr + C; + float* out0_ptr = dout + s * size; + float* out1_ptr = out0_ptr + size; + float* out2_ptr = out1_ptr + size; + float* out3_ptr = out2_ptr + size; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); +#else + asm volatile(TRANS_C4 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [cnt] "+r"(cnt_num), + [stride] "+r"(stride) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + for (int i = 0; i < remain; i++) { + const float* ptr = din0_ptr; + *out0_ptr++ = *ptr++; + *out1_ptr++ = *ptr++; + *out2_ptr++ = *ptr++; + *out3_ptr++ = *ptr++; + din0_ptr += C; + } + } + // remain size + for (; s < C; s++) { + const float* din0_ptr = din + s; + const float* din1_ptr = din0_ptr + C; + const float* din2_ptr = din1_ptr + C; + const float* din3_ptr = din2_ptr + C; + float* out0_ptr = dout + s * size; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + din0_ptr += stride_w; + din1_ptr += stride_w; + din2_ptr += stride_w; + din3_ptr += stride_w; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += C; + } + } + } +} +template <> +void NHWC2NCHW(int N, int C, int size, const int8_t* X, int8_t* Y) { + int cnt = size >> 3; + int remain = size % 8; + int sum = C * size; + int stride = C << 3; // 8 * size + int stride_w = C << 4; // 4 * size + for (int n = 0; n < N; n++) { + const int8_t* din = X + n * sum; + int8_t* dout = Y + n * sum; + int s = 0; +#pragma omp parallel for + for (s = 0; s < C - 7; s += 8) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + C; + const int8_t* din2_ptr = din1_ptr + C; + const int8_t* din3_ptr = din2_ptr + C; + const int8_t* din4_ptr = din3_ptr + C; + const int8_t* din5_ptr = din4_ptr + C; + const int8_t* din6_ptr = din5_ptr + C; + const int8_t* din7_ptr = din6_ptr + C; + int8_t* out0_ptr = dout + s * size; + int8_t* out1_ptr = out0_ptr + size; + int8_t* out2_ptr = out1_ptr + size; + int8_t* out3_ptr = out2_ptr + size; + int8_t* out4_ptr = out3_ptr + size; + int8_t* out5_ptr = out4_ptr + size; + int8_t* out6_ptr = out5_ptr + size; + int8_t* out7_ptr = out6_ptr + size; + int cnt_num = cnt; + if (cnt_num > 0) { +#ifdef __aarch64__ + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v12", + "v13", + "v14", + "v15"); +#else + asm volatile(TRANS_C8 + : [din0_ptr] "+r"(din0_ptr), + [din1_ptr] "+r"(din1_ptr), + [din2_ptr] "+r"(din2_ptr), + [din3_ptr] "+r"(din3_ptr), + [out0_ptr] "+r"(out0_ptr), + [out1_ptr] "+r"(out1_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [out6_ptr] "+r"(out6_ptr), + [out7_ptr] "+r"(out7_ptr), + [cnt] "+r"(cnt_num), + [stride_w] "+r"(stride_w) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); +#endif + } + for (int i = 0; i < remain; i++) { + const int8_t* ptr = din0_ptr; + *out0_ptr++ = *ptr++; + *out1_ptr++ = *ptr++; + *out2_ptr++ = *ptr++; + *out3_ptr++ = *ptr++; + *out4_ptr++ = *ptr++; + *out5_ptr++ = *ptr++; + *out6_ptr++ = *ptr++; + *out7_ptr++ = *ptr++; + din0_ptr += C; + } + } + // remain size + for (; s < C; s++) { + const int8_t* din0_ptr = din + s; + const int8_t* din1_ptr = din0_ptr + C; + const int8_t* din2_ptr = din1_ptr + C; + const int8_t* din3_ptr = din2_ptr + C; + const int8_t* din4_ptr = din3_ptr + C; + const int8_t* din5_ptr = din4_ptr + C; + const int8_t* din6_ptr = din5_ptr + C; + const int8_t* din7_ptr = din6_ptr + C; + int8_t* out0_ptr = dout + s * size; + for (int i = 0; i < cnt; i++) { + *out0_ptr++ = *din0_ptr; + *out0_ptr++ = *din1_ptr; + *out0_ptr++ = *din2_ptr; + *out0_ptr++ = *din3_ptr; + *out0_ptr++ = *din4_ptr; + *out0_ptr++ = *din5_ptr; + *out0_ptr++ = *din6_ptr; + *out0_ptr++ = *din7_ptr; + din0_ptr += stride; + din1_ptr += stride; + din2_ptr += stride; + din3_ptr += stride; + din4_ptr += stride; + din5_ptr += stride; + din6_ptr += stride; + din7_ptr += stride; + } + for (int i = 0; i < remain; i++) { + *out0_ptr++ = *din0_ptr; + din0_ptr += C; + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/layout.h b/lite/backends/arm/math/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..ed0e2f8b78a280c513161a02bb3b3b479008145a --- /dev/null +++ b/lite/backends/arm/math/layout.h @@ -0,0 +1,30 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +template +void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y); + +template +void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index 1c67ee8f3dcafe30d9bda587d62233d0e715071e..d78ae690f9b019dff7728bd3e95c0b1406bea463 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -145,6 +145,12 @@ class KernelRegistry final { KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget *, // KernelRegistryForTargettemplate Param(); \ + auto input = param.x->template data(); \ + auto input_dim = param.x->dims(); \ + CHECK(input_dim.size() == 4) \ + << "NCHW to NHWC should guarantee that the input dims should be 4"; \ + int n = input_dim[0]; \ + int c = input_dim[1]; \ + int h = input_dim[2]; \ + int w = input_dim[3]; \ + param.y->Resize({n, h, w, c}); \ + auto output = param.y->template mutable_data(TARGET(kARM)); \ + if (c == 1) { \ + memcpy(output, input, sizeof(type) * n * h * w); \ + return; \ + } \ + lite::arm::math::NCHW2NHWC(n, c, h * w, input, output); + +#define NHWCTONCHW(type) \ + auto& param = this->template Param(); \ + auto input = param.x->template data(); \ + auto input_dim = param.x->dims(); \ + CHECK(input_dim.size() == 4) \ + << "NHWC to NCHW should guarantee that the input dims should be 4"; \ + int n = input_dim[0]; \ + int h = input_dim[1]; \ + int w = input_dim[2]; \ + int c = input_dim[3]; \ + param.y->Resize({n, c, h, w}); \ + auto output = param.y->template mutable_data(TARGET(kARM)); \ + if (c == 1) { \ + memcpy(output, input, sizeof(type) * n * h * w); \ + return; \ + } \ + lite::arm::math::NHWC2NCHW(n, c, h * w, input, output); + +template <> +void NCHWToNHWCCompute::Run() { + NCHWTONHWC(float); +} + +template <> +void NCHWToNHWCCompute::Run() { + NCHWTONHWC(int8_t); +} + +template <> +void NHWCToNCHWCompute::Run() { + NHWCTONCHW(float); +} + +template <> +void NHWCToNCHWCompute::Run() { + NHWCTONCHW(int8_t); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +typedef paddle::lite::kernels::arm::NCHWToNHWCCompute + NCHW_fp32; +typedef paddle::lite::kernels::arm::NCHWToNHWCCompute + NCHW_int8; +typedef paddle::lite::kernels::arm::NHWCToNCHWCompute + NHWC_fp32; +typedef paddle::lite::kernels::arm::NHWCToNCHWCompute + NHWC_int8; + +REGISTER_LITE_KERNEL(layout, kARM, kFloat, kNCHW, NCHW_fp32, nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout, kARM, kFloat, kNCHW, NHWC_fp32, nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout, kARM, kInt8, kNCHW, NCHW_int8, int8_nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout, kARM, kInt8, kNCHW, NHWC_int8, int8_nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, kARM, kFloat, kNCHW, NCHW_fp32, nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, kARM, kFloat, kNCHW, NHWC_fp32, nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, kARM, kInt8, kNCHW, NCHW_int8, int8_nchw2nhwc) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL(layout_once, kARM, kInt8, kNCHW, NHWC_int8, int8_nhwc2nchw) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/arm/layout_compute.h b/lite/kernels/arm/layout_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..13b8621029437ea18d960e9c22d53b7062983b8f --- /dev/null +++ b/lite/kernels/arm/layout_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { +template +class NCHWToNHWCCompute : public KernelLite { + public: + using param_t = operators::LayoutParam; + void Run() override; + virtual ~NCHWToNHWCCompute() = default; +}; + +template +class NHWCToNCHWCompute : public KernelLite { + public: + using param_t = operators::LayoutParam; + void Run() override; + virtual ~NHWCToNCHWCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/tests/math/CMakeLists.txt b/lite/tests/math/CMakeLists.txt index b199f655239150438ecba881d5e1e4fa1e5dfa31..7dd4f522dbc0f10e8cfb7d19e95da4354ac4b779 100644 --- a/lite/tests/math/CMakeLists.txt +++ b/lite/tests/math/CMakeLists.txt @@ -8,4 +8,10 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(conv_transpose_compute_test SRCS conv_transpose_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_int8_compute_test SRCS conv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(pool_compute_test SRCS pool_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + + if(LITE_BUILD_EXTRA) + lite_cc_test(layout_compute_test SRCS layout_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + endif() + + endif() diff --git a/lite/tests/math/layout_compute_test.cc b/lite/tests/math/layout_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..29f8f749db5a8f0f899500e4356bd472bca4fd13 --- /dev/null +++ b/lite/tests/math/layout_compute_test.cc @@ -0,0 +1,608 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#include "lite/tests/utils/naive_math_impl.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +#ifdef LITE_WITH_ARM +#include "lite/kernels/arm/layout_compute.h" +#endif // LITE_WITH_ARM + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(batch, 1, "batch size"); +DEFINE_int32(in_channel, 32, "input channel"); +DEFINE_int32(in_height, 112, "input height"); +DEFINE_int32(in_width, 112, "input width"); + +DEFINE_bool(flag_nchw, true, "do nchw to nhwc"); + +typedef paddle::lite::DDim DDim; +typedef paddle::lite::Tensor Tensor; +typedef paddle::lite::operators::LayoutParam LayoutParam; + +using paddle::lite::Timer; + +#define IN(n, c, h, w) \ + input_data[w + h * input_w + c * input_h * input_w + \ + n * input_c * input_h * input_w] +#define OUT(n, c, h, w) \ + output_data[w + h * output_w + c * output_h * output_w + \ + n * output_c * output_h * output_w] + +template +void nchw2nhwc_ref(const Tensor* input, Tensor* output) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_c = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, h, w, c) = IN(n, c, h, w); + } + } + } + } +} +#undef IN +#undef OUT + +#define IN(n, h, w, c) \ + input_data[c + w * input_c + h * input_w * input_c + \ + n * input_h * input_w * input_c] +#define OUT(n, h, w, c) \ + output_data[c + w * output_c + h * output_w * output_c + \ + n * output_h * output_w * output_c] +template +void nhwc2nchw_ref(const Tensor* input, Tensor* output) { + auto* input_data = input->data(); + auto* output_data = output->mutable_data(); + + int input_n = input->dims()[0]; + int input_h = input->dims()[1]; + int input_w = input->dims()[2]; + int input_c = input->dims()[3]; + int output_h = output->dims()[1]; + int output_w = output->dims()[2]; + int output_c = output->dims()[3]; + + for (int n = 0; n < input_n; ++n) { + for (int c = 0; c < input_c; ++c) { + for (int h = 0; h < input_h; ++h) { + for (int w = 0; w < input_w; ++w) { + OUT(n, c, h, w) = IN(n, h, w, c); + } + } + } + } +} + +#ifdef LITE_WITH_ARM +void test_layout_fp32_nchw(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + LayoutParam param; + param.x = new Tensor; + const_cast(param.x)->set_precision(PRECISION(kFloat)); + + param.y = new Tensor; + param.y->set_precision(PRECISION(kFloat)); + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::NCHWToNHWCCompute layout; + DDim dim_out({dim_in[0], dim_in[2], dim_in[3], dim_in[1]}); + + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + const_cast(param.x)->Resize(dim_in); + param.y->Resize(dim_out); + + layout.SetParam(param); + + paddle::lite::fill_tensor_rand( + *(const_cast(param.x)), -1.f, 1.f); + // paddle::lite::fill_tensor_const(*param.x, 1.f); + + auto din = param.x->data(); + + Tensor tout_basic; + + if (FLAGS_check_result) { + tout_basic.set_precision(PRECISION(kFloat)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0.f); + auto dout_basic = tout_basic.mutable_data(); + nchw2nhwc_ref(param.x, &tout_basic); + } + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + layout.Run(); + } + /// compute + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + layout.Run(); + t0.end(); + } + double gops = 2.0 * dim_out.production(); + LOG(INFO) << "layout fp32: input shape: " << dim_in << ", output shape" + << dim_out << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "din"; + print_tensor(*(const_cast(param.x))); + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "lite result"; + print_tensor(*param.y); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kFloat)); + tensor_diff(tout_basic, *param.y, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test fp32 layout: input: " << dim_in + << ", output: " << dim_out << ", flag_nchw: " + << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + LOG(INFO) << "test fp32 layout: input: " << dim_in + << ", output: " << dim_out + << ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.y; +} +void test_layout_fp32_nhwc(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + + LayoutParam param; + param.x = new Tensor; + const_cast(param.x)->set_precision(PRECISION(kFloat)); + + param.y = new Tensor; + param.y->set_precision(PRECISION(kFloat)); + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::NHWCToNCHWCompute layout; + // n h w c == n c h w + DDim dim_out({dim_in[0], dim_in[3], dim_in[1], dim_in[2]}); + + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + const_cast(param.x)->Resize(dim_in); + param.y->Resize(dim_out); + + layout.SetParam(param); + + paddle::lite::fill_tensor_rand( + *(const_cast(param.x)), -1.f, 1.f); + // paddle::lite::fill_tensor_const(*param.x, 1.f); + + auto din = param.x->data(); + + Tensor tout_basic; + + if (FLAGS_check_result) { + tout_basic.set_precision(PRECISION(kFloat)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0.f); + auto dout_basic = tout_basic.mutable_data(); + nhwc2nchw_ref(param.x, &tout_basic); + } + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + layout.Run(); + } + /// compute + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + layout.Run(); + t0.end(); + } + double gops = 2.0 * dim_out.production(); + LOG(INFO) << "layout fp32: input shape: " << dim_in << ", output shape" + << dim_out << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "din"; + print_tensor(*(const_cast(param.x))); + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "lite result"; + print_tensor(*param.y); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kFloat)); + tensor_diff(tout_basic, *param.y, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test fp32 layout: input: " << dim_in + << ", output: " << dim_out << ", flag_nchw: " + << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + LOG(INFO) << "test fp32 layout: input: " << dim_in + << ", output: " << dim_out + << ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.y; +} +void test_layout_int8_nchw(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + + LayoutParam param; + param.x = new Tensor; + const_cast(param.x)->set_precision(PRECISION(kInt8)); + + param.y = new Tensor; + param.y->set_precision(PRECISION(kInt8)); + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::NCHWToNHWCCompute layout; + DDim dim_out({dim_in[0], dim_in[2], dim_in[3], dim_in[1]}); + + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + const_cast(param.x)->Resize(dim_in); + param.y->Resize(dim_out); + + layout.SetParam(param); + + paddle::lite::fill_tensor_rand(*(const_cast(param.x))); + // paddle::lite::fill_tensor_const(*param.x, 1.f); + + auto din = param.x->data(); + + Tensor tout_basic; + + if (FLAGS_check_result) { + tout_basic.set_precision(PRECISION(kInt8)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0); + auto dout_basic = tout_basic.mutable_data(); + nchw2nhwc_ref(param.x, &tout_basic); + } + LOG(INFO) << "saber compute"; + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + layout.Run(); + } + /// compute + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + layout.Run(); + t0.end(); + } + LOG(INFO) << "saber compute end"; + double gops = 2.0 * dim_out.production(); + LOG(INFO) << "layout int8: input shape: " << dim_in << ", output shape" + << dim_out << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "din"; + print_tensor(*(const_cast(param.x))); + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "lite result"; + print_tensor(*param.y); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kInt8)); + tensor_diff(tout_basic, *param.y, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test int8 layout: input: " << dim_in + << ", output: " << dim_out << ", flag_nchw: " + << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + LOG(INFO) << "test int8 layout: input: " << dim_in + << ", output: " << dim_out + << ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.y; +} +void test_layout_int8_nhwc(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + + LayoutParam param; + param.x = new Tensor; + const_cast(param.x)->set_precision(PRECISION(kInt8)); + + param.y = new Tensor; + param.y->set_precision(PRECISION(kInt8)); + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::NHWCToNCHWCompute layout; + // n h w c == n c h w + DDim dim_out({dim_in[0], dim_in[3], dim_in[1], dim_in[2]}); + + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + const_cast(param.x)->Resize(dim_in); + param.y->Resize(dim_out); + + layout.SetParam(param); + + paddle::lite::fill_tensor_rand(*(const_cast(param.x))); + // paddle::lite::fill_tensor_const(*param.x, 1.f); + + auto din = param.x->data(); + + Tensor tout_basic; + + if (FLAGS_check_result) { + tout_basic.set_precision(PRECISION(kInt8)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0.f); + auto dout_basic = tout_basic.mutable_data(); + nhwc2nchw_ref(param.x, &tout_basic); + } + LOG(INFO) << "saber compute"; + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + layout.Run(); + } + /// compute + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + layout.Run(); + t0.end(); + } + LOG(INFO) << "run"; + double gops = 2.0 * dim_out.production(); + LOG(INFO) << "layout int8: input shape: " << dim_in << ", output shape" + << dim_out << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "din"; + print_tensor(*(const_cast(param.x))); + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "lite result"; + print_tensor(*param.y); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kInt8)); + tensor_diff(tout_basic, *param.y, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test int8 layout: input: " << dim_in + << ", output: " << dim_out << ", flag_nchw: " + << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + LOG(INFO) << "test int8 layout: input: " << dim_in + << ", output: " << dim_out + << ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.y; +} +#else +void test_layout_fp32_nchw(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) {} +void test_layout_fp32_nhwc(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) {} +void test_layout_int8_nchw(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) {} +void test_layout_int8_nhwc(DDim dim_in, + bool flag_nchw, + const std::vector& thread_num, + const std::vector& power_mode) {} +#endif // LITE_WITH_ARM + +#if 1 // +TEST(TestLayout, test_Layout_fp32) { + if (FLAGS_basic_test) { + for (auto n : {1, 3}) { + for (auto c : {1, 3, 5, 32}) { + for (auto h : {3, 16, 20, 32}) { + for (auto w : {3, 4, 32, 112}) { + for (auto nchw2nhwc : {true, false}) { + DDim dim_in({n, c, h, w}); + if (nchw2nhwc) { + LOG(INFO) << "NCHW2NHWC"; + test_layout_fp32_nchw( + dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode}); + } else { + LOG(INFO) << "NHWC2NCHW"; + test_layout_fp32_nhwc( + dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode}); + } + } + } + } + } + } + } +} +#endif +#if 1 +TEST(TestLayout, test_Layout_int8) { + if (FLAGS_basic_test) { + for (auto n : {1, 3}) { + for (auto c : {1, 3, 5, 32}) { + for (auto h : {3, 16, 20, 32}) { + for (auto w : {3, 4, 32, 112}) { + for (auto nchw2nhwc : {true, false}) { + DDim dim_in({n, c, h, w}); + if (nchw2nhwc) { + LOG(INFO) << "NCHW2NHWC int8"; + test_layout_int8_nchw( + dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode}); + } else { + LOG(INFO) << "NHWC2NCHW int8"; + test_layout_int8_nhwc( + dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode}); + } + } + } + } + } + } + } +} +#endif + +#if 1 /// custom +TEST(TestLayoutCustom, test_Layout_custom_size) { + test_layout_fp32_nchw( + {DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})}, + true, + {FLAGS_threads}, + {FLAGS_power_mode}); +} +#endif // custom