提交 b1e4f4fd 编写于 作者: H HappyAngel 提交者: GitHub

add NHWC NCHW transform, test=develop (#2381)

* add nhwc to nchw

* add layout in funcs

* change layout as extra, test=develop

* change make, test=develop

* use template class method to update layout NNCHHW and NHWC transform, test=develop

* fix cmake error, set layout to extra, test=develop

* fix test_layout_compute_arm test, its extra

* layout is extra, test=develop

* fix error in kernels/arm/layout_comput.cc when register kernel, DataLayout must be NCHW, test=develop

* delete extra note, test=develop

* delete extra test

* delete layout_test, test=develop

, its in tests/math/layout_comput_test

* delete extrat test, test=develop
上级 2f3336c1
......@@ -57,8 +57,8 @@ endif()
if (NOT HAS_ARM_MATH_LIB_DIR)
# TODO(xxx): seperate them and do not deps proto, eigen3
cc_library(math_arm SRCS
funcs.cc
cc_library(math_arm SRCS
funcs.cc
packed_sgemm.cc
packed_sgemm_c4.cc
sgemm.cc
......@@ -68,8 +68,10 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
gemv_arm_int8.cc
conv3x3s1_direct_fp32.cc
conv3x3s2_direct_fp32.cc
conv3x3s1_depthwise_fp32.cc
conv3x3s2_depthwise_fp32.cc
conv3x3s1p01_depthwise_fp32.cc
conv3x3s2p01_depthwise_fp32.cc
conv3x3s1px_depthwise_fp32.cc
conv3x3s2px_depthwise_fp32.cc
conv3x3s1_direct_int8.cc
conv3x3s2_direct_int8.cc
conv3x3s1_depthwise_int8.cc
......@@ -77,16 +79,13 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s1_depthwise_int8.cc
conv5x5s1_depthwise_fp32.cc
conv5x5s2_depthwise_fp32.cc
conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc
conv_depthwise_3x3s1.cc
conv_depthwise_3x3s2.cc
conv_winograd_3x3.cc
conv_impl.cc
softmax.cc
softmax.cc
scale.cc
pooling.cc
elementwise.cc
layout.cc
lrn.cc
decode_bboxes.cc
concat.cc
......@@ -122,4 +121,3 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
anchor_generator.cc
DEPS ${lite_kernel_deps} context tensor)
endif()
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_depthwise.h"
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle {
namespace lite {
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_depthwise.h"
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle {
namespace lite {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_depthwise.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void conv_depthwise_3x3s1p0_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s1p0_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s2p0_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p0_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s1p0_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s2p0_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3p0_fp32(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int stride,
bool flag_bias,
bool flag_relu,
ARMContext* ctx) {
if (stride == 1) {
if (flag_relu) {
if (w_in > 5) {
conv_depthwise_3x3s1p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (w_in > 5) {
conv_depthwise_3x3s1p0_bias(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
} else { //! stride = 2
if (flag_relu) {
if (w_in > 8) {
conv_depthwise_3x3s2p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (w_in > 8) {
conv_depthwise_3x3s2p0_bias(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
// 4line
void conv_depthwise_3x3s1p0_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
#ifdef __aarch64__
for (int c = 0; c < ch_in; c++) {
float* dout_ptr = dout_batch + c * size_out_channel;
const float* din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float* wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
// wr0 = vsetq_lane_f32(0.f, wr0, 3);
// wr1 = vsetq_lane_f32(0.f, wr1, 3);
// wr2 = vsetq_lane_f32(0.f, wr2, 3);
float* doutr0 = dout_ptr;
float* doutr1 = doutr0 + w_out;
float* doutr2 = doutr1 + w_out;
float* doutr3 = doutr2 + w_out;
const float* dr0 = din_ch_ptr;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* dr5 = dr4 + w_in;
const float* din_ptr0 = dr0;
const float* din_ptr1 = dr1;
const float* din_ptr2 = dr2;
const float* din_ptr3 = dr3;
const float* din_ptr4 = dr4;
const float* din_ptr5 = dr5;
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 5:
din_ptr1 = zero_ptr;
case 4:
din_ptr2 = zero_ptr;
case 3:
din_ptr3 = zero_ptr;
case 2:
din_ptr4 = zero_ptr;
case 1:
din_ptr5 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
"PRFM PLDL1KEEP, [%[din_ptr0]] \n"
"PRFM PLDL1KEEP, [%[din_ptr1]] \n"
"PRFM PLDL1KEEP, [%[din_ptr2]] \n"
"PRFM PLDL1KEEP, [%[din_ptr3]] \n"
"PRFM PLDL1KEEP, [%[din_ptr4]] \n"
"PRFM PLDL1KEEP, [%[din_ptr5]] \n"
"movi v21.4s, #0x0\n" /* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
// mid
// "cmp %[cnt], #1 \n"
// "blt 5f \n"
"4: \n"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"st1 {v12.4s}, [%[doutr0]], #16 \n"
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
// r4
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"st1 {v13.4s}, [%[doutr1]], #16 \n"
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
// r5
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"subs %[cnt], %[cnt], #1 \n"
"st1 {v15.4s}, [%[doutr3]], #16 \n"
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"bne 4b \n"
// right
"5: \n"
"cmp %[remain], #1 \n"
"blt 0f \n"
"ld1 {v18.4s, v19.4s}, [%[vmask]] \n"
"ld1 {v22.4s}, [%[doutr0]] \n"
"ld1 {v23.4s}, [%[doutr1]] \n"
"ld1 {v24.4s}, [%[doutr2]] \n"
"ld1 {v25.4s}, [%[doutr3]] \n"
"bif v0.16b, %[vzero].16b, v18.16b \n"
"bif v1.16b, %[vzero].16b, v19.16b \n"
"bif v2.16b, %[vzero].16b, v18.16b \n"
"bif v3.16b, %[vzero].16b, v19.16b \n"
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v4.16b, %[vzero].16b, v18.16b \n"
"bif v5.16b, %[vzero].16b, v19.16b \n"
"bif v6.16b, %[vzero].16b, v18.16b \n"
"bif v7.16b, %[vzero].16b, v19.16b \n"
"fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v8.16b, %[vzero].16b, v18.16b \n"
"bif v9.16b, %[vzero].16b, v19.16b \n"
"bif v10.16b, %[vzero].16b, v18.16b \n"
"bif v11.16b, %[vzero].16b, v19.16b \n"
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
"ld1 {v18.4s}, [%[rmask]] \n"
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v12.16b, v22.16b, v18.16b \n"
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16 \n"
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v13.16b, v23.16b, v18.16b \n"
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16 \n"
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v14.16b, v24.16b, v18.16b \n"
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"bif v15.16b, v25.16b, v18.16b \n"
"st1 {v15.4s}, [%[doutr3]], #16 \n"
// end
"0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
}
#else
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_val = flag_bias ? bias[i] : 0.f;
float* dout_channel = dout_batch + i * size_out_channel;
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* din0_ptr = nullptr;
const float* din1_ptr = nullptr;
const float* din2_ptr = nullptr;
const float* din3_ptr = nullptr;
float* doutr0 = nullptr;
float* doutr1 = nullptr;
float* ptr_zero = const_cast<float*>(zero);
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
doutr0 = dout_channel;
doutr1 = dout_channel + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 3 >= h_in) {
switch (i + 3 - h_in) {
case 3:
din1_ptr = zero_ptr;
case 2:
din2_ptr = zero_ptr;
case 1:
din3_ptr = zero_ptr;
case 0:
din3_ptr = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int* rmask_ptr = rmask;
unsigned int* vmask_ptr = vmask;
asm volatile(
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// mid
"1: @ right pad entry\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vdup.32 q5, %[bias_val] @ and \n" // q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point\n"
// right
"3: @ right pad entry\n"
"cmp %[remain], #1 @ check whether has "
"mid cols\n"
"blt 0f @ jump to main loop start "
"point\n"
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d31}, [%[vmask]]! @ load din r0\n"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d19}, [%[rmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[rmask]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vbif d8, d16, d19 @ bit select, deal with right pad\n"
"vbif d9, d17, d23 @ bit select, deal with right pad\n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad\n"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_channel += 2 * w_out;
} //! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2
*/
// w_in > 7
void conv_depthwise_3x3s2p0_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3));
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t wbias;
float bias_c = 0.f;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
bias_c = bias[i];
} else {
wbias = vdupq_n_f32(0.f);
}
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
dr0 = dr4;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 >= h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
case 0:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
asm volatile(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"prfm pldl1keep, [%[inptr0]] \n"
"prfm pldl1keep, [%[inptr1]] \n"
"prfm pldl1keep, [%[inptr2]] \n"
"prfm pldl1keep, [%[inptr3]] \n"
"prfm pldl1keep, [%[inptr4]] \n"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
// mid
"2: \n"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fadd v17.4s, v17.4s, v13.4s \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"fadd v17.4s, v17.4s, v14.4s \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"subs %[cnt], %[cnt], #1 \n"
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 4f \n"
"3: \n"
"bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]] \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
"ld1 {v1.4s}, [%[outptr1]] \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"bif v16.16b, v0.16b, %[wmask].16b \n" // pipei
"fadd v17.4s, v17.4s, v13.4s \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fadd v17.4s, v17.4s, v14.4s \n"
"bif v17.16b, v1.16b, %[wmask].16b \n" // pipei
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
//! process bottom pad
if (i + 2 > h_in) {
switch (i + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"vmov.u32 q9, #0 \n"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n"
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n"
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
// mid
"2: \n"
"vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"subs %[cnt], #1 \n"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14}
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 3f \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n"
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output\n"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vbif.f32 q3, q10, q11 @ write mask\n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"3: \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
// 4line
void conv_depthwise_3x3s1p0_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
#ifdef __aarch64__
for (int c = 0; c < ch_in; c++) {
float* dout_ptr = dout_batch + c * size_out_channel;
const float* din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float* wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
// wr0 = vsetq_lane_f32(0.f, wr0, 3);
// wr1 = vsetq_lane_f32(0.f, wr1, 3);
// wr2 = vsetq_lane_f32(0.f, wr2, 3);
float* doutr0 = dout_ptr;
float* doutr1 = doutr0 + w_out;
float* doutr2 = doutr1 + w_out;
float* doutr3 = doutr2 + w_out;
const float* dr0 = din_ch_ptr;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* dr5 = dr4 + w_in;
const float* din_ptr0 = dr0;
const float* din_ptr1 = dr1;
const float* din_ptr2 = dr2;
const float* din_ptr3 = dr3;
const float* din_ptr4 = dr4;
const float* din_ptr5 = dr5;
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 5:
din_ptr1 = zero_ptr;
case 4:
din_ptr2 = zero_ptr;
case 3:
din_ptr3 = zero_ptr;
case 2:
din_ptr4 = zero_ptr;
case 1:
din_ptr5 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
"PRFM PLDL1KEEP, [%[din_ptr0]] \n"
"PRFM PLDL1KEEP, [%[din_ptr1]] \n"
"PRFM PLDL1KEEP, [%[din_ptr2]] \n"
"PRFM PLDL1KEEP, [%[din_ptr3]] \n"
"PRFM PLDL1KEEP, [%[din_ptr4]] \n"
"PRFM PLDL1KEEP, [%[din_ptr5]] \n"
"movi v21.4s, #0x0\n" /* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
// mid
"4: \n"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16 \n"
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
// r4
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16 \n"
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
// r5
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */
"subs %[cnt], %[cnt], #1 \n"
"st1 {v15.4s}, [%[doutr3]], #16 \n"
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"bne 4b \n"
// right
"5: \n"
"cmp %[remain], #1 \n"
"blt 0f \n"
"ld1 {v18.4s, v19.4s}, [%[vmask]] \n"
"ld1 {v22.4s}, [%[doutr0]] \n"
"ld1 {v23.4s}, [%[doutr1]] \n"
"ld1 {v24.4s}, [%[doutr2]] \n"
"ld1 {v25.4s}, [%[doutr3]] \n"
"bif v0.16b, %[vzero].16b, v18.16b \n"
"bif v1.16b, %[vzero].16b, v19.16b \n"
"bif v2.16b, %[vzero].16b, v18.16b \n"
"bif v3.16b, %[vzero].16b, v19.16b \n"
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v4.16b, %[vzero].16b, v18.16b \n"
"bif v5.16b, %[vzero].16b, v19.16b \n"
"bif v6.16b, %[vzero].16b, v18.16b \n"
"bif v7.16b, %[vzero].16b, v19.16b \n"
"fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v8.16b, %[vzero].16b, v18.16b \n"
"bif v9.16b, %[vzero].16b, v19.16b \n"
"bif v10.16b, %[vzero].16b, v18.16b \n"
"bif v11.16b, %[vzero].16b, v19.16b \n"
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
"ld1 {v18.4s}, [%[rmask]] \n"
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v12.4s, v12.4s, %[vzero].4s \n" /* relu */
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v12.16b, v22.16b, v18.16b \n"
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
"st1 {v12.4s}, [%[doutr0]], #16 \n"
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v13.4s, v13.4s, %[vzero].4s \n" /* relu */
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v13.16b, v23.16b, v18.16b \n"
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"st1 {v13.4s}, [%[doutr1]], #16 \n"
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v14.4s, v14.4s, %[vzero].4s \n" /* relu */
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v14.16b, v24.16b, v18.16b \n"
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"fmax v15.4s, v15.4s, %[vzero].4s \n" /* relu */
"bif v15.16b, v25.16b, v18.16b \n"
"st1 {v15.4s}, [%[doutr3]], #16 \n"
// end
"0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
}
#else
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_val = flag_bias ? bias[i] : 0.f;
float* dout_channel = dout_batch + i * size_out_channel;
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* din0_ptr = nullptr;
const float* din1_ptr = nullptr;
const float* din2_ptr = nullptr;
const float* din3_ptr = nullptr;
float* doutr0 = nullptr;
float* doutr1 = nullptr;
float* ptr_zero = const_cast<float*>(zero);
for (int i = 0; i < h_out; i += 2) {
//! process top pad pad_h = 1
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
doutr0 = dout_channel;
doutr1 = dout_channel + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 3 >= h_in) {
switch (i + 3 - h_in) {
case 3:
din1_ptr = zero_ptr;
case 2:
din2_ptr = zero_ptr;
case 1:
din3_ptr = zero_ptr;
case 0:
din3_ptr = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int* rmask_ptr = rmask;
unsigned int* vmask_ptr = vmask;
asm volatile(
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r1\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r2\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r3\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// mid
"1: @ right pad entry\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n"
"vmax.f32 q4, q4, %q[vzero] @ relu \n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
"vmax.f32 q5, q5, %q[vzero] @ relu \n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vdup.32 q5, %[bias_val] @ and \n" // q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point\n"
// right
"3: @ right pad entry\n"
"cmp %[remain], #1 @ check whether has "
"mid cols\n"
"blt 0f @ jump to main loop start "
"point\n"
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d31}, [%[vmask]]! @ load din r0\n"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d19}, [%[rmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[rmask]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vmax.f32 q4, q4, %q[vzero] @ relu \n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vbif d8, d16, d19 @ bit select, deal with right pad\n"
"vbif d9, d17, d23 @ bit select, deal with right pad\n"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmax.f32 q5, q5, %q[vzero] @ relu \n"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad\n"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_channel += 2 * w_out;
} //! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, with reulu
*/
// w_in > 7
void conv_depthwise_3x3s2p0_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3));
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t wbias;
float bias_c = 0.f;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
bias_c = bias[i];
} else {
wbias = vdupq_n_f32(0.f);
}
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
dr0 = dr4;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 >= h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
case 0:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
asm volatile(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"prfm pldl1keep, [%[inptr0]] \n"
"prfm pldl1keep, [%[inptr1]] \n"
"prfm pldl1keep, [%[inptr2]] \n"
"prfm pldl1keep, [%[inptr3]] \n"
"prfm pldl1keep, [%[inptr4]] \n"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
// mid
"2: \n"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */
"fadd v17.4s, v17.4s, v13.4s \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fadd v17.4s, v17.4s, v14.4s \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */
"subs %[cnt], %[cnt], #1 \n"
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 4f \n"
"3: \n"
"bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]] \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
"ld1 {v1.4s}, [%[outptr1]] \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */
"fadd v17.4s, v17.4s, v13.4s \n"
"bif v16.16b, v0.16b, %[wmask].16b \n" // pipei
"fadd v17.4s, v17.4s, v14.4s \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */
"bif v17.16b, v1.16b, %[wmask].16b \n" // pipei
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
//! process bottom pad
if (i + 2 > h_in) {
switch (i + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"vmov.u32 q9, #0 \n"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n"
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n"
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
// mid
"2: \n"
"vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"subs %[cnt], #1 \n"
"vmax.f32 q3, q3, q9 @ relu \n"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14}
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 3f \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n"
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output\n"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vmax.f32 q3, q3, q9 @ relu \n"
"vbif.f32 q3, q10, q11 @ write mask\n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"3: \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p0_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float* dout_channel = dout_batch + i * size_out_channel;
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float* doutr0 = dout_channel;
float* doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float* dr0 = din_channel + j * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
doutr1 = trash_buf;
case 0:
dr3 = zero_ptr;
doutr1 = trash_buf;
default:
break;
}
}
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[din0]]\n"
"prfm pldl1keep, [%[din1]]\n"
"prfm pldl1keep, [%[din2]]\n"
"prfm pldl1keep, [%[din3]]\n"
"ld1 {v0.4s, v1.4s}, [%[din0]]\n"
"ld1 {v2.4s, v3.4s}, [%[din1]]\n"
"ld1 {v4.4s, v5.4s}, [%[din2]]\n"
"ld1 {v6.4s, v7.4s}, [%[din3]]\n"
"bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234
"bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234
"bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234
"bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234
"bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234
"bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234
"bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234
"bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234
"ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345
"ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450
"and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias
"and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias
// r0
"fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0]
"fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1]
"fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2]
"ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345
"ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450
// r1
"fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0]
"fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0]
"fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2]
"ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345
"ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450
// r2
"fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0]
"fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2]
"ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345
"ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450
// r3
"fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1]
"fadd v12.4s, v12.4s, v10.4s\n"
"fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2]
"fadd v12.4s, v12.4s, v11.4s\n" // out1
"fadd v13.4s, v13.4s, v14.4s\n" // out2
"fadd v13.4s, v13.4s, v15.4s\n" // out2
"prfm pldl1keep, [%[out1]]\n"
"prfm pldl1keep, [%[out2]]\n"
"st1 {v12.4s}, [%[out1]]\n"
"st1 {v13.4s}, [%[out2]]\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
unsigned int* vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
asm volatile(
"pld [%[din0]]\n"
"pld [%[din1]]\n"
"pld [%[din2]]\n"
"pld [%[din3]]\n"
"vld1.32 {d16-d18}, [%[din0]] @ load din r0\n"
"vld1.32 {d20-d22}, [%[din1]] @ load din r1\n"
"vld1.32 {d24-d26}, [%[din2]] @ load din r2\n"
"vld1.32 {d28-d30}, [%[din3]] @ load din r3\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vadd.f32 q4, q4, q10 @ q4 += q10 \n"
"pld [%[out1]]\n"
"pld [%[out2]]\n"
"vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vadd.f32 q4, q4, q11 @ q4 += q10 \n"
"vadd.f32 q5, q5, q8 @ q4 += q10 \n"
"vadd.f32 q5, q5, q9 @ q4 += q10 \n"
"vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/
void conv_depthwise_3x3s2p0_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
float out_buf[4];
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
for (int j = 0; j < h_out; ++j) {
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0 \n"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6}
// v15={1,3,5,7}
"and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias
"bif v10.16b, v9.16b, v6.16b \n"
"bif v11.16b, v9.16b, v7.16b \n"
"bif v12.16b, v9.16b, v6.16b \n"
"bif v13.16b, v9.16b, v7.16b \n"
"bif v14.16b, v9.16b, v6.16b \n"
"bif v15.16b, v9.16b, v7.16b \n"
"ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 =
// {2,4,6,8}
"ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 =
// {2,4,6,8}
"ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 =
// {2,4,6,8}
"fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00
"fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01
"fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02
"fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12
"fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21
"fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22
"fadd v4.4s, v4.4s, v5.4s \n"
"fadd v4.4s, v4.4s, v16.4s \n"
// "fadd v4.4s, v4.4s, %[bias].4s \n"
"st1 {v4.4s}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0 \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n"
"vdup.32 q3, %[bias] @ and \n" // q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0}
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0}
"vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // {0,2,4,6}
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // {1,3,5,7}
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // {2,4,6,0}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q12 * w11
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q13 * w12
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q7 * w10
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q14 * w20
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q15 * w21
"vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, "
"out0\n" // q8 * w22
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vst1.32 {d6-d7}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p0_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float* dout_channel = dout_batch + i * size_out_channel;
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float* doutr0 = dout_channel;
float* doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float* dr0 = din_channel + j * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
doutr1 = trash_buf;
case 0:
dr3 = zero_ptr;
doutr1 = trash_buf;
default:
break;
}
}
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[din0]]\n"
"prfm pldl1keep, [%[din1]]\n"
"prfm pldl1keep, [%[din2]]\n"
"prfm pldl1keep, [%[din3]]\n"
"ld1 {v0.4s, v1.4s}, [%[din0]]\n"
"ld1 {v2.4s, v3.4s}, [%[din1]]\n"
"ld1 {v4.4s, v5.4s}, [%[din2]]\n"
"ld1 {v6.4s, v7.4s}, [%[din3]]\n"
"bif v0.16b, %[zero].16b, %[mask1].16b\n" // d0_1234
"bif v1.16b, %[zero].16b, %[mask2].16b\n" // d0_1234
"bif v2.16b, %[zero].16b, %[mask1].16b\n" // d1_1234
"bif v3.16b, %[zero].16b, %[mask2].16b\n" // d1_1234
"bif v4.16b, %[zero].16b, %[mask1].16b\n" // d2_1234
"bif v5.16b, %[zero].16b, %[mask2].16b\n" // d2_1234
"bif v6.16b, %[zero].16b, %[mask1].16b\n" // d3_1234
"bif v7.16b, %[zero].16b, %[mask2].16b\n" // d3_1234
"ext v8.16b, v0.16b, v1.16b, #4\n" // d1_2345
"ext v9.16b, v0.16b, v1.16b, #8\n" // d1_3450
"and v12.16b, %[vbias].16b, %[vbias].16b \n" // v12 = vbias
"and v13.16b, %[vbias].16b, %[vbias].16b \n" // v13 = vbias
// r0
"fmul v10.4s, v0.4s, %[wr0].s[0]\n" // d0_1234 * w0[0]
"fmul v11.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1]
"fmla v12.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2]
"ext v8.16b, v2.16b, v3.16b, #4\n" // d1_2345
"ext v9.16b, v2.16b, v3.16b, #8\n" // d1_3450
// r1
"fmul v14.4s, v2.4s, %[wr0].s[0]\n" // d0_1234 * w0[0]
"fmla v10.4s, v2.4s, %[wr1].s[0]\n" // d0_1234 * w0[0]
"fmul v15.4s, v8.4s, %[wr0].s[1]\n" // d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" // d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2]
"ext v8.16b, v4.16b, v5.16b, #4\n" // d1_2345
"ext v9.16b, v4.16b, v5.16b, #8\n" // d1_3450
// r2
"fmla v14.4s, v4.4s, %[wr1].s[0]\n" // d0_1234 * w0[0]
"fmla v10.4s, v4.4s, %[wr2].s[0]\n" // d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr1].s[1]\n" // d1_2345 * w0[1]
"fmla v11.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1]
"fmla v13.4s, v9.4s, %[wr1].s[2]\n" // d0_3456 * w0[2]
"fmla v12.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2]
"ext v8.16b, v6.16b, v7.16b, #4\n" // d1_2345
"ext v9.16b, v6.16b, v7.16b, #8\n" // d1_3450
// r3
"fmla v14.4s, v6.4s, %[wr2].s[0]\n" // d0_1234 * w0[0]
"fmla v15.4s, v8.4s, %[wr2].s[1]\n" // d1_2345 * w0[1]
"fadd v12.4s, v12.4s, v10.4s\n"
"fmla v13.4s, v9.4s, %[wr2].s[2]\n" // d0_3456 * w0[2]
"fadd v12.4s, v12.4s, v11.4s\n" // out1
"fadd v13.4s, v13.4s, v14.4s\n" // out2
"fadd v13.4s, v13.4s, v15.4s\n" // out2
"prfm pldl1keep, [%[out1]]\n"
"prfm pldl1keep, [%[out2]]\n"
"fmax v12.4s, v12.4s, %[zero].4s \n"
"fmax v13.4s, v13.4s, %[zero].4s \n"
"st1 {v12.4s}, [%[out1]]\n"
"st1 {v13.4s}, [%[out2]]\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
unsigned int* vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
asm volatile(
"pld [%[din0]]\n"
"pld [%[din1]]\n"
"pld [%[din2]]\n"
"pld [%[din3]]\n"
"vld1.32 {d16-d18}, [%[din0]] @ load din r0\n"
"vld1.32 {d20-d22}, [%[din1]] @ load din r1\n"
"vld1.32 {d24-d26}, [%[din2]] @ load din r2\n"
"vld1.32 {d28-d30}, [%[din3]] @ load din r3\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vadd.f32 q4, q4, q10 @ q4 += q10 \n"
"pld [%[out1]]\n"
"pld [%[out2]]\n"
"vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vadd.f32 q4, q4, q11 @ q4 += q10 \n"
"vadd.f32 q5, q5, q8 @ q4 += q10 \n"
"vadd.f32 q5, q5, q9 @ q4 += q10 \n"
"vmax.f32 q4, q4, %q[vzero] @ relu \n"
"vmax.f32 q5, q5, %q[vzero] @ relu \n"
"vst1.32 {d8-d9}, [%[out1]] @ store result, add pointer\n"
"vst1.32 {d10-d11}, [%[out2]] @ store result, add pointer\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
// doutr0 = doutr1;
// doutr1 += w_out;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 7
*/
void conv_depthwise_3x3s2p0_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
float out_buf[4];
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
for (int j = 0; j < h_out; ++j) {
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0 \n"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]] \n"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6}
// v15={1,3,5,7}
"and v4.16b, %[bias].16b, %[bias].16b \n" // v10 = vbias
"bif v10.16b, v9.16b, v6.16b \n"
"bif v11.16b, v9.16b, v7.16b \n"
"bif v12.16b, v9.16b, v6.16b \n"
"bif v13.16b, v9.16b, v7.16b \n"
"bif v14.16b, v9.16b, v6.16b \n"
"bif v15.16b, v9.16b, v7.16b \n"
"ext v6.16b, v10.16b, v9.16b, #4 \n" // v6 =
// {2,4,6,8}
"ext v7.16b, v12.16b, v9.16b, #4 \n" // v6 =
// {2,4,6,8}
"ext v8.16b, v14.16b, v9.16b, #4 \n" // v6 =
// {2,4,6,8}
"fmla v4.4s, v10.4s, %[wr0].s[0] \n" // 0246 * w00
"fmul v5.4s, v11.4s, %[wr0].s[1] \n" // 1357 * w01
"fmul v16.4s, v6.4s, %[wr0].s[2] \n" // 2468 * w02
"fmla v4.4s, v12.4s, %[wr1].s[0] \n" // v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[1] \n" // v13 * w12
"fmla v16.4s, v7.4s, %[wr1].s[2] \n" // v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[0] \n" // v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[1] \n" // v15 * w21
"fmla v16.4s, v8.4s, %[wr2].s[2] \n" // v8 * w22
"fadd v4.4s, v4.4s, v5.4s \n"
"fadd v4.4s, v4.4s, v16.4s \n"
"fmax v4.4s, v4.4s, v9.4s \n"
// "fadd v4.4s, v4.4s, %[bias].4s \n"
"st1 {v4.4s}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf),
[mask_ptr] "r"(mask_ptr)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0 \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n"
"vdup.32 q3, %[bias] @ and \n" // q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,0}
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q7 = {2,4,6,0}
"vext.32 q8, q14, q9, #1 @ shift left 1 \n" // q8 = {2,4,6,0}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // {0,2,4,6}
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // {1,3,5,7}
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // {2,4,6,0}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q12 * w11
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q13 * w12
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q7 * w10
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q14 * w20
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q15 * w21
"vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, "
"out0\n" // q8 * w22
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vmax.f32 q3, q3, q9 @ relu \n"
"vst1.32 {d6-d7}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(mask_ptr)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_depthwise.h"
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void conv_depthwise_3x3s1p1_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s1p1_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s2p1_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p1_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s1p1_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
//! for input width <= 4
void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3p1_fp32(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int stride,
bool flag_bias,
bool flag_relu,
ARMContext* ctx) {
if (stride == 1) {
if (flag_relu) {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
} else { //! stride = 2
if (flag_relu) {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
// 4line
void conv_depthwise_3x3s1p1_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
// printf("conv3x3_dw start \n");
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = (w_in + 3) >> 2;
int cnt_col = tile_w - 2;
unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in);
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
#ifdef __aarch64__
for (int c = 0; c < ch_in; c++) {
float* dout_ptr = dout_batch + c * size_out_channel;
const float* din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float* wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float* doutr0 = dout_ptr;
float* doutr1 = doutr0 + w_out;
float* doutr2 = doutr1 + w_out;
float* doutr3 = doutr2 + w_out;
const float* dr0 = din_ch_ptr;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* dr5 = dr4 + w_in;
const float* din_ptr0 = dr0;
const float* din_ptr1 = dr1;
const float* din_ptr2 = dr2;
const float* din_ptr3 = dr3;
const float* din_ptr4 = dr4;
const float* din_ptr5 = dr5;
for (int i = 0; i < h_in; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
if (i == 0) {
din_ptr0 = zero_ptr;
din_ptr1 = dr0;
din_ptr2 = dr1;
din_ptr3 = dr2;
din_ptr4 = dr3;
din_ptr5 = dr4;
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
} else {
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
}
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 > h_in) {
switch (i + 5 - h_in) {
case 5:
din_ptr1 = zero_ptr;
case 4:
din_ptr2 = zero_ptr;
case 3:
din_ptr3 = zero_ptr;
case 2:
din_ptr4 = zero_ptr;
case 1:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = cnt_col;
asm volatile(
"PRFM PLDL1KEEP, [%[din_ptr0]] \n"
"PRFM PLDL1KEEP, [%[din_ptr1]] \n"
"PRFM PLDL1KEEP, [%[din_ptr2]] \n"
"PRFM PLDL1KEEP, [%[din_ptr3]] \n"
"PRFM PLDL1KEEP, [%[din_ptr4]] \n"
"PRFM PLDL1KEEP, [%[din_ptr5]] \n"
"movi v21.4s, #0x0\n" /* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
// left
// r0
"fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 *
w0[1]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */
"sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */
"fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 *
w0[0]*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */
"sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 *
w0[2]*/
"ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 *
w1[1]*/
"sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */
"sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */
"fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 *
w1[1]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w2[1]*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w1[1]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w2[1]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w1[1]*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */
// r4
"fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w2[1]*/
"st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w1[1]*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */
// r5
"fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */
"cmp %[cnt], #1 \n"
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"blt 3f \n"
// mid
"1: \n"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"st1 {v12.4s}, [%[doutr0]], #16 \n"
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"st1 {v13.4s}, [%[doutr1]], #16 \n"
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"subs %[cnt], %[cnt], #1 \n"
"st1 {v15.4s}, [%[doutr3]], #16 \n"
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"bne 1b \n"
// right
"3: \n"
"ld1 {v18.4s, v19.4s}, [%[vmask]] \n"
"ld1 {v22.4s}, [%[doutr0]] \n"
"ld1 {v23.4s}, [%[doutr1]] \n"
"ld1 {v24.4s}, [%[doutr2]] \n"
"ld1 {v25.4s}, [%[doutr3]] \n"
"bif v0.16b, %[vzero].16b, v18.16b \n"
"bif v1.16b, %[vzero].16b, v19.16b \n"
"bif v2.16b, %[vzero].16b, v18.16b \n"
"bif v3.16b, %[vzero].16b, v19.16b \n"
"bif v4.16b, %[vzero].16b, v18.16b \n"
"bif v5.16b, %[vzero].16b, v19.16b \n"
"bif v6.16b, %[vzero].16b, v18.16b \n"
"bif v7.16b, %[vzero].16b, v19.16b \n"
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v8.16b, %[vzero].16b, v18.16b \n"
"bif v9.16b, %[vzero].16b, v19.16b \n"
"bif v10.16b, %[vzero].16b, v18.16b \n"
"bif v11.16b, %[vzero].16b, v19.16b \n"
"fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v18.4s}, [%[rmask]] \n"
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v12.16b, v22.16b, v18.16b \n"
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16 \n"
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v13.16b, v23.16b, v18.16b \n"
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16 \n"
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v14.16b, v24.16b, v18.16b \n"
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"bif v15.16b, v25.16b, v18.16b \n"
"st1 {v15.4s}, [%[doutr3]], #16 \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
}
#else
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_val = flag_bias ? bias[i] : 0.f;
float* dout_channel = dout_batch + i * size_out_channel;
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* din0_ptr = nullptr;
const float* din1_ptr = nullptr;
const float* din2_ptr = nullptr;
const float* din3_ptr = nullptr;
float* doutr0 = nullptr;
float* doutr1 = nullptr;
float* ptr_zero = const_cast<float*>(zero);
for (int i = 0; i < h_in; i += 2) {
//! process top pad pad_h = 1
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
doutr0 = dout_channel;
doutr1 = dout_channel + w_out;
// unsigned int* rst_mask = rmask;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
dr2 = dr3;
dr3 = dr2 + w_in;
} else {
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
}
//! process bottom pad
if (i + 3 > h_in) {
switch (i + 3 - h_in) {
case 3:
din1_ptr = zero_ptr;
case 2:
din2_ptr = zero_ptr;
case 1:
din3_ptr = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = cnt_col;
unsigned int* rmask_ptr = rmask;
unsigned int* vmask_ptr = vmask;
asm volatile(
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n"
"vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n"
"vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n"
"vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"vext.32 q6, %q[vzero], q8, #3 @ 0012\n"
"vext.32 q7, q8, q9, #1 @ 1234\n"
// left
// r0
"vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n"
"sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n"
"sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n"
"sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n"
"vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, %q[vzero], q10, #3 @ 0012\n"
"vext.32 q7, q10, q11, #1 @ 1234\n"
// r1
"vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, %q[vzero], q12, #3 @ 0012\n"
"vext.32 q7, q12, q13, #1 @ 1234\n"
// r2
"vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, %q[vzero], q14, #3 @ 0012\n"
"vext.32 q7, q14, q15, #1 @ 1234\n"
// r3
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
"cmp %[cnt], #1 @ check whether has "
"mid cols\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"blt 3f @ jump to main loop start "
"point\n"
// mid
"1: @ right pad entry\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vdup.32 q5, %[bias_val] @ and \n" // q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point\n"
// right
"3: @ right pad entry\n"
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d31}, [%[vmask]]! @ load din r0\n"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d19}, [%[rmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[rmask]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vbif d8, d16, d19 @ bit select, deal with right pad\n"
"vbif d9, d17, d23 @ bit select, deal with right pad\n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad\n"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_channel += 2 * w_out;
} //! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2
*/
// w_in > 7
void conv_depthwise_3x3s2p1_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); //
int size_right_pad = w_out * 2 - w_in;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t wbias;
float bias_c = 0.f;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
bias_c = bias[i];
} else {
wbias = vdupq_n_f32(0.f);
}
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else {
dr0 = dr4;
dr1 = dr0 + w_in;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
asm volatile(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"prfm pldl1keep, [%[inptr0]] \n"
"prfm pldl1keep, [%[inptr1]] \n"
"prfm pldl1keep, [%[inptr2]] \n"
"prfm pldl1keep, [%[inptr3]] \n"
"prfm pldl1keep, [%[inptr4]] \n"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5}
// r0
"fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01
"fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5}
"sub %[inptr0], %[inptr0], #4 \n"
"sub %[inptr1], %[inptr1], #4 \n"
// r1
"fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01
"fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5}
"sub %[inptr2], %[inptr2], #4 \n"
"sub %[inptr3], %[inptr3], #4 \n"
// r2
"fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01
"fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01
"fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02
"fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00
"fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5}
"sub %[inptr4], %[inptr4], #4 \n"
// r3
"fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01
"fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5}
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01
"fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
"fadd v17.4s, v17.4s, v13.4s \n"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"ld1 {v15.4s}, [%[inptr0]] \n"
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"fadd v17.4s, v17.4s, v14.4s \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"cmp %[cnt], #1 \n"
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"blt 1f \n"
// mid
"2: \n"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fadd v17.4s, v17.4s, v13.4s \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"fadd v17.4s, v17.4s, v14.4s \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"subs %[cnt], %[cnt], #1 \n"
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 4f \n"
"3: \n"
"bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]] \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
"ld1 {v1.4s}, [%[outptr1]] \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"bif v16.16b, v0.16b, %[wmask].16b \n" // pipei
"fadd v17.4s, v17.4s, v13.4s \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fadd v17.4s, v17.4s, v14.4s \n"
"bif v17.16b, v1.16b, %[wmask].16b \n" // pipei
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
dr0 = dr1;
dr1 = dr2;
dr2 = dr1 + w_in;
} else {
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
}
//! process bottom pad
if (i + 2 > h_in) {
switch (i + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
asm volatile(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"vmov.u32 q9, #0 \n"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"vext.32 q6, q9, q11, #3 @ shift right 1 "
"data\n" // q2 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift right 1 "
"data\n" // q6 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift right 1 "
"data\n" // q6 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, "
"out0\n" // q11 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, "
"out0\n" // q12 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, "
"out0\n" // q6 * w00
"sub %[din0_ptr], #4 @ inpitr0 - 1\n"
"sub %[din1_ptr], #4 @ inpitr1 - 1\n"
"sub %[din2_ptr], #4 @ inpitr2 - 1\n"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0\n" // q11 * w01
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0\n" // q12 * w02
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w00
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, "
"out1\n" // q0 * w01
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, "
"out1\n" // q1 * w02
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, "
"out1\n" // q2 * w00
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"cmp %[cnt], #1 \n"
"blt 1f \n"
// mid
"2: \n"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"subs %[cnt], #1 \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 3f \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n"
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output\n"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vbif.f32 q3, q10, q11 @ write mask\n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"3: \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
// 4line
void conv_depthwise_3x3s1p1_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
// printf("conv3x3_dw start \n");
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = (w_in + 3) >> 2;
int tile_h = (h_in + 3) >> 2;
int cnt_col = tile_w - 2;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in);
int size_pad_bottom = (unsigned int)(1 + (tile_h << 2) - h_in);
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
#ifdef __aarch64__
for (int c = 0; c < ch_in; c++) {
float* dout_ptr = dout_batch + c * size_out_channel;
const float* din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float* wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float* doutr0 = dout_ptr;
float* doutr1 = doutr0 + w_out;
float* doutr2 = doutr1 + w_out;
float* doutr3 = doutr2 + w_out;
const float* dr0 = din_ch_ptr;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* dr5 = dr4 + w_in;
const float* din_ptr0 = dr0;
const float* din_ptr1 = dr1;
const float* din_ptr2 = dr2;
const float* din_ptr3 = dr3;
const float* din_ptr4 = dr4;
const float* din_ptr5 = dr5;
for (int i = 0; i < h_in; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
if (i == 0) {
din_ptr0 = zero_ptr;
din_ptr1 = dr0;
din_ptr2 = dr1;
din_ptr3 = dr2;
din_ptr4 = dr3;
din_ptr5 = dr4;
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
} else {
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
}
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 > h_in) {
switch (i + 5 - h_in) {
case 5:
din_ptr1 = zero_ptr;
case 4:
din_ptr2 = zero_ptr;
case 3:
din_ptr3 = zero_ptr;
case 2:
din_ptr4 = zero_ptr;
case 1:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = cnt_col;
asm volatile(
"PRFM PLDL1KEEP, [%[din_ptr0]] \n"
"PRFM PLDL1KEEP, [%[din_ptr1]] \n"
"PRFM PLDL1KEEP, [%[din_ptr2]] \n"
"PRFM PLDL1KEEP, [%[din_ptr3]] \n"
"PRFM PLDL1KEEP, [%[din_ptr4]] \n"
"PRFM PLDL1KEEP, [%[din_ptr5]] \n"
"movi v21.4s, #0x0\n" /* out0 = 0 */
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
// left
// r0
"fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 *
w0[1]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */
"sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */
"fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 *
w0[0]*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */
"sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 *
w0[2]*/
"ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 *
w1[1]*/
"sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */
"sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */
"fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 *
w1[1]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w2[1]*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w1[1]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w2[1]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w1[1]*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */
// r4
"fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w2[1]*/
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w1[1]*/
"st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w1[1]*/
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/
"ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
// r5
"fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 *
w1[1]*/
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 *
w0[1]*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/
"st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */
"cmp %[cnt], #1 \n"
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"blt 3f \n"
// mid
"1: \n"
// r0
"fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v12.4s}, [%[doutr0]], #16 \n"
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v13.4s}, [%[doutr1]], #16 \n"
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
"subs %[cnt], %[cnt], #1 \n"
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/
"st1 {v15.4s}, [%[doutr3]], #16 \n"
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
"bne 1b \n"
// right
"3: \n"
"ld1 {v18.4s, v19.4s}, [%[vmask]] \n"
"ld1 {v22.4s}, [%[doutr0]] \n"
"ld1 {v23.4s}, [%[doutr1]] \n"
"ld1 {v24.4s}, [%[doutr2]] \n"
"ld1 {v25.4s}, [%[doutr3]] \n"
"bif v0.16b, %[vzero].16b, v18.16b \n"
"bif v1.16b, %[vzero].16b, v19.16b \n"
"bif v2.16b, %[vzero].16b, v18.16b \n"
"bif v3.16b, %[vzero].16b, v19.16b \n"
"bif v4.16b, %[vzero].16b, v18.16b \n"
"bif v5.16b, %[vzero].16b, v19.16b \n"
"bif v6.16b, %[vzero].16b, v18.16b \n"
"bif v7.16b, %[vzero].16b, v19.16b \n"
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */
// r0
"fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"bif v8.16b, %[vzero].16b, v18.16b \n"
"bif v9.16b, %[vzero].16b, v19.16b \n"
"bif v10.16b, %[vzero].16b, v18.16b \n"
"bif v11.16b, %[vzero].16b, v19.16b \n"
"fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"ld1 {v18.4s}, [%[rmask]] \n"
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */
// r1
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */
// r2
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v12.16b, v22.16b, v18.16b \n"
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */
// r3
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"st1 {v12.4s}, [%[doutr0]], #16 \n"
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v13.16b, v23.16b, v18.16b \n"
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */
"st1 {v13.4s}, [%[doutr1]], #16 \n"
// r3
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 *
w0[0]*/
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 *
w0[1]*/
"bif v14.16b, v24.16b, v18.16b \n"
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 *
w0[2]*/
"st1 {v14.4s}, [%[doutr2]], #16 \n"
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/
"bif v15.16b, v25.16b, v18.16b \n"
"st1 {v15.4s}, [%[doutr3]], #16 \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
}
#else
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_val = flag_bias ? bias[i] : 0.f;
float* dout_channel = dout_batch + i * size_out_channel;
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* din0_ptr = nullptr;
const float* din1_ptr = nullptr;
const float* din2_ptr = nullptr;
const float* din3_ptr = nullptr;
float* doutr0 = nullptr;
float* doutr1 = nullptr;
float* ptr_zero = const_cast<float*>(zero);
for (int i = 0; i < h_in; i += 2) {
//! process top pad pad_h = 1
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
doutr0 = dout_channel;
doutr1 = dout_channel + w_out;
// unsigned int* rst_mask = rmask;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
dr2 = dr3;
dr3 = dr2 + w_in;
} else {
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
}
//! process bottom pad
if (i + 3 > h_in) {
switch (i + 3 - h_in) {
case 3:
din1_ptr = zero_ptr;
case 2:
din2_ptr = zero_ptr;
case 1:
din3_ptr = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = cnt_col;
unsigned int* rmask_ptr = rmask;
unsigned int* vmask_ptr = vmask;
asm volatile(
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n"
"vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n"
"vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n"
"vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"vext.32 q6, %q[vzero], q8, #3 @ 0012\n"
"vext.32 q7, q8, q9, #1 @ 1234\n"
// left
// r0
"vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n"
"sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n"
"sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n"
"sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n"
"vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, %q[vzero], q10, #3 @ 0012\n"
"vext.32 q7, q10, q11, #1 @ 1234\n"
// r1
"vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, %q[vzero], q12, #3 @ 0012\n"
"vext.32 q7, q12, q13, #1 @ 1234\n"
// r2
"vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, %q[vzero], q14, #3 @ 0012\n"
"vext.32 q7, q14, q15, #1 @ 1234\n"
// r3
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n"
"vmax.f32 q4, q4, %q[vzero] @ relu \n"
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vmax.f32 q5, q5, %q[vzero] @ relu \n"
"cmp %[cnt], #1 @ check whether has "
"mid cols\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"vdup.32 q5, %[bias_val] @ and \n" // q5
// =
// vbias
"blt 3f @ jump to main loop start "
"point\n"
// mid
"1: @ right pad entry\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"pld [%[din3_ptr]] @ preload data\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n"
"vmax.f32 q4, q4, %q[vzero] @ relu \n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
"vdup.32 q4, %[bias_val] @ and \n" // q4
// =
// vbias
"vmax.f32 q5, q5, %q[vzero] @ relu \n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"vdup.32 q5, %[bias_val] @ and \n" // q4
// =
// vbias
"bne 1b @ jump to main loop start "
"point\n"
// right
"3: @ right pad entry\n"
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n"
"vld1.32 {d31}, [%[vmask]]! @ load din r0\n"
"vbif d16, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d17, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d18, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vbif d20, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d21, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d22, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vext.32 q6, q8, q9, #1 @ 1234\n"
"vext.32 q7, q8, q9, #2 @ 2345\n"
// r0
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n"
"vbif d24, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d25, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d26, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vbif d28, %e[vzero], d19 @ bit select, deal with "
"right pad\n"
"vbif d29, %e[vzero], d23 @ bit select, deal with "
"right pad\n"
"vbif d30, %e[vzero], d27 @ bit select, deal with "
"right pad\n"
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n"
"vext.32 q6, q10, q11, #1 @ 1234\n"
"vext.32 q7, q10, q11, #2 @ 2345\n"
// r1
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d19}, [%[rmask]]! @ load din r0\n"
"vld1.32 {d23}, [%[rmask]]! @ load din r0\n"
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n"
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n"
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q12, q13, #1 @ 1234\n"
"vext.32 q7, q12, q13, #2 @ 2345\n"
// r2
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n"
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n"
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n"
"vext.32 q6, q14, q15, #1 @ 1234\n"
"vext.32 q7, q14, q15, #2 @ 2345\n"
// r3
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n"
"vmax.f32 q4, q4, %q[vzero] @ relu \n"
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n"
"vbif d8, d16, d19 @ bit select, deal with right pad\n"
"vbif d9, d17, d23 @ bit select, deal with right pad\n"
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n"
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n"
"vmax.f32 q5, q5, %q[vzero] @ relu \n"
"vbif d10, d20, d19 @ bit select, deal with right "
"pad\n"
"vbif d11, d21, d23 @ bit select, deal with right "
"pad\n"
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add "
"pointer\n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_channel += 2 * w_out;
} //! end of processing mid rows
}
#endif
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, with reulu
*/
// w_in > 7
void conv_depthwise_3x3s2p1_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); //
int size_right_pad = w_out * 2 - w_in;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t wbias;
float bias_c = 0.f;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
bias_c = bias[i];
} else {
wbias = vdupq_n_f32(0.f);
}
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else {
dr0 = dr4;
dr1 = dr0 + w_in;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
asm volatile(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"prfm pldl1keep, [%[inptr0]] \n"
"prfm pldl1keep, [%[inptr1]] \n"
"prfm pldl1keep, [%[inptr2]] \n"
"prfm pldl1keep, [%[inptr3]] \n"
"prfm pldl1keep, [%[inptr4]] \n"
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"ext v10.16b, %[vzero].16b, v1.16b, #12 \n" // v10 = {0,1,3,5}
// r0
"fmul v11.4s, v0.4s, %[w0].s[1] \n" // {0,2,4,6} * w01
"fmul v12.4s, v1.4s, %[w0].s[2] \n" // {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v3.16b, #12 \n" // v10 = {0,1,3,5}
"sub %[inptr0], %[inptr0], #4 \n"
"sub %[inptr1], %[inptr1], #4 \n"
// r1
"fmla v11.4s, v2.4s, %[w1].s[1] \n" // {0,2,4,6} * w01
"fmla v12.4s, v3.4s, %[w1].s[2] \n" // {1,3,5,7} * w02
"fmla v16.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v5.16b, #12 \n" // v10 = {0,1,3,5}
"sub %[inptr2], %[inptr2], #4 \n"
"sub %[inptr3], %[inptr3], #4 \n"
// r2
"fmul v13.4s, v4.4s, %[w0].s[1] \n" // {0,2,4,6} * w01
"fmla v11.4s, v4.4s, %[w2].s[1] \n" // {0,2,4,6} * w01
"fmul v14.4s, v5.4s, %[w0].s[2] \n" // {1,3,5,7} * w02
"fmla v12.4s, v5.4s, %[w2].s[2] \n" // {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w0].s[0] \n" // {0,1,3,5} * w00
"fmla v16.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v7.16b, #12 \n" // v10 = {0,1,3,5}
"sub %[inptr4], %[inptr4], #4 \n"
// r3
"fmla v13.4s, v6.4s, %[w1].s[1] \n" // {0,2,4,6} * w01
"fmla v14.4s, v7.4s, %[w1].s[2] \n" // {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w1].s[0] \n" // {0,1,3,5} * w00
"ext v10.16b, %[vzero].16b, v9.16b, #12 \n" // v10 = {0,1,3,5}
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[1] \n" // {0,2,4,6} * w01
"fmla v14.4s, v9.4s, %[w2].s[2] \n" // {1,3,5,7} * w02
"fmla v17.4s, v10.4s, %[w2].s[0] \n" // {0,1,3,5} * w00
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
"fadd v17.4s, v17.4s, v13.4s \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"ld1 {v15.4s}, [%[inptr0]] \n"
"fadd v17.4s, v17.4s, v14.4s \n"
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"cmp %[cnt], #1 \n"
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"blt 1f \n"
// mid
"2: \n"
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, v18.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" // v0={0,2,4,6}
// v1={1,3,5,7}
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, v19.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n"
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, v20.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n"
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, v21.16b, #4 \n" // v10 = {2,4,6,8}
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n"
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */
"fadd v17.4s, v17.4s, v13.4s \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fadd v17.4s, v17.4s, v14.4s \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"and v16.16b, %[vbias].16b, %[vbias].16b \n" // v10 = vbias
"subs %[cnt], %[cnt], #1 \n"
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"and v17.16b, %[vbias].16b, %[vbias].16b \n" // v16 = vbias
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 4f \n"
"3: \n"
"bif v0.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v1.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v2.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v3.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"bif v4.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v5.16b, %[vzero].16b, %[mask2].16b \n" // pipei
"ext v10.16b, v0.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v6.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v7.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r0
"fmul v11.4s, v0.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmul v12.4s, v1.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v2.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"bif v8.16b, %[vzero].16b, %[mask1].16b \n" // pipei
"bif v9.16b, %[vzero].16b, %[mask2].16b \n" // pipei
// r1
"fmla v11.4s, v2.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v12.4s, v3.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v16.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v4.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r2
"fmul v13.4s, v4.4s, %[w0].s[0] \n" // {0,2,4,6} * w00
"fmla v11.4s, v4.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmul v14.4s, v5.4s, %[w0].s[1] \n" // {1,3,5,7} * w01
"fmla v12.4s, v5.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w0].s[2] \n" // {2,4,6,8} * w02
"fmla v16.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v6.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
// r3
"fmla v13.4s, v6.4s, %[w1].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v7.4s, %[w1].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w1].s[2] \n" // {2,4,6,8} * w02
"ext v10.16b, v8.16b, %[vzero].16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v0.4s}, [%[outptr0]] \n"
"fadd v16.4s, v16.4s, v11.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n"
"ld1 {v1.4s}, [%[outptr1]] \n"
// r4
"fmla v13.4s, v8.4s, %[w2].s[0] \n" // {0,2,4,6} * w00
"fmla v14.4s, v9.4s, %[w2].s[1] \n" // {1,3,5,7} * w01
"fmla v17.4s, v10.4s, %[w2].s[2] \n" // {2,4,6,8} * w02
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */
"fadd v17.4s, v17.4s, v13.4s \n"
"bif v16.16b, v0.16b, %[wmask].16b \n" // pipei
"fadd v17.4s, v17.4s, v14.4s \n"
"st1 {v16.4s}, [%[outptr0]], #16 \n"
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */
"bif v17.16b, v1.16b, %[wmask].16b \n" // pipei
"st1 {v17.4s}, [%[outptr1]], #16 \n"
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
dr0 = dr1;
dr1 = dr2;
dr2 = dr1 + w_in;
} else {
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
}
//! process bottom pad
if (i + 2 > h_in) {
switch (i + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
asm volatile(
// top
// Load up 12 elements (3 vectors) from each of 8 sources.
"0: \n"
"vmov.u32 q9, #0 \n"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q10, q11
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v11={0,2,4,6} v12={1,3,5,7}, q12, q13
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v13={0,2,4,6} v14={1,3,5,7}, q14, q15
"pld [%[din0_ptr]] @ preload data\n"
"pld [%[din1_ptr]] @ preload data\n"
"pld [%[din2_ptr]] @ preload data\n"
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"vext.32 q6, q9, q11, #3 @ shift right 1 "
"data\n" // q2 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift right 1 "
"data\n" // q6 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift right 1 "
"data\n" // q6 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, "
"out0\n" // q11 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, "
"out0\n" // q12 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, "
"out0\n" // q6 * w00
"sub %[din0_ptr], #4 @ inpitr0 - 1\n"
"sub %[din1_ptr], #4 @ inpitr1 - 1\n"
"sub %[din2_ptr], #4 @ inpitr2 - 1\n"
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0\n" // q11 * w01
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0\n" // q12 * w02
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w00
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, "
"out1\n" // q0 * w01
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, "
"out1\n" // q1 * w02
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, "
"out1\n" // q2 * w00
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" // v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vmax.f32 q3, q3, q9 @ relu \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"cmp %[cnt], #1 \n"
"blt 1f \n"
// mid
"2: \n"
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" // q2={8,10,12,14}
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"vext.32 q6, q10, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q7, q12, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" // q2={8,10,12,14}
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q8, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // v0={0,2,4,6} v1={1,3,5,7}
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // v4={0,2,4,6} v5={1,3,5,7}
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vmax.f32 q3, q3, q9 @ relu \n"
"subs %[cnt], #1 \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"bne 2b \n"
// right
"1: \n"
"cmp %[remain], #1 \n"
"blt 3f \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n"
"vdup.32 q3, %[bias] @ and \n" // q10 =
// vbias
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, "
"out0\n" // q0 * w00
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w02
"vext.32 q6, q14, q9, #1 @ shift left 1 \n" // q6 = {2,4,6,8}
"vld1.f32 {d20-d21}, [%[outptr]] @ load output\n"
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, "
"out0\n" // q0 * w00
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, "
"out0\n" // q1 * w01
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, "
"out0\n" // q6 * w02
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n"
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, "
"out0\n" // q0 * w00
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, "
"out0\n" // q1 * w01
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, "
"out0\n" // q6 * w02
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vmax.f32 q3, q3, q9 @ relu \n"
"vbif.f32 q3, q10, q11 @ write mask\n"
"vst1.32 {d6-d7}, [%[outptr]]! \n"
"3: \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p1_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float* dout_channel = dout_batch + i * size_out_channel;
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float* doutr0 = dout_channel;
float* doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
}
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[din0]]\n"
"prfm pldl1keep, [%[din1]]\n"
"prfm pldl1keep, [%[din2]]\n"
"prfm pldl1keep, [%[din3]]\n"
"ld1 {v0.4s}, [%[din0]], #16\n"
"ld1 {v1.4s}, [%[din1]], #16\n"
"ld1 {v2.4s}, [%[din2]], #16\n"
"ld1 {v3.4s}, [%[din3]], #16\n"
"bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234
"bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234
"bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234
"bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234
"ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123
"ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123
"ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123
"ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123
"ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340
"ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340
"ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340
"ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340
"fmul v12.4s, v0.4s, %[wr0].s[1]\n"
"fmul v13.4s, v1.4s, %[wr0].s[1]\n"
"fmul v14.4s, v1.4s, %[wr1].s[1]\n"
"fmul v15.4s, v2.4s, %[wr1].s[1]\n"
"fmul v16.4s, v2.4s, %[wr2].s[1]\n"
"fmul v17.4s, v3.4s, %[wr2].s[1]\n"
"fmla v12.4s, v4.4s, %[wr0].s[0]\n"
"fmla v13.4s, v5.4s, %[wr0].s[0]\n"
"fmla v14.4s, v5.4s, %[wr1].s[0]\n"
"fmla v15.4s, v6.4s, %[wr1].s[0]\n"
"fmla v16.4s, v6.4s, %[wr2].s[0]\n"
"fmla v17.4s, v7.4s, %[wr2].s[0]\n"
"fmla v12.4s, v8.4s, %[wr0].s[2]\n"
"fmla v13.4s, v9.4s, %[wr0].s[2]\n"
"fmla v14.4s, v9.4s, %[wr1].s[2]\n"
"fmla v15.4s, v10.4s, %[wr1].s[2]\n"
"fmla v16.4s, v10.4s, %[wr2].s[2]\n"
"fmla v17.4s, v11.4s, %[wr2].s[2]\n"
"fadd v12.4s, v12.4s, v14.4s\n"
"fadd v12.4s, v12.4s, v16.4s\n"
"fadd v13.4s, v13.4s, v15.4s\n" // out1
"fadd v13.4s, v13.4s, v17.4s\n" // out2
"fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias
"fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias
"prfm pldl1keep, [%[out1]]\n"
"prfm pldl1keep, [%[out2]]\n"
"st1 {v12.4s}, [%[out1]]\n"
"st1 {v13.4s}, [%[out2]]\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(
"pld [%[din0]]\n"
"pld [%[din1]]\n"
"pld [%[din2]]\n"
"pld [%[din3]]\n"
"vld1.32 {d12-d13}, [%[din0]]!\n"
"vld1.32 {d14-d15}, [%[din1]]!\n"
"vld1.32 {d16-d17}, [%[din2]]!\n"
"vld1.32 {d18-d19}, [%[din3]]!\n"
"vbif q6, %q[zero], %q[mask]\n" // d0_1234
"vbif q7, %q[zero], %q[mask]\n" // d1_1234
"vbif q8, %q[zero], %q[mask]\n" // d2_1234
"vbif q9, %q[zero], %q[mask]\n" // d3_1234
"vmul.f32 q14, q6, %e[wr0][1]\n"
"vmul.f32 q15, q7, %e[wr0][1]\n"
"vmla.f32 q14, q7, %e[wr1][1]\n"
"vmla.f32 q15, q8, %e[wr1][1]\n"
"vmla.f32 q14, q8, %e[wr2][1]\n"
"vmla.f32 q15, q9, %e[wr2][1]\n"
"vext.32 q10, %q[zero], q6, #3\n" // d0_0123
"vext.32 q11, %q[zero], q7, #3\n" // d1_0123
"vext.32 q12, %q[zero], q8, #3\n" // d2_0123
"vext.32 q13, %q[zero], q9, #3\n" // d3_0123
"vmla.f32 q14, q10, %e[wr0][0]\n"
"vmla.f32 q15, q11, %e[wr0][0]\n"
"vmla.f32 q14, q11, %e[wr1][0]\n"
"vmla.f32 q15, q12, %e[wr1][0]\n"
"vmla.f32 q14, q12, %e[wr2][0]\n"
"vmla.f32 q15, q13, %e[wr2][0]\n"
"vext.32 q10, q6, %q[zero], #1\n" // d0_2340
"vext.32 q11, q7, %q[zero], #1\n" // d1_2340
"vext.32 q12, q8, %q[zero], #1\n" // d2_2340
"vext.32 q13, q9, %q[zero], #1\n" // d3_2340
"vmla.f32 q14, q10, %f[wr0][0]\n"
"vmla.f32 q15, q11, %f[wr0][0]\n"
"vmla.f32 q14, q11, %f[wr1][0]\n"
"vmla.f32 q15, q12, %f[wr1][0]\n"
"vmla.f32 q14, q12, %f[wr2][0]\n" // out1
"vmla.f32 q15, q13, %f[wr2][0]\n" // out2
"vadd.f32 q14, q14, %q[bias]\n" // out1 add bias
"vadd.f32 q15, q15, %q[bias]\n" // out2 add bias
"pld [%[out1]]\n"
"pld [%[out2]]\n"
"vst1.32 {d28-d29}, [%[out1]]\n"
"vst1.32 {d30-d31}, [%[out2]]\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/
void conv_depthwise_3x3s2p1_bias_s(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
int hs = -1;
int he = 2;
float out_buf[4];
for (int j = 0; j < h_out; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
if (hs == -1) {
dr0 = zeros;
}
if (he > h_in) {
dr2 = zeros;
}
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0 \n"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6}
// v15={1,3,5,7}
"bif v10.16b, v9.16b, v6.16b \n"
"bif v11.16b, v9.16b, v7.16b \n"
"bif v12.16b, v9.16b, v6.16b \n"
"bif v13.16b, v9.16b, v7.16b \n"
"bif v14.16b, v9.16b, v6.16b \n"
"bif v15.16b, v9.16b, v7.16b \n"
"ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 =
// {0,1,3,5}
"ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 =
// {0,1,3,5}
"ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 =
// {0,1,3,5}
"fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01
"fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02
"fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00
"fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12
"fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21
"fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22
"fadd v4.4s, v4.4s, v5.4s \n"
"fadd v4.4s, v4.4s, v6.4s \n"
"fadd v4.4s, v4.4s, %[bias].4s \n"
"st1 {v4.4s}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0 \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n"
"vdup.32 q3, %[bias] @ and \n" // q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, "
"out0\n" // q10 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, "
"out0\n" // q11 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w00
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0\n" // q12 * w11
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0\n" // q13 * w12
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0\n" // q7 * w10
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, "
"out0\n" // q14 * w20
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, "
"out0\n" // q15 * w21
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, "
"out0\n" // q8 * w22
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vst1.32 {d6-d7}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
hs += 2;
he += 2;
}
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p1_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float* dout_channel = dout_batch + i * size_out_channel;
const float* din_channel = din_batch + i * size_in_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float* doutr0 = dout_channel;
float* doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
}
#ifdef __aarch64__
asm volatile(
"prfm pldl1keep, [%[din0]]\n"
"prfm pldl1keep, [%[din1]]\n"
"prfm pldl1keep, [%[din2]]\n"
"prfm pldl1keep, [%[din3]]\n"
"ld1 {v0.4s}, [%[din0]], #16\n"
"ld1 {v1.4s}, [%[din1]], #16\n"
"ld1 {v2.4s}, [%[din2]], #16\n"
"ld1 {v3.4s}, [%[din3]], #16\n"
"bif v0.16b, %[zero].16b, %[mask].16b\n" // d0_1234
"bif v1.16b, %[zero].16b, %[mask].16b\n" // d1_1234
"bif v2.16b, %[zero].16b, %[mask].16b\n" // d2_1234
"bif v3.16b, %[zero].16b, %[mask].16b\n" // d3_1234
"ext v4.16b, %[zero].16b, v0.16b, #12\n" // d0_0123
"ext v5.16b, %[zero].16b, v1.16b, #12\n" // d1_0123
"ext v6.16b, %[zero].16b, v2.16b, #12\n" // d2_0123
"ext v7.16b, %[zero].16b, v3.16b, #12\n" // d3_0123
"ext v8.16b, v0.16b, %[zero].16b, #4\n" // d0_2340
"ext v9.16b, v1.16b, %[zero].16b, #4\n" // d1_2340
"ext v10.16b, v2.16b, %[zero].16b, #4\n" // d2_2340
"ext v11.16b, v3.16b, %[zero].16b, #4\n" // d3_2340
"fmul v12.4s, v0.4s, %[wr0].s[1]\n"
"fmul v13.4s, v1.4s, %[wr0].s[1]\n"
"fmul v14.4s, v1.4s, %[wr1].s[1]\n"
"fmul v15.4s, v2.4s, %[wr1].s[1]\n"
"fmul v16.4s, v2.4s, %[wr2].s[1]\n"
"fmul v17.4s, v3.4s, %[wr2].s[1]\n"
"fmla v12.4s, v4.4s, %[wr0].s[0]\n"
"fmla v13.4s, v5.4s, %[wr0].s[0]\n"
"fmla v14.4s, v5.4s, %[wr1].s[0]\n"
"fmla v15.4s, v6.4s, %[wr1].s[0]\n"
"fmla v16.4s, v6.4s, %[wr2].s[0]\n"
"fmla v17.4s, v7.4s, %[wr2].s[0]\n"
"fmla v12.4s, v8.4s, %[wr0].s[2]\n"
"fmla v13.4s, v9.4s, %[wr0].s[2]\n"
"fmla v14.4s, v9.4s, %[wr1].s[2]\n"
"fmla v15.4s, v10.4s, %[wr1].s[2]\n"
"fmla v16.4s, v10.4s, %[wr2].s[2]\n"
"fmla v17.4s, v11.4s, %[wr2].s[2]\n"
"fadd v12.4s, v12.4s, v14.4s\n"
"fadd v12.4s, v12.4s, v16.4s\n"
"fadd v13.4s, v13.4s, v15.4s\n" // out1
"fadd v13.4s, v13.4s, v17.4s\n" // out2
"fadd v12.4s, v12.4s, %[bias].4s\n" // out1 add bias
"fadd v13.4s, v13.4s, %[bias].4s\n" // out2 add bias
"prfm pldl1keep, [%[out1]]\n"
"prfm pldl1keep, [%[out2]]\n"
"fmax v12.4s, v12.4s, %[zero].4s\n" // out1 -> relu
"fmax v13.4s, v13.4s, %[zero].4s\n" // out2 -> relu
"st1 {v12.4s}, [%[out1]]\n"
"st1 {v13.4s}, [%[out2]]\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(
"pld [%[din0]]\n"
"pld [%[din1]]\n"
"pld [%[din2]]\n"
"pld [%[din3]]\n"
"vld1.32 {d12-d13}, [%[din0]]!\n"
"vld1.32 {d14-d15}, [%[din1]]!\n"
"vld1.32 {d16-d17}, [%[din2]]!\n"
"vld1.32 {d18-d19}, [%[din3]]!\n"
"vbif q6, %q[zero], %q[mask]\n" // d0_1234
"vbif q7, %q[zero], %q[mask]\n" // d1_1234
"vbif q8, %q[zero], %q[mask]\n" // d2_1234
"vbif q9, %q[zero], %q[mask]\n" // d3_1234
"vmul.f32 q14, q6, %e[wr0][1]\n"
"vmul.f32 q15, q7, %e[wr0][1]\n"
"vmla.f32 q14, q7, %e[wr1][1]\n"
"vmla.f32 q15, q8, %e[wr1][1]\n"
"vmla.f32 q14, q8, %e[wr2][1]\n"
"vmla.f32 q15, q9, %e[wr2][1]\n"
"vext.32 q10, %q[zero], q6, #3\n" // d0_0123
"vext.32 q11, %q[zero], q7, #3\n" // d1_0123
"vext.32 q12, %q[zero], q8, #3\n" // d2_0123
"vext.32 q13, %q[zero], q9, #3\n" // d3_0123
"vmla.f32 q14, q10, %e[wr0][0]\n"
"vmla.f32 q15, q11, %e[wr0][0]\n"
"vmla.f32 q14, q11, %e[wr1][0]\n"
"vmla.f32 q15, q12, %e[wr1][0]\n"
"vmla.f32 q14, q12, %e[wr2][0]\n"
"vmla.f32 q15, q13, %e[wr2][0]\n"
"vext.32 q10, q6, %q[zero], #1\n" // d0_2340
"vext.32 q11, q7, %q[zero], #1\n" // d1_2340
"vext.32 q12, q8, %q[zero], #1\n" // d2_2340
"vext.32 q13, q9, %q[zero], #1\n" // d3_2340
"vmla.f32 q14, q10, %f[wr0][0]\n"
"vmla.f32 q15, q11, %f[wr0][0]\n"
"vmla.f32 q14, q11, %f[wr1][0]\n"
"vmla.f32 q15, q12, %f[wr1][0]\n"
"vmla.f32 q14, q12, %f[wr2][0]\n" // out1
"vmla.f32 q15, q13, %f[wr2][0]\n" // out2
"vadd.f32 q14, q14, %q[bias]\n" // out1 add bias
"vadd.f32 q15, q15, %q[bias]\n" // out2 add bias
"pld [%[out1]]\n"
"pld [%[out2]]\n"
"vmax.f32 q14, q14, %q[zero]\n" // out1 -> relu
"vmax.f32 q15, q15, %q[zero]\n" // out2 -> relu
"vst1.32 {d28-d29}, [%[out1]]\n"
"vst1.32 {d30-d31}, [%[out2]]\n"
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 7
*/
void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
int hs = -1;
int he = 2;
float out_buf[4];
for (int j = 0; j < h_out; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
if (hs == -1) {
dr0 = zeros;
}
if (he > h_in) {
dr2 = zeros;
}
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"movi v9.4s, #0 \n"
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n"
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" // v10={0,2,4,6}
// v11={1,3,5,7}
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" // v13={0,2,4,6}
// v12={1,3,5,7}
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" // v14={0,2,4,6}
// v15={1,3,5,7}
"bif v10.16b, v9.16b, v6.16b \n"
"bif v11.16b, v9.16b, v7.16b \n"
"bif v12.16b, v9.16b, v6.16b \n"
"bif v13.16b, v9.16b, v7.16b \n"
"bif v14.16b, v9.16b, v6.16b \n"
"bif v15.16b, v9.16b, v7.16b \n"
"ext v6.16b, v9.16b, v11.16b, #12 \n" // v6 =
// {0,1,3,5}
"ext v7.16b, v9.16b, v13.16b, #12 \n" // v7 =
// {0,1,3,5}
"ext v8.16b, v9.16b, v15.16b, #12 \n" // v8 =
// {0,1,3,5}
"fmul v4.4s, v10.4s, %[wr0].s[1] \n" // v10 * w01
"fmul v5.4s, v11.4s, %[wr0].s[2] \n" // v11 * w02
"fmul v6.4s, v6.4s, %[wr0].s[0] \n" // v6 * w00
"fmla v4.4s, v12.4s, %[wr1].s[1] \n" // v12 * w11
"fmla v5.4s, v13.4s, %[wr1].s[2] \n" // v13 * w12
"fmla v6.4s, v7.4s, %[wr1].s[0] \n" // v7 * w10
"fmla v4.4s, v14.4s, %[wr2].s[1] \n" // v14 * w20
"fmla v5.4s, v15.4s, %[wr2].s[2] \n" // v15 * w21
"fmla v6.4s, v8.4s, %[wr2].s[0] \n" // v8 * w22
"fadd v4.4s, v4.4s, v5.4s \n"
"fadd v4.4s, v4.4s, v6.4s \n"
"fadd v4.4s, v4.4s, %[bias].4s \n" // out add bias
"fmax v4.4s, v4.4s, v9.4s \n"
"st1 {v4.4s}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(
// Load up 12 elements (3 vectors) from each of 8 sources.
"vmov.u32 q9, #0 \n"
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n"
"vdup.32 q3, %[bias] @ and \n" // q3 =
// vbias
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" // q10={0,2,4,6} q11={1,3,5,7}
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" // q13={0,2,4,6} q12={1,3,5,7}
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" // q14={0,2,4,6} q15={1,3,5,7}
"vbif q10, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q11, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q12, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q13, q9, q7 @ bit select, deal "
"with right pad\n"
"vbif q14, q9, q6 @ bit select, deal "
"with right pad\n"
"vbif q15, q9, q7 @ bit select, deal "
"with right pad\n"
"vext.32 q6, q9, q11, #3 @ shift left 1 \n" // q6 = {0,1,3,5}
"vext.32 q7, q9, q13, #3 @ shift left 1 \n" // q7 = {0,1,3,5}
"vext.32 q8, q9, q15, #3 @ shift left 1 \n" // q8 = {0,1,3,5}
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, "
"out0\n" // q10 * w01
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, "
"out0\n" // q11 * w02
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, "
"out0\n" // q6 * w00
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, "
"out0\n" // q12 * w11
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, "
"out0\n" // q13 * w12
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, "
"out0\n" // q7 * w10
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, "
"out0\n" // q14 * w20
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, "
"out0\n" // q15 * w21
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, "
"out0\n" // q8 * w22
"vadd.f32 q3, q3, q4 @ add \n"
"vadd.f32 q3, q3, q5 @ add \n"
"vmax.f32 q3, q3, q9 @ relu\n"
"vst1.32 {d6-d7}, [%[out]] \n"
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif // __aarch64__
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
hs += 2;
he += 2;
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -361,7 +361,6 @@ void conv_im2col_gemm(const float* i_data,
float* tmp_work_space =
ctx->workspace_data<float>() + ctx->llc_size() / sizeof(float);
//! use gemv when the output channel size = 1
for (int b = 0; b < num; ++b) {
// dC
......
......@@ -39,6 +39,7 @@
#include "lite/backends/arm/math/im2sequence.h"
#include "lite/backends/arm/math/increment.h"
#include "lite/backends/arm/math/interpolate.h"
#include "lite/backends/arm/math/layout.h"
#include "lite/backends/arm/math/lrn.h"
#include "lite/backends/arm/math/negative.h"
#include "lite/backends/arm/math/norm.h"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/layout.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#ifdef __aarch64__
#define TRANS_C4 \
"ld1 {v0.4s}, [%[din0_ptr]] \n" \
"ld1 {v1.4s}, [%[din1_ptr]] \n" \
"ld1 {v2.4s}, [%[din2_ptr]] \n" \
"ld1 {v3.4s}, [%[din3_ptr]] \n" \
\
"1: \n" \
"trn1 v4.4s, v0.4s, v1.4s \n" /*00 10 02 12 */ \
"trn1 v5.4s, v2.4s, v3.4s \n" /*20 30 22 32 */ \
"trn2 v6.4s, v0.4s, v1.4s \n" /*01 11 03 13 */ \
"trn2 v7.4s, v2.4s, v3.4s \n" /*21 31 23 33 */ \
\
"add %[din0_ptr], %[din0_ptr], %[stride] \n" /* din+=c*size*/ \
"add %[din1_ptr], %[din1_ptr], %[stride] \n" /* din+=c*size*/ \
"add %[din2_ptr], %[din2_ptr], %[stride] \n" /* din+=c*size*/ \
"add %[din3_ptr], %[din3_ptr], %[stride] \n" /* din+=c*size*/ \
\
"trn1 v8.2d, v4.2d, v5.2d \n" /*00 10 20 30 */ \
"trn1 v9.2d, v6.2d, v7.2d \n" /*01 11 21 31 */ \
"trn2 v10.2d, v4.2d, v5.2d \n" /*02 12 22 32 */ \
"trn2 v11.2d, v6.2d, v7.2d \n" /*03 13 23 33 */ \
\
"ld1 {v0.4s}, [%[din0_ptr]] \n" \
"ld1 {v1.4s}, [%[din1_ptr]] \n" \
"ld1 {v2.4s}, [%[din2_ptr]] \n" \
"ld1 {v3.4s}, [%[din3_ptr]] \n" \
\
"subs %w[cnt], %w[cnt], #1 \n" \
"str q8, [%[out0_ptr]], #16 \n" \
"str q9, [%[out1_ptr]], #16 \n" \
"str q10, [%[out2_ptr]], #16 \n" \
"str q11, [%[out3_ptr]], #16 \n" \
"bne 1b \n"
#define TRANS_C8 \
"1: \n" \
"ld1 {v0.8b}, [%[din0_ptr]] \n" \
"ld1 {v1.8b}, [%[din1_ptr]] \n" \
"ld1 {v2.8b}, [%[din2_ptr]] \n" \
"ld1 {v3.8b}, [%[din3_ptr]] \n" \
\
"add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \
\
"trn1 v8.8b, v0.8b, v1.8b \n" /*00 10 02 12 04 14 06 16 */ \
"trn1 v9.8b, v2.8b, v3.8b \n" /*20 30 22 32 */ \
"trn2 v12.8b, v0.8b, v1.8b \n" /*01 11 03 13 05 15 07 17 */ \
"trn2 v13.8b, v2.8b, v3.8b \n" /*21 31 23 33 */ \
\
"ld1 {v4.8b}, [%[din0_ptr]] \n" \
"ld1 {v5.8b}, [%[din1_ptr]] \n" \
"ld1 {v6.8b}, [%[din2_ptr]] \n" \
"ld1 {v7.8b}, [%[din3_ptr]] \n" \
\
"trn1 v10.8b, v4.8b, v5.8b \n" /*40 50 42 52 */ \
"trn1 v11.8b, v6.8b, v7.8b \n" /*60 70 62 72 */ \
"trn2 v14.8b, v4.8b, v5.8b \n" /*41 51 43 53 */ \
"trn2 v15.8b, v6.8b, v7.8b \n" /*61 71 63 73 */ \
\
"trn1 v0.4h, v8.4h, v9.4h \n" /*00 10 20 30 04 14 24 34*/ \
"trn1 v2.4h, v12.4h, v13.4h \n" /*01 11 21 31 05 15 25 35*/ \
"trn1 v1.4h, v10.4h, v11.4h \n" /*40 50 60 70 44 54 64 74*/ \
"trn1 v3.4h, v14.4h, v15.4h \n" /*41 51 61 71 45 55 65 75*/ \
\
"trn2 v4.4h, v8.4h, v9.4h \n" /*02 10 20 30 06 14 24 34*/ \
"trn2 v6.4h, v12.4h, v13.4h \n" /*03 11 21 31 07 15 25 35*/ \
"trn2 v5.4h, v10.4h, v11.4h \n" /*42 50 60 70 46 54 64 74*/ \
"trn2 v7.4h, v14.4h, v15.4h \n" /*43 51 61 71 47 55 65 75*/ \
\
"trn1 v8.2s, v0.2s, v1.2s \n" /*00 10 20 30 40 50 60 70*/ \
"trn1 v9.2s, v2.2s, v3.2s \n" /*01 11 21 31 41 51 61 71*/ \
"trn1 v10.2s, v4.2s, v5.2s \n" /*02 12 22 32 42 50 60 70*/ \
"trn1 v11.2s, v6.2s, v7.2s \n" /*03 13 23 33 41 51 61 71*/ \
\
"trn2 v12.2s, v0.2s, v1.2s \n" /*04 14 24 34 44 54 64 74*/ \
"trn2 v13.2s, v2.2s, v3.2s \n" /*05 15 25 35 45 55 65 75*/ \
"trn2 v14.2s, v4.2s, v5.2s \n" /*06 16 22 32 42 50 60 70*/ \
"trn2 v15.2s, v6.2s, v7.2s \n" /*07 17 23 33 41 51 61 71*/ \
\
"add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \
\
"subs %w[cnt], %w[cnt], #1 \n" \
"st1 {v8.8b}, [%[out0_ptr]], #8 \n" \
"st1 {v9.8b}, [%[out1_ptr]], #8 \n" \
"st1 {v10.8b}, [%[out2_ptr]], #8 \n" \
"st1 {v11.8b}, [%[out3_ptr]], #8 \n" \
\
"st1 {v11.8b}, [%[out4_ptr]], #8 \n" \
"st1 {v12.8b}, [%[out5_ptr]], #8 \n" \
"st1 {v13.8b}, [%[out6_ptr]], #8 \n" \
"st1 {v14.8b}, [%[out7_ptr]], #8 \n" \
"bne 1b \n"
#else
#define TRANS_C4 \
"1: \n" \
"vld1.32 {d0-d1}, [%[din0_ptr]] \n" \
"vld1.32 {d2-d3}, [%[din1_ptr]] \n" \
"vld1.32 {d4-d5}, [%[din2_ptr]] \n" \
"vld1.32 {d6-d7}, [%[din3_ptr]] \n" \
\
"vtrn.32 q0, q1 \n" /*00 10 02 12 01 11 03 13*/ \
"vtrn.32 q2, q3 \n" /*20 30 22 32 21 31 23 33 */ \
\
"add %[din0_ptr], %[din0_ptr], %[stride] \n" /* din+=c*size*/ \
"add %[din1_ptr], %[din1_ptr], %[stride] \n" /* din+=c*size*/ \
"add %[din2_ptr], %[din2_ptr], %[stride] \n" /* din+=c*size*/ \
"add %[din3_ptr], %[din3_ptr], %[stride] \n" /* din+=c*size*/ \
"vswp d1, d4 \n" \
"vswp d3, d6 \n" \
\
"subs %[cnt], %[cnt], #1 \n" \
"vst1.32 {d0-d1}, [%[out0_ptr]]! \n" \
"vst1.32 {d2-d3}, [%[out1_ptr]]! \n" \
"vst1.32 {d4-d5}, [%[out2_ptr]]! \n" \
"vst1.32 {d6-d7}, [%[out3_ptr]]! \n" \
"bne 1b \n"
#define TRANS_C8 \
"1: \n" \
"vld1.8 d0, [%[din0_ptr]] \n" \
"vld1.8 d1, [%[din1_ptr]] \n" \
"vld1.8 d2, [%[din2_ptr]] \n" \
"vld1.8 d3, [%[din3_ptr]] \n" \
\
"add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \
\
"vtrn.8 d0, d1 \n" /*00 10 02 12 04 14 06 16*/ \
"vtrn.8 d2, d3 \n" /*20 30 22 32 24 34 26 36 */ \
\
"vld1.8 d4, [%[din0_ptr]] \n" \
"vld1.8 d5, [%[din1_ptr]] \n" \
"vld1.8 d6, [%[din2_ptr]] \n" \
"vld1.8 d7, [%[din3_ptr]] \n" \
\
"vtrn.16 d0, d2 \n" /*00 10 20 30 04 14 24 34*/ \
"vtrn.16 d1, d3 \n" /* 01 11 21 31 05 15 25 35 */ \
"vtrn.8 d4, d5 \n" /*40 50 02 12 04 14 06 16*/ \
"vtrn.8 d6, d7 \n" /*60 70 22 32 24 34 26 36 */ \
\
"add %[din0_ptr], %[din0_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din1_ptr], %[din1_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din2_ptr], %[din2_ptr], %[stride_w] \n" /* din+=c*size*/ \
"add %[din3_ptr], %[din3_ptr], %[stride_w] \n" /* din+=c*size*/ \
\
"vtrn.16 d4, d6 \n" /*40 50 60 70 04 14 24 34*/ \
"vtrn.16 d5, d7 \n" /* 41 51 61 71 05 15 25 35 */ \
\
"vtrn.32 d0, d4 \n" /*00 10 20 30 40 50 60 70*/ \
"vtrn.32 d1, d5 \n" /* 01 11 21 31 41 51 61 71 */ \
"vtrn.32 d2, d6 \n" /*02 12 22 32 42 52 62 72*/ \
"vtrn.32 d3, d7 \n" /* 03 11 21 33 43 53 63 73 */ \
\
"subs %[cnt], %[cnt], #1 \n" \
"vst1.8 {d0}, [%[out0_ptr]]! \n" \
"vst1.8 {d1}, [%[out1_ptr]]! \n" \
"vst1.8 {d2}, [%[out2_ptr]]! \n" \
"vst1.8 {d3}, [%[out3_ptr]]! \n" \
"vst1.8 {d4}, [%[out4_ptr]]! \n" \
"vst1.8 {d5}, [%[out5_ptr]]! \n" \
"vst1.8 {d6}, [%[out6_ptr]]! \n" \
"vst1.8 {d7}, [%[out7_ptr]]! \n" \
"bne 1b \n"
#endif
template <>
void NCHW2NHWC<float>(int N, int C, int size, const float* X, float* Y) {
int cnt = C >> 2;
int remain = C % 4;
int sum = C * size;
int stride = size << 4; // 4 * size
int stride_w = stride >> 2;
for (int n = 0; n < N; n++) {
const float* din = X + n * sum;
float* dout = Y + n * sum;
int s = 0;
#pragma omp parallel for
for (s = 0; s < size - 3; s += 4) {
const float* din0_ptr = din + s;
const float* din1_ptr = din0_ptr + size;
const float* din2_ptr = din1_ptr + size;
const float* din3_ptr = din2_ptr + size;
float* out0_ptr = dout + s * C;
float* out1_ptr = out0_ptr + C;
float* out2_ptr = out1_ptr + C;
float* out3_ptr = out2_ptr + C;
int cnt_num = cnt;
if (cnt_num > 0) {
#ifdef __aarch64__
asm volatile(TRANS_C4
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[cnt] "+r"(cnt_num),
[stride] "+r"(stride)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12");
#else
asm volatile(TRANS_C4
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[cnt] "+r"(cnt_num),
[stride] "+r"(stride)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif
}
for (int i = 0; i < remain; i++) {
const float* ptr = din0_ptr;
*out0_ptr++ = *ptr++;
*out1_ptr++ = *ptr++;
*out2_ptr++ = *ptr++;
*out3_ptr++ = *ptr++;
din0_ptr += size;
}
}
// remain size
for (; s < size; s++) {
const float* din0_ptr = din + s;
const float* din1_ptr = din0_ptr + size;
const float* din2_ptr = din1_ptr + size;
const float* din3_ptr = din2_ptr + size;
float* out0_ptr = dout + s * C;
for (int i = 0; i < cnt; i++) {
*out0_ptr++ = *din0_ptr;
*out0_ptr++ = *din1_ptr;
*out0_ptr++ = *din2_ptr;
*out0_ptr++ = *din3_ptr;
din0_ptr += stride_w;
din1_ptr += stride_w;
din2_ptr += stride_w;
din3_ptr += stride_w;
}
for (int i = 0; i < remain; i++) {
*out0_ptr++ = *din0_ptr;
din0_ptr += size;
}
}
}
}
template <>
void NCHW2NHWC<int8_t>(int N, int C, int size, const int8_t* X, int8_t* Y) {
int cnt = C >> 3;
int remain = C % 8;
int sum = C * size;
int stride = size << 3; // 8 * size
int stride_w = size << 4; // 4 * size * 4
for (int n = 0; n < N; n++) {
const int8_t* din = X + n * sum;
int8_t* dout = Y + n * sum;
int s = 0;
#pragma omp parallel for
for (s = 0; s < size - 7; s += 8) {
const int8_t* din0_ptr = din + s;
const int8_t* din1_ptr = din0_ptr + size;
const int8_t* din2_ptr = din1_ptr + size;
const int8_t* din3_ptr = din2_ptr + size;
int8_t* out0_ptr = dout + s * C;
int8_t* out1_ptr = out0_ptr + C;
int8_t* out2_ptr = out1_ptr + C;
int8_t* out3_ptr = out2_ptr + C;
int8_t* out4_ptr = out3_ptr + C;
int8_t* out5_ptr = out4_ptr + C;
int8_t* out6_ptr = out5_ptr + C;
int8_t* out7_ptr = out6_ptr + C;
int cnt_num = cnt;
if (cnt_num > 0) {
#ifdef __aarch64__
asm volatile(TRANS_C8
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[out4_ptr] "+r"(out4_ptr),
[out5_ptr] "+r"(out5_ptr),
[out6_ptr] "+r"(out6_ptr),
[out7_ptr] "+r"(out7_ptr),
[cnt] "+r"(cnt_num),
[stride_w] "+r"(stride_w)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(TRANS_C8
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[out4_ptr] "+r"(out4_ptr),
[out5_ptr] "+r"(out5_ptr),
[out6_ptr] "+r"(out6_ptr),
[out7_ptr] "+r"(out7_ptr),
[cnt] "+r"(cnt_num),
[stride_w] "+r"(stride_w)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif
}
// const int8_t* din_ptr = din + 8 * cnt * size + s; // remain channel
for (int i = 0; i < remain; i++) {
const int8_t* ptr = din0_ptr;
*out0_ptr = *ptr++;
*out1_ptr = *ptr++;
*out2_ptr = *ptr++;
*out3_ptr = *ptr++;
din0_ptr += size;
*out4_ptr = *ptr++;
*out5_ptr = *ptr++;
*out6_ptr = *ptr++;
*out7_ptr = *ptr++;
}
}
// remain size
for (; s < size; s++) {
const int8_t* din0_ptr = din + s;
const int8_t* din1_ptr = din0_ptr + size;
const int8_t* din2_ptr = din1_ptr + size;
const int8_t* din3_ptr = din2_ptr + size;
const int8_t* din4_ptr = din3_ptr + size;
const int8_t* din5_ptr = din4_ptr + size;
const int8_t* din6_ptr = din5_ptr + size;
const int8_t* din7_ptr = din6_ptr + size;
int8_t* out0_ptr = dout + s * C;
for (int i = 0; i < cnt; i++) {
*out0_ptr++ = *din0_ptr;
*out0_ptr++ = *din1_ptr;
*out0_ptr++ = *din2_ptr;
*out0_ptr++ = *din3_ptr;
*out0_ptr++ = *din4_ptr;
*out0_ptr++ = *din5_ptr;
*out0_ptr++ = *din6_ptr;
*out0_ptr++ = *din7_ptr;
din0_ptr += stride;
din1_ptr += stride;
din2_ptr += stride;
din3_ptr += stride;
din4_ptr += stride;
din5_ptr += stride;
din6_ptr += stride;
din7_ptr += stride;
}
for (int i = 0; i < remain; i++) {
*out0_ptr++ = *din0_ptr;
din0_ptr += size;
}
}
}
}
template <>
void NHWC2NCHW<float>(int N, int C, int size, const float* X, float* Y) {
int cnt = size >> 2;
int remain = size % 4;
int sum = C * size;
int stride = C << 4; // 4 * size
int stride_w = C << 2;
for (int n = 0; n < N; n++) {
const float* din = X + n * sum;
float* dout = Y + n * sum;
int s = 0;
#pragma omp parallel for
for (s = 0; s < C - 3; s += 4) {
const float* din0_ptr = din + s;
const float* din1_ptr = din0_ptr + C;
const float* din2_ptr = din1_ptr + C;
const float* din3_ptr = din2_ptr + C;
float* out0_ptr = dout + s * size;
float* out1_ptr = out0_ptr + size;
float* out2_ptr = out1_ptr + size;
float* out3_ptr = out2_ptr + size;
int cnt_num = cnt;
if (cnt_num > 0) {
#ifdef __aarch64__
asm volatile(TRANS_C4
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[cnt] "+r"(cnt_num),
[stride] "+r"(stride)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11");
#else
asm volatile(TRANS_C4
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[cnt] "+r"(cnt_num),
[stride] "+r"(stride)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif
}
for (int i = 0; i < remain; i++) {
const float* ptr = din0_ptr;
*out0_ptr++ = *ptr++;
*out1_ptr++ = *ptr++;
*out2_ptr++ = *ptr++;
*out3_ptr++ = *ptr++;
din0_ptr += C;
}
}
// remain size
for (; s < C; s++) {
const float* din0_ptr = din + s;
const float* din1_ptr = din0_ptr + C;
const float* din2_ptr = din1_ptr + C;
const float* din3_ptr = din2_ptr + C;
float* out0_ptr = dout + s * size;
for (int i = 0; i < cnt; i++) {
*out0_ptr++ = *din0_ptr;
*out0_ptr++ = *din1_ptr;
*out0_ptr++ = *din2_ptr;
*out0_ptr++ = *din3_ptr;
din0_ptr += stride_w;
din1_ptr += stride_w;
din2_ptr += stride_w;
din3_ptr += stride_w;
}
for (int i = 0; i < remain; i++) {
*out0_ptr++ = *din0_ptr;
din0_ptr += C;
}
}
}
}
template <>
void NHWC2NCHW<int8_t>(int N, int C, int size, const int8_t* X, int8_t* Y) {
int cnt = size >> 3;
int remain = size % 8;
int sum = C * size;
int stride = C << 3; // 8 * size
int stride_w = C << 4; // 4 * size
for (int n = 0; n < N; n++) {
const int8_t* din = X + n * sum;
int8_t* dout = Y + n * sum;
int s = 0;
#pragma omp parallel for
for (s = 0; s < C - 7; s += 8) {
const int8_t* din0_ptr = din + s;
const int8_t* din1_ptr = din0_ptr + C;
const int8_t* din2_ptr = din1_ptr + C;
const int8_t* din3_ptr = din2_ptr + C;
const int8_t* din4_ptr = din3_ptr + C;
const int8_t* din5_ptr = din4_ptr + C;
const int8_t* din6_ptr = din5_ptr + C;
const int8_t* din7_ptr = din6_ptr + C;
int8_t* out0_ptr = dout + s * size;
int8_t* out1_ptr = out0_ptr + size;
int8_t* out2_ptr = out1_ptr + size;
int8_t* out3_ptr = out2_ptr + size;
int8_t* out4_ptr = out3_ptr + size;
int8_t* out5_ptr = out4_ptr + size;
int8_t* out6_ptr = out5_ptr + size;
int8_t* out7_ptr = out6_ptr + size;
int cnt_num = cnt;
if (cnt_num > 0) {
#ifdef __aarch64__
asm volatile(TRANS_C8
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[out4_ptr] "+r"(out4_ptr),
[out5_ptr] "+r"(out5_ptr),
[out6_ptr] "+r"(out6_ptr),
[out7_ptr] "+r"(out7_ptr),
[cnt] "+r"(cnt_num),
[stride_w] "+r"(stride_w)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(TRANS_C8
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[din3_ptr] "+r"(din3_ptr),
[out0_ptr] "+r"(out0_ptr),
[out1_ptr] "+r"(out1_ptr),
[out2_ptr] "+r"(out2_ptr),
[out3_ptr] "+r"(out3_ptr),
[out4_ptr] "+r"(out4_ptr),
[out5_ptr] "+r"(out5_ptr),
[out6_ptr] "+r"(out6_ptr),
[out7_ptr] "+r"(out7_ptr),
[cnt] "+r"(cnt_num),
[stride_w] "+r"(stride_w)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif
}
for (int i = 0; i < remain; i++) {
const int8_t* ptr = din0_ptr;
*out0_ptr++ = *ptr++;
*out1_ptr++ = *ptr++;
*out2_ptr++ = *ptr++;
*out3_ptr++ = *ptr++;
*out4_ptr++ = *ptr++;
*out5_ptr++ = *ptr++;
*out6_ptr++ = *ptr++;
*out7_ptr++ = *ptr++;
din0_ptr += C;
}
}
// remain size
for (; s < C; s++) {
const int8_t* din0_ptr = din + s;
const int8_t* din1_ptr = din0_ptr + C;
const int8_t* din2_ptr = din1_ptr + C;
const int8_t* din3_ptr = din2_ptr + C;
const int8_t* din4_ptr = din3_ptr + C;
const int8_t* din5_ptr = din4_ptr + C;
const int8_t* din6_ptr = din5_ptr + C;
const int8_t* din7_ptr = din6_ptr + C;
int8_t* out0_ptr = dout + s * size;
for (int i = 0; i < cnt; i++) {
*out0_ptr++ = *din0_ptr;
*out0_ptr++ = *din1_ptr;
*out0_ptr++ = *din2_ptr;
*out0_ptr++ = *din3_ptr;
*out0_ptr++ = *din4_ptr;
*out0_ptr++ = *din5_ptr;
*out0_ptr++ = *din6_ptr;
*out0_ptr++ = *din7_ptr;
din0_ptr += stride;
din1_ptr += stride;
din2_ptr += stride;
din3_ptr += stride;
din4_ptr += stride;
din5_ptr += stride;
din6_ptr += stride;
din7_ptr += stride;
}
for (int i = 0; i < remain; i++) {
*out0_ptr++ = *din0_ptr;
din0_ptr += C;
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y);
template <typename T>
void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -145,6 +145,12 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kOpenCL),
PRECISION(kFloat),
......
......@@ -46,6 +46,7 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li
add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layout_compute_arm ARM extra SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_mean_compute_arm ARM extra SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -101,7 +102,6 @@ lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS tran
lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm)
lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm)
lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm)
if(LITE_BUILD_EXTRA)
lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm)
lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/layout_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
<< "NCHW to NHWC should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
int c = input_dim[1]; \
int h = input_dim[2]; \
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NCHW2NHWC<type>(n, c, h * w, input, output);
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
<< "NHWC to NCHW should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
int h = input_dim[1]; \
int w = input_dim[2]; \
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NHWC2NCHW<type>(n, c, h * w, input, output);
template <>
void NCHWToNHWCCompute<PRECISION(kFloat)>::Run() {
NCHWTONHWC(float);
}
template <>
void NCHWToNHWCCompute<PRECISION(kInt8)>::Run() {
NCHWTONHWC(int8_t);
}
template <>
void NHWCToNCHWCompute<PRECISION(kFloat)>::Run() {
NHWCTONCHW(float);
}
template <>
void NHWCToNCHWCompute<PRECISION(kInt8)>::Run() {
NHWCTONCHW(int8_t);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
typedef paddle::lite::kernels::arm::NCHWToNHWCCompute<PRECISION(kFloat)>
NCHW_fp32;
typedef paddle::lite::kernels::arm::NCHWToNHWCCompute<PRECISION(kInt8)>
NCHW_int8;
typedef paddle::lite::kernels::arm::NHWCToNCHWCompute<PRECISION(kFloat)>
NHWC_fp32;
typedef paddle::lite::kernels::arm::NHWCToNCHWCompute<PRECISION(kInt8)>
NHWC_int8;
REGISTER_LITE_KERNEL(layout, kARM, kFloat, kNCHW, NCHW_fp32, nchw2nhwc)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout, kARM, kFloat, kNCHW, NHWC_fp32, nhwc2nchw)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(layout, kARM, kInt8, kNCHW, NCHW_int8, int8_nchw2nhwc)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout, kARM, kInt8, kNCHW, NHWC_int8, int8_nhwc2nchw)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(layout_once, kARM, kFloat, kNCHW, NCHW_fp32, nchw2nhwc)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout_once, kARM, kFloat, kNCHW, NHWC_fp32, nhwc2nchw)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(layout_once, kARM, kInt8, kNCHW, NCHW_int8, int8_nchw2nhwc)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(layout_once, kARM, kInt8, kNCHW, NHWC_int8, int8_nhwc2nchw)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <PrecisionType Ptype>
class NCHWToNHWCCompute : public KernelLite<TARGET(kARM), Ptype> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NCHWToNHWCCompute() = default;
};
template <PrecisionType Ptype>
class NHWCToNCHWCompute : public KernelLite<TARGET(kARM), Ptype> {
public:
using param_t = operators::LayoutParam;
void Run() override;
virtual ~NHWCToNCHWCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -8,4 +8,10 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(conv_transpose_compute_test SRCS conv_transpose_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(conv_int8_compute_test SRCS conv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(pool_compute_test SRCS pool_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(layout_compute_test SRCS layout_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/core/context.h"
#include "lite/operators/op_params.h"
#include "lite/tests/utils/naive_math_impl.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
#ifdef LITE_WITH_ARM
#include "lite/kernels/arm/layout_compute.h"
#endif // LITE_WITH_ARM
DEFINE_int32(power_mode,
3,
"power mode: "
"0 for POWER_HIGH;"
"1 for POWER_LOW;"
"2 for POWER_FULL;"
"3 for NO_BIND");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
DEFINE_int32(in_channel, 32, "input channel");
DEFINE_int32(in_height, 112, "input height");
DEFINE_int32(in_width, 112, "input width");
DEFINE_bool(flag_nchw, true, "do nchw to nhwc");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::LayoutParam LayoutParam;
using paddle::lite::Timer;
#define IN(n, c, h, w) \
input_data[w + h * input_w + c * input_h * input_w + \
n * input_c * input_h * input_w]
#define OUT(n, c, h, w) \
output_data[w + h * output_w + c * output_h * output_w + \
n * output_c * output_h * output_w]
template <typename Dtype>
void nchw2nhwc_ref(const Tensor* input, Tensor* output) {
auto* input_data = input->data<Dtype>();
auto* output_data = output->mutable_data<Dtype>();
int input_n = input->dims()[0];
int input_c = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
OUT(n, h, w, c) = IN(n, c, h, w);
}
}
}
}
}
#undef IN
#undef OUT
#define IN(n, h, w, c) \
input_data[c + w * input_c + h * input_w * input_c + \
n * input_h * input_w * input_c]
#define OUT(n, h, w, c) \
output_data[c + w * output_c + h * output_w * output_c + \
n * output_h * output_w * output_c]
template <typename Dtype>
void nhwc2nchw_ref(const Tensor* input, Tensor* output) {
auto* input_data = input->data<Dtype>();
auto* output_data = output->mutable_data<Dtype>();
int input_n = input->dims()[0];
int input_h = input->dims()[1];
int input_w = input->dims()[2];
int input_c = input->dims()[3];
int output_h = output->dims()[1];
int output_w = output->dims()[2];
int output_c = output->dims()[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
OUT(n, c, h, w) = IN(n, h, w, c);
}
}
}
}
}
#ifdef LITE_WITH_ARM
void test_layout_fp32_nchw(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LayoutParam param;
param.x = new Tensor;
const_cast<Tensor*>(param.x)->set_precision(PRECISION(kFloat));
param.y = new Tensor;
param.y->set_precision(PRECISION(kFloat));
for (auto& cls : power_mode) {
for (auto& th : thread_num) {
paddle::lite::kernels::arm::NCHWToNHWCCompute<PRECISION(kFloat)> layout;
DDim dim_out({dim_in[0], dim_in[2], dim_in[3], dim_in[1]});
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
const_cast<Tensor*>(param.x)->Resize(dim_in);
param.y->Resize(dim_out);
layout.SetParam(param);
paddle::lite::fill_tensor_rand(
*(const_cast<Tensor*>(param.x)), -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<float>();
Tensor tout_basic;
if (FLAGS_check_result) {
tout_basic.set_precision(PRECISION(kFloat));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0.f);
auto dout_basic = tout_basic.mutable_data<float>();
nchw2nhwc_ref<float>(param.x, &tout_basic);
}
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
layout.Run();
}
/// compute
Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
layout.Run();
t0.end();
}
double gops = 2.0 * dim_out.production();
LOG(INFO) << "layout fp32: input shape: " << dim_in << ", output shape"
<< dim_out << ",running time, avg: " << t0.get_average_ms()
<< ", min time: " << t0.get_min_time()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time();
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
LOG(WARNING) << "din";
print_tensor(*(const_cast<Tensor*>(param.x)));
LOG(WARNING) << "basic result";
print_tensor(tout_basic);
LOG(WARNING) << "lite result";
print_tensor(*param.y);
Tensor tdiff;
tdiff.Resize(tout_basic.dims());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic, *param.y, tdiff);
print_tensor(tdiff);
LOG(FATAL) << "test fp32 layout: input: " << dim_in
<< ", output: " << dim_out << ", flag_nchw: "
<< (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
}
}
LOG(INFO) << "test fp32 layout: input: " << dim_in
<< ", output: " << dim_out
<< ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
}
}
}
delete param.x;
delete param.y;
}
void test_layout_fp32_nhwc(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LayoutParam param;
param.x = new Tensor;
const_cast<Tensor*>(param.x)->set_precision(PRECISION(kFloat));
param.y = new Tensor;
param.y->set_precision(PRECISION(kFloat));
for (auto& cls : power_mode) {
for (auto& th : thread_num) {
paddle::lite::kernels::arm::NHWCToNCHWCompute<PRECISION(kFloat)> layout;
// n h w c == n c h w
DDim dim_out({dim_in[0], dim_in[3], dim_in[1], dim_in[2]});
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
const_cast<Tensor*>(param.x)->Resize(dim_in);
param.y->Resize(dim_out);
layout.SetParam(param);
paddle::lite::fill_tensor_rand(
*(const_cast<Tensor*>(param.x)), -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<float>();
Tensor tout_basic;
if (FLAGS_check_result) {
tout_basic.set_precision(PRECISION(kFloat));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0.f);
auto dout_basic = tout_basic.mutable_data<float>();
nhwc2nchw_ref<float>(param.x, &tout_basic);
}
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
layout.Run();
}
/// compute
Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
layout.Run();
t0.end();
}
double gops = 2.0 * dim_out.production();
LOG(INFO) << "layout fp32: input shape: " << dim_in << ", output shape"
<< dim_out << ",running time, avg: " << t0.get_average_ms()
<< ", min time: " << t0.get_min_time()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time();
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
LOG(WARNING) << "din";
print_tensor(*(const_cast<Tensor*>(param.x)));
LOG(WARNING) << "basic result";
print_tensor(tout_basic);
LOG(WARNING) << "lite result";
print_tensor(*param.y);
Tensor tdiff;
tdiff.Resize(tout_basic.dims());
tdiff.set_precision(PRECISION(kFloat));
tensor_diff(tout_basic, *param.y, tdiff);
print_tensor(tdiff);
LOG(FATAL) << "test fp32 layout: input: " << dim_in
<< ", output: " << dim_out << ", flag_nchw: "
<< (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
}
}
LOG(INFO) << "test fp32 layout: input: " << dim_in
<< ", output: " << dim_out
<< ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
}
}
}
delete param.x;
delete param.y;
}
void test_layout_int8_nchw(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LayoutParam param;
param.x = new Tensor;
const_cast<Tensor*>(param.x)->set_precision(PRECISION(kInt8));
param.y = new Tensor;
param.y->set_precision(PRECISION(kInt8));
for (auto& cls : power_mode) {
for (auto& th : thread_num) {
paddle::lite::kernels::arm::NCHWToNHWCCompute<PRECISION(kInt8)> layout;
DDim dim_out({dim_in[0], dim_in[2], dim_in[3], dim_in[1]});
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
const_cast<Tensor*>(param.x)->Resize(dim_in);
param.y->Resize(dim_out);
layout.SetParam(param);
paddle::lite::fill_tensor_rand(*(const_cast<Tensor*>(param.x)));
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<int8_t>();
Tensor tout_basic;
if (FLAGS_check_result) {
tout_basic.set_precision(PRECISION(kInt8));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0);
auto dout_basic = tout_basic.mutable_data<int8_t>();
nchw2nhwc_ref<int8_t>(param.x, &tout_basic);
}
LOG(INFO) << "saber compute";
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
layout.Run();
}
/// compute
Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
layout.Run();
t0.end();
}
LOG(INFO) << "saber compute end";
double gops = 2.0 * dim_out.production();
LOG(INFO) << "layout int8: input shape: " << dim_in << ", output shape"
<< dim_out << ",running time, avg: " << t0.get_average_ms()
<< ", min time: " << t0.get_min_time()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time();
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
LOG(WARNING) << "din";
print_tensor(*(const_cast<Tensor*>(param.x)));
LOG(WARNING) << "basic result";
print_tensor(tout_basic);
LOG(WARNING) << "lite result";
print_tensor(*param.y);
Tensor tdiff;
tdiff.Resize(tout_basic.dims());
tdiff.set_precision(PRECISION(kInt8));
tensor_diff(tout_basic, *param.y, tdiff);
print_tensor(tdiff);
LOG(FATAL) << "test int8 layout: input: " << dim_in
<< ", output: " << dim_out << ", flag_nchw: "
<< (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
}
}
LOG(INFO) << "test int8 layout: input: " << dim_in
<< ", output: " << dim_out
<< ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
}
}
}
delete param.x;
delete param.y;
}
void test_layout_int8_nhwc(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LayoutParam param;
param.x = new Tensor;
const_cast<Tensor*>(param.x)->set_precision(PRECISION(kInt8));
param.y = new Tensor;
param.y->set_precision(PRECISION(kInt8));
for (auto& cls : power_mode) {
for (auto& th : thread_num) {
paddle::lite::kernels::arm::NHWCToNCHWCompute<PRECISION(kInt8)> layout;
// n h w c == n c h w
DDim dim_out({dim_in[0], dim_in[3], dim_in[1], dim_in[2]});
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
const_cast<Tensor*>(param.x)->Resize(dim_in);
param.y->Resize(dim_out);
layout.SetParam(param);
paddle::lite::fill_tensor_rand(*(const_cast<Tensor*>(param.x)));
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<int8_t>();
Tensor tout_basic;
if (FLAGS_check_result) {
tout_basic.set_precision(PRECISION(kInt8));
tout_basic.Resize(dim_out);
fill_tensor_const(tout_basic, 0.f);
auto dout_basic = tout_basic.mutable_data<int8_t>();
nhwc2nchw_ref<int8_t>(param.x, &tout_basic);
}
LOG(INFO) << "saber compute";
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
layout.Run();
}
/// compute
Timer t0;
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
layout.Run();
t0.end();
}
LOG(INFO) << "run";
double gops = 2.0 * dim_out.production();
LOG(INFO) << "layout int8: input shape: " << dim_in << ", output shape"
<< dim_out << ",running time, avg: " << t0.get_average_ms()
<< ", min time: " << t0.get_min_time()
<< ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time();
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tout_basic, *param.y, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-3f) {
if (max_diff > 5e-4f) {
LOG(WARNING) << "din";
print_tensor(*(const_cast<Tensor*>(param.x)));
LOG(WARNING) << "basic result";
print_tensor(tout_basic);
LOG(WARNING) << "lite result";
print_tensor(*param.y);
Tensor tdiff;
tdiff.Resize(tout_basic.dims());
tdiff.set_precision(PRECISION(kInt8));
tensor_diff(tout_basic, *param.y, tdiff);
print_tensor(tdiff);
LOG(FATAL) << "test int8 layout: input: " << dim_in
<< ", output: " << dim_out << ", flag_nchw: "
<< (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
}
}
LOG(INFO) << "test int8 layout: input: " << dim_in
<< ", output: " << dim_out
<< ", flag_nchw: " << (flag_nchw ? "nchw2nhwc" : "nhwc2nchw")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
}
}
}
delete param.x;
delete param.y;
}
#else
void test_layout_fp32_nchw(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {}
void test_layout_fp32_nhwc(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {}
void test_layout_int8_nchw(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {}
void test_layout_int8_nhwc(DDim dim_in,
bool flag_nchw,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {}
#endif // LITE_WITH_ARM
#if 1 //
TEST(TestLayout, test_Layout_fp32) {
if (FLAGS_basic_test) {
for (auto n : {1, 3}) {
for (auto c : {1, 3, 5, 32}) {
for (auto h : {3, 16, 20, 32}) {
for (auto w : {3, 4, 32, 112}) {
for (auto nchw2nhwc : {true, false}) {
DDim dim_in({n, c, h, w});
if (nchw2nhwc) {
LOG(INFO) << "NCHW2NHWC";
test_layout_fp32_nchw(
dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode});
} else {
LOG(INFO) << "NHWC2NCHW";
test_layout_fp32_nhwc(
dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode});
}
}
}
}
}
}
}
}
#endif
#if 1
TEST(TestLayout, test_Layout_int8) {
if (FLAGS_basic_test) {
for (auto n : {1, 3}) {
for (auto c : {1, 3, 5, 32}) {
for (auto h : {3, 16, 20, 32}) {
for (auto w : {3, 4, 32, 112}) {
for (auto nchw2nhwc : {true, false}) {
DDim dim_in({n, c, h, w});
if (nchw2nhwc) {
LOG(INFO) << "NCHW2NHWC int8";
test_layout_int8_nchw(
dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode});
} else {
LOG(INFO) << "NHWC2NCHW int8";
test_layout_int8_nhwc(
dim_in, nchw2nhwc, {1, 2, 4}, {FLAGS_power_mode});
}
}
}
}
}
}
}
}
#endif
#if 1 /// custom
TEST(TestLayoutCustom, test_Layout_custom_size) {
test_layout_fp32_nchw(
{DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})},
true,
{FLAGS_threads},
{FLAGS_power_mode});
}
#endif // custom
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册