提交 e659e4ab 编写于 作者: H HappyAngel 提交者: GitHub

[lite][arm]add conv relu6 and leaky_relu in conv_dw_3x3s2, test=develop (#2618)

* fix conv 2-pad to 4-pad

* fix compute conv shape

* fix pad, test=develop

* change conv_depthwise_3x3s1_fp.cc name to conv3x3s1p01_depthwise_fp32.cc to distinguish between conv3x3s1_depthwise_fp32.cc

* delete printf note in conv3x3s1, test=develop

* delete printf note, test=develop

* delete gem_sdot.h, test=develop

it is coped from __gemm_sdot_meta_.h

* update compute padding, test=develop

* fix padding size, must be 2 or 4. test=develop

* fix format in operators/conv_op.cc, test=develop

* change #if 0 to #if 1, test=develop

* put 2-pad to 4-pad in AttachImpl, test=develop

* fix clang-format error inn tests/math/connv_compute_test, test=develop

* fix x86 test result error, test=develop

* add asymmetric padding test case in liite/tests/math/conv_compute.cc, test=develop

* change paddings type to support dynamically modify, test=develop

* fix x86 build error in connv_compute_test, test=develop

* fix opencl build error, test=develop

* fix oopencl build error, test=develop

* fix  opencl/conv_compute build error, test=develop

* fix  opencl/conv_compute build error, test=develop

* fix format in kernels/opencl/conv_computte_ttest,test=develop

* fix build error, test=develop

fix build error in kernels/x86/conv_compute.h

* fix ccompute shape error in ooperators/conv_op.h, test=develop

* add conv_reelu6 and conv leaky_relu in conv_3x3s1_direct

* add conv_relu6 in c1, c2, c4,test=develop

* fix conflict in conv_bloock_utils.h, test=develop

* add relu6 and leankyrelu in conv_3x3s1_dw

* add conv_3x3s1px_dw relu6 and leaky_relu fusion, test=develop

* fix conflict in tests/math/conv_compute_arm, test=develop

* fix build error in winograd arm, test=develop

* channge act_param as pointer in conv_block_tuils.h, test=develop

* fix winograd in no equal 4-padding compute error, test=develop

* add conv relu6 and leaky_relu in conv_dw_3x3s2, test=develop

* fix format, test=develop

* fix format in conv_block_utils, test=develop

* move updatePadding from conv_op.cc to conv_op.h, test=develop

* fix format conv_op.h, test=develop

* fix buuilde error in conv_oop.h, test=develop

* remove flag_relu parameter in conv_3x3_depthwise, test=develop
上级 04ab34b6
......@@ -25,7 +25,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -40,7 +39,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -55,7 +53,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -70,7 +67,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -93,7 +89,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
const float *bias,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext *ctx) {
if (pad == 0) {
......@@ -103,7 +98,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
......@@ -118,7 +112,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
......@@ -136,7 +129,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
......@@ -151,7 +143,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
......@@ -163,7 +154,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
}
}
}
// clang-format on
#ifdef __aarch64__
#define INIT_S1 \
"PRFM PLDL1KEEP, [%[din_ptr0]] \n" \
......@@ -2318,7 +2309,6 @@ void act_switch_3x3s1p1(const float *din_ptr0,
}
}
#endif
// clang-format on
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
......@@ -2328,7 +2318,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -2857,7 +2846,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -3443,7 +3431,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -3579,129 +3566,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
}
int cnt = tile_w;
/*
if (flag_relu) {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
*/
act_switch_3x3s1p0(din_ptr0,
din_ptr1,
din_ptr2,
......@@ -3760,90 +3624,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
/*
if (flag_relu) {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}*/
act_switch_3x3s1p0(din_ptr0,
din_ptr1,
din_ptr2,
......@@ -4174,7 +3954,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
......@@ -4213,14 +3992,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
// #ifdef __aarch64__
// float32x4_t wbias;
// if (flag_bias) {
// wbias = vdupq_n_f32(bias[i]);
// } else {
// wbias = vdupq_n_f32(0.f);
// }
// #endif // __aarch64__
float32x4_t wbias;
float bias_val = 0.f;
if (flag_bias) {
......@@ -4261,137 +4032,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
break;
}
}
/*
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
*/
unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0_s(dr0,
dr1,
......
......@@ -836,7 +836,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
/// get workspace
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle {
......@@ -24,13 +25,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_s(float* dout,
......@@ -38,13 +39,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias(float* dout,
......@@ -52,13 +53,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s(float* dout,
......@@ -66,13 +67,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din,
......@@ -88,7 +89,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx) {
if (pad == 0) {
if (w_in > 7) {
......@@ -97,13 +98,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s(dout,
......@@ -111,13 +112,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
}
}
......@@ -128,13 +129,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s(dout,
......@@ -142,13 +143,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
}
}
......@@ -412,6 +413,83 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define LEFT_RESULT_S2_RELU6 \
"fmax v16.4s, v16.4s, %[vzero].4s \n" \
"ld1 {v22.4s}, [%[six_ptr]] \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
"fmin v16.4s, v16.4s, v22.4s \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"fmax v17.4s, v17.4s, %[vzero].4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"fmin v17.4s, v17.4s, v22.4s \n" \
\
"cmp %w[cnt], #1 \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define LEFT_RESULT_S2_LEAKY_RELU \
"ld1 {v22.4s}, [%[scale_ptr]] \n" \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"cmp %w[cnt], #1 \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define MID_RESULT_S2_RELU \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
......@@ -438,6 +516,58 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"bne 2b \n"
#define MID_RESULT_S2_RELU6 \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"fmin v16.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"fmin v17.4s, v17.4s, v22.4s \n" \
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"bne 2b \n"
#define MID_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
......@@ -456,6 +586,47 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define RIGHT_RESULT_S2_RELU6 \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"fmin v16.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"fmin v17.4s, v17.4s, v22.4s \n" \
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define RIGHT_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define COMPUTE_S_S2 \
"movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
......@@ -500,7 +671,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"fmax v4.4s, v4.4s, v9.4s \n" \
\
"st1 {v4.4s}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \
"movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
......@@ -537,7 +707,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"fadd v4.4s, v4.4s, v16.4s \n"
#define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \
"fmax v4.4s, v4.4s, v9.4s \n" \
"st1 {v4.4s}, [%[out]] \n"
......@@ -682,7 +851,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"cmp %[cnt], #1 \n" \
"blt 1f \n"
#define MID_RESULT_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \
"subs %[cnt], #1 \n" \
......@@ -739,7 +907,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu\n" \
\
......@@ -787,155 +954,38 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#endif
/**
* \brief depthwise convolution kernel 3x3, stride 2
* w_in > 7
*/
void conv_depthwise_3x3s2p1_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); //
int size_right_pad = w_out * 2 - w_in;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else {
dr0 = dr4;
dr1 = dr0 + w_in;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
if (flag_relu) {
void act_switch_3x3s2p1(const float* din0_ptr,
const float* din1_ptr,
const float* din2_ptr,
const float* din3_ptr,
const float* din4_ptr,
float* doutr0_ptr,
float* doutr1_ptr,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2,
uint32x4_t wmask,
float32x4_t wbias,
float32x4_t vzero,
int cnt,
int cnt_remain,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
......@@ -980,6 +1030,110 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
"v19",
"v20",
"v21");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
......@@ -1025,6 +1179,167 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
"v20",
"v21");
}
}
#endif
/**
* \brief depthwise convolution kernel 3x3, stride 2
* w_in > 7
*/
void conv_depthwise_3x3s2p1_bias(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4);
int size_right_pad = w_out * 2 - w_in;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else {
dr0 = dr4;
dr1 = dr0 + w_in;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
act_switch_3x3s2p1(din0_ptr,
din1_ptr,
din2_ptr,
din3_ptr,
din4_ptr,
doutr0_ptr,
doutr1_ptr,
wr0,
wr1,
wr2,
vmask_rp1,
vmask_rp2,
wmask,
wbias,
vzero,
cnt,
cnt_remain,
act_param);
doutr0 = doutr0 + 2 * w_out;
}
#else
......@@ -1061,37 +1376,6 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
if (flag_relu) {
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
......@@ -1120,6 +1404,9 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
"q13",
"q14",
"q15");
// do act
if (act_param.has_active) {
act_switch_process(doutr0, doutr0, w_out, &act_param);
}
doutr0 = doutr0 + w_out;
}
......@@ -1136,13 +1423,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
......@@ -1198,30 +1485,6 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
......@@ -1244,10 +1507,8 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
"v13",
"v14",
"v15");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
......@@ -1272,44 +1533,284 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
"q13",
"q14",
"q15");
#endif
// do act
if (act_param.has_active) {
act_switch_process(out_buf, out_buf, w_out, &act_param);
}
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
hs += 2;
he += 2;
}
}
}
}
#ifdef __aarch64__
void act_switch_3x3s2p0(const float* din0_ptr,
const float* din1_ptr,
const float* din2_ptr,
const float* din3_ptr,
const float* din4_ptr,
float* doutr0_ptr,
float* doutr1_ptr,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2,
uint32x4_t wmask,
float32x4_t wbias,
float32x4_t vzero,
int cnt,
int cnt_remain,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU6
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU6
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 "4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
hs += 2;
he += 2;
}
}
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
}
#endif
/**
* \brief depthwise convolution kernel 3x3, stride 2
*/
......@@ -1319,13 +1820,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
......@@ -1438,117 +1939,24 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
if (flag_relu) {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
act_switch_3x3s2p0(din0_ptr,
din1_ptr,
din2_ptr,
din3_ptr,
din4_ptr,
doutr0_ptr,
doutr1_ptr,
wr0,
wr1,
wr2,
vmask_rp1,
vmask_rp2,
wmask,
wbias,
vzero,
cnt,
cnt_remain,
act_param);
doutr0 = doutr0 + 2 * w_out;
}
#else
......@@ -1576,36 +1984,6 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
if (flag_relu) {
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
......@@ -1634,6 +2012,8 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
"q13",
"q14",
"q15");
if (act_param.has_active) {
act_switch_process(doutr0, doutr0, w_out, &act_param);
}
doutr0 = doutr0 + w_out;
}
......@@ -1650,13 +2030,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
......@@ -1718,33 +2098,6 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
......@@ -1770,35 +2123,8 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
"v14",
"v15",
"v16");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
......@@ -1824,8 +2150,10 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
"q13",
"q14",
"q15");
}
#endif
if (act_param.has_active) {
act_switch_process(out_buf, out_buf, w_out, &act_param);
}
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
......
......@@ -25,6 +25,511 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
#ifdef __aarch64__
#define COMPUTE \
"ldr q8, [%[bias]]\n" /* load bias */ \
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \
"and v19.16b, v8.16b, v8.16b\n" \
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \
"and v20.16b, v8.16b, v8.16b\n" \
"ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ \
"and v21.16b, v8.16b, v8.16b\n" \
"ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ \
"and v22.16b, v8.16b, v8.16b\n" \
"ldr q8, [%[inr0]]\n" /* load input r0*/ \
"fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \
"fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \
"fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \
"fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \
"fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \
"ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \
"fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \
"fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \
"fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \
"ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \
"ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \
"ldr q8, [%[inr1]]\n" /* load input r1*/ \
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/ \
"fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/ \
"fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/ \
"fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/ \
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/ \
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/ \
"fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/ \
"fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/ \
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/ \
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/ \
"ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/ \
"ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/ \
"ldr q8, [%[inr2]]\n" /* load input r2*/ \
"fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/ \
"fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/ \
"fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/ \
"fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/ \
"fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/ \
"fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/ \
"fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/ \
"fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/ \
"fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/ \
"fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/ \
"fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/ \
"fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/ \
"trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \
"trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \
"trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \
"trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \
"trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \
"trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \
"trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \
"trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
#define RELU /* relu */ \
"movi v0.4s, #0\n" /* for relu */ \
"fmax v19.4s, v19.4s, v0.4s\n" \
"fmax v20.4s, v20.4s, v0.4s\n" \
"fmax v21.4s, v21.4s, v0.4s\n" \
"fmax v22.4s, v22.4s, v0.4s\n"
#define RELU6 /* relu6 */ \
"fmin v19.4s, v19.4s, %[vsix].4s\n" \
"fmin v20.4s, v20.4s, %[vsix].4s\n" \
"fmin v21.4s, v21.4s, %[vsix].4s\n" \
"fmin v22.4s, v22.4s, %[vsix].4s\n"
#define LEAKY_RELU /* LeakyRelu */ \
"movi v0.4s, #0\n" /* for relu */ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \
"cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \
"cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \
"cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
"bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \
"bif v19.16b, v8.16b, v7.16b \n" /* choose*/
#define STORE /* save result */ \
"str q19, [%[outc0]], #16\n" \
"str q20, [%[outc1]], #16\n" \
"str q21, [%[outc2]], #16\n" \
"str q22, [%[outc3]], #16\n"
#else
#define COMPUTE \
/* fill with bias */ \
"vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ /* load weights */ \
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ \
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \
"vand.i32 q12, q8, q8\n" \
"vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \
"vand.i32 q13, q8, q8\n" \
"vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \
"vand.i32 q14, q8, q8\n" \
"vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ \
"vand.i32 q15, q8, q8\n" \
"vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ \
"vmla.f32 q12, q9, q0 @ w0 * inr0\n" \
"vmla.f32 q13, q9, q2 @ w0 * inr2\n" \
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ \
"vmla.f32 q14, q9, q4 @ w0 * inr4\n" \
"vmla.f32 q15, q9, q6 @ w0 * inr6\n" \
"vmla.f32 q12, q10, q1 @ w1 * inr1\n" \
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \
"vmla.f32 q13, q10, q3 @ w1 * inr3\n" \
"vmla.f32 q14, q10, q5 @ w1 * inr5\n" \
"vmla.f32 q15, q10, q7 @ w1 * inr7\n" \
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */ \
"vmla.f32 q12, q11, q2 @ w2 * inr2\n" \
"vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n" \
"vmla.f32 q13, q11, q4 @ w2 * inr4\n" \
"vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n" \
"vmla.f32 q14, q11, q6 @ w2 * inr6\n" \
"vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n" \
"vmla.f32 q15, q11, q8 @ w2 * inr8\n" /* mul r1 with w3, w4*/ \
"vmla.f32 q12, q9, q0 @ w3 * inr0\n" \
"vmla.f32 q13, q9, q2 @ w3 * inr2\n" \
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */ \
"vmla.f32 q14, q9, q4 @ w3 * inr4\n" \
"vmla.f32 q15, q9, q6 @ w3 * inr6\n" \
"vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/ \
"vmla.f32 q12, q10, q1 @ w4 * inr1\n" \
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \
"vmla.f32 q13, q10, q3 @ w4 * inr3\n" \
"vmla.f32 q14, q10, q5 @ w4 * inr5\n" \
"vmla.f32 q15, q10, q7 @ w4 * inr7\n" \
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */ \
"vmla.f32 q12, q11, q2 @ w5 * inr2\n" \
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \
"vmla.f32 q13, q11, q4 @ w5 * inr4\n" \
"vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \
"vmla.f32 q14, q11, q6 @ w5 * inr6\n" \
"vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n" \
"vmla.f32 q15, q11, q8 @ w5 * inr8\n" /* mul r2 with w6, w7*/ \
"vmla.f32 q12, q9, q0 @ w6 * inr0\n" \
"vmla.f32 q13, q9, q2 @ w6 * inr2\n" \
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */ \
"vmla.f32 q14, q9, q4 @ w6 * inr4\n" \
"vmla.f32 q15, q9, q6 @ w6 * inr6\n" \
"vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/ \
"vmla.f32 q12, q10, q1 @ w7 * inr1\n" \
"vmla.f32 q13, q10, q3 @ w7 * inr3\n" \
"vmla.f32 q14, q10, q5 @ w7 * inr5\n" \
"vmla.f32 q15, q10, q7 @ w7 * inr7\n" \
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \
"vmla.f32 q12, q11, q2 @ w8 * inr2\n" \
"vmla.f32 q13, q11, q4 @ w8 * inr4\n" \
"vmla.f32 q14, q11, q6 @ w8 * inr6\n" \
"vmla.f32 q15, q11, q8 @ w8 * inr8\n" /* transpose */ \
"vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \
"vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \
"vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \
"vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/
#define RELU /* relu */ \
"vmov.u32 q0, #0\n" \
"vld1.32 {d2-d3}, [%[six_ptr]]\n" \
"vmax.f32 q12, q12, q0\n" \
"vmax.f32 q13, q13, q0\n" \
"vmax.f32 q14, q14, q0\n" \
"vmax.f32 q15, q15, q0\n"
#define RELU6 /* relu6 */ \
"vmin.f32 q12, q12, q1\n" \
"vmin.f32 q13, q13, q1\n" \
"vmin.f32 q14, q14, q1\n" \
"vmin.f32 q15, q15, q1\n"
#define LEAKY_RELU /* LeakyRelu */ \
"vmov.u32 q0, #0\n" \
"vld1.32 {d2-d3}, [%[scale_ptr]]\n" \
"vcge.f32 q2, q12, q0 @ q0 > 0 \n" \
"vcge.f32 q4, q13, q0 @ q0 > 0 \n" \
"vcge.f32 q6, q14, q0 @ q0 > 0 \n" \
"vcge.f32 q8, q15, q0 @ q0 > 0 \n" \
"vmul.f32 q3, q12, q1 @ mul \n" \
"vmul.f32 q5, q13, q1 @ mul \n" \
"vmul.f32 q7, q14, q1 @ mul \n" \
"vmul.f32 q9, q15, q1 @ mul \n" \
"vbif q12, q3, q2 @ choose \n" \
"vbif q13, q5, q4 @ choose \n" \
"vbif q14, q7, q6 @ choose \n" \
"vbif q15, q9, q8 @ choose \n"
#define STORE /* save result */ \
"vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \
"vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \
"vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \
"vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/
#endif
void act_switch_3x3s2(const float* inr0,
const float* inr1,
const float* inr2,
float* outc0,
float* outc1,
float* outc2,
float* outc3,
const float* weight_c,
float* bias_local,
float32x4_t w0,
float32x4_t w1,
float32x4_t w2,
float32x4_t w3,
float32x4_t w4,
float32x4_t w5,
float32x4_t w6,
float32x4_t w7,
float32x4_t w8,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(tmp);
float32x4_t vscale = vdupq_n_f32(ss);
#else
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local), [six_ptr] "r"(vsix)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(COMPUTE RELU RELU6 STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local),
[vsix] "w"(vsix)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE RELU RELU6 STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local), [six_ptr] "r"(vsix)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__
asm volatile(COMPUTE LEAKY_RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local),
[vscale] "w"(vscale)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE LEAKY_RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local), [scale_ptr] "r"(vscale)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(COMPUTE STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
}
void conv_3x3s2_depthwise_fp32(const float* i_data,
float* o_data,
......@@ -38,6 +543,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const float* weights,
const float* bias,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx) {
auto paddings = *param.paddings;
int threads = ctx->threads();
......@@ -51,11 +557,9 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh * 2 + 1;
const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size =
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
auto workspace_size = threads * prein_size + win_round + ow_round;
ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
/// get workspace
......@@ -77,6 +581,8 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block;
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel;
......@@ -147,201 +653,47 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
outc2 = pre_out + 8;
outc3 = pre_out + 12;
}
// clang-format off
#ifdef __aarch64__
asm volatile(
"ldr q8, [%[bias]]\n" /* load bias */
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/
"and v19.16b, v8.16b, v8.16b\n"
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/
"and v20.16b, v8.16b, v8.16b\n"
"ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/
"and v21.16b, v8.16b, v8.16b\n"
"ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/
"and v22.16b, v8.16b, v8.16b\n"
"ldr q8, [%[inr0]]\n" /* load input r0*/
/* r0 mul w0-w2, get out */
"fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/
"fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/
"fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/
"fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/
"fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/
"ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/
"fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/
"fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/
"fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/
"fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/
"ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/
"fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/
"ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/
"fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/
"fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/
"ldr q8, [%[inr1]]\n" /* load input r1*/
/* r1, mul w3-w5, get out */
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/
"fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/
"fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/
"fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/
"fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/
"fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/
"fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/
"fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/
"ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/
"fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/
"ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/
"fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/
"ldr q8, [%[inr2]]\n" /* load input r2*/
/* r2, mul w6-w8, get out r0, r1 */
"fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/
"fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/
"fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/
"fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/
"fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/
"fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/
"fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/
"fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/
"fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/
"fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/
"fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/
"fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/
/* transpose */
"trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/
"trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
/* relu */
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
/* save result */
"0:\n"
"str q19, [%[outc0]], #16\n"
"str q20, [%[outc1]], #16\n"
"str q21, [%[outc2]], #16\n"
"str q22, [%[outc3]], #16\n"
:[inr0] "+r"(inr0), [inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0]"+r"(outc0), [outc1]"+r"(outc1),
[outc2]"+r"(outc2), [outc3]"+r"(outc3)
:[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5),
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8),
[bias] "r" (bias_local), [flag_relu]"r"(flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8", "v19","v20","v21","v22"
);
act_switch_3x3s2(inr0,
inr1,
inr2,
outc0,
outc1,
outc2,
outc3,
weight_c,
bias_local,
w0,
w1,
w2,
w3,
w4,
w5,
w6,
w7,
w8,
act_param);
#else
asm volatile(
/* fill with bias */
"vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */
/* load weights */
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/
"vand.i32 q12, q8, q8\n"
"vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/
"vand.i32 q13, q8, q8\n"
"vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/
"vand.i32 q14, q8, q8\n"
"vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/
"vand.i32 q15, q8, q8\n"
"vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/
/* mul r0 with w0, w1, w2 */
"vmla.f32 q12, q9, q0 @ w0 * inr0\n"
"vmla.f32 q13, q9, q2 @ w0 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */
"vmla.f32 q14, q9, q4 @ w0 * inr4\n"
"vmla.f32 q15, q9, q6 @ w0 * inr6\n"
"vmla.f32 q12, q10, q1 @ w1 * inr1\n"
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n"
"vmla.f32 q13, q10, q3 @ w1 * inr3\n"
"vmla.f32 q14, q10, q5 @ w1 * inr5\n"
"vmla.f32 q15, q10, q7 @ w1 * inr7\n"
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */
"vmla.f32 q12, q11, q2 @ w2 * inr2\n"
"vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n"
"vmla.f32 q13, q11, q4 @ w2 * inr4\n"
"vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n"
"vmla.f32 q14, q11, q6 @ w2 * inr6\n"
"vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n"
"vmla.f32 q15, q11, q8 @ w2 * inr8\n"
/* mul r1 with w3, w4, w5 */
"vmla.f32 q12, q9, q0 @ w3 * inr0\n"
"vmla.f32 q13, q9, q2 @ w3 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */
"vmla.f32 q14, q9, q4 @ w3 * inr4\n"
"vmla.f32 q15, q9, q6 @ w3 * inr6\n"
"vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/
"vmla.f32 q12, q10, q1 @ w4 * inr1\n"
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n"
"vmla.f32 q13, q10, q3 @ w4 * inr3\n"
"vmla.f32 q14, q10, q5 @ w4 * inr5\n"
"vmla.f32 q15, q10, q7 @ w4 * inr7\n"
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */
"vmla.f32 q12, q11, q2 @ w5 * inr2\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n"
"vmla.f32 q13, q11, q4 @ w5 * inr4\n"
"vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n"
"vmla.f32 q14, q11, q6 @ w5 * inr6\n"
"vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n"
"vmla.f32 q15, q11, q8 @ w5 * inr8\n"
/* mul r2 with w6, w7, w8 */
"vmla.f32 q12, q9, q0 @ w6 * inr0\n"
"vmla.f32 q13, q9, q2 @ w6 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */
"vmla.f32 q14, q9, q4 @ w6 * inr4\n"
"vmla.f32 q15, q9, q6 @ w6 * inr6\n"
"vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/
"vmla.f32 q12, q10, q1 @ w7 * inr1\n"
"vmla.f32 q13, q10, q3 @ w7 * inr3\n"
"vmla.f32 q14, q10, q5 @ w7 * inr5\n"
"vmla.f32 q15, q10, q7 @ w7 * inr7\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n"
"vmla.f32 q12, q11, q2 @ w8 * inr2\n"
"vmla.f32 q13, q11, q4 @ w8 * inr4\n"
"vmla.f32 q14, q11, q6 @ w8 * inr6\n"
"vmla.f32 q15, q11, q8 @ w8 * inr8\n"
/* transpose */
"vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/
"vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/
"vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/
"vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/
"cmp %[flag_relu], #0\n"
"beq 0f\n" /* skip relu*/
"vmov.u32 q0, #0\n"
"vmax.f32 q12, q12, q0\n"
"vmax.f32 q13, q13, q0\n"
"vmax.f32 q14, q14, q0\n"
"vmax.f32 q15, q15, q0\n"
"0:\n"
"vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/
"vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/
"vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/
"vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/
:[r0] "+r"(inr0), [r1] "+r"(inr1),
[r2] "+r"(inr2), [wc0] "+r" (weight_c),
[outc0]"+r"(outc0), [outc1]"+r"(outc1),
[outc2]"+r"(outc2), [outc3]"+r"(outc3)
:[bias] "r" (bias_local),
[flag_relu]"r"(flag_relu)
:"cc", "memory",
"q0","q1","q2","q3","q4","q5","q6","q7",
"q8", "q9","q10","q11","q12","q13","q14","q15"
);
#endif // __arch64__
// clang-format off
act_switch_3x3s2(inr0,
inr1,
inr2,
outc0,
outc1,
outc2,
outc3,
weight_c,
bias_local,
vzero,
vzero,
vzero,
vzero,
vzero,
vzero,
vzero,
vzero,
vzero,
act_param);
#endif
if (flag_mask) {
for (int i = 0; i < remain; ++i) {
c0[i] = pre_out[i];
......@@ -350,6 +702,13 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
c3[i] = pre_out[i + 12];
}
}
inr0 += 32;
inr1 += 32;
inr2 += 32;
outc0 += 4;
outc1 += 4;
outc2 += 4;
outc3 += 4;
}
}
}
......
......@@ -2151,6 +2151,210 @@ inline void act_switch_c8_fp32(const float* din_ptr,
}
}
#ifdef __aarch64__
#define LOAD_DATA \
"1: \n" \
"ld1 {v0.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v1.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v2.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v3.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/
#define DO_RELU \
"fmax v0.4s, v0.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v1.4s, v1.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v2.4s, v2.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v3.4s, v3.4s, %[vzero].4s \n" /* vmaxq_f32() */
#define DO_RELU6 \
"fmin v0.4s, v0.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define DO_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define DO_STORE \
"subs %w[cnt], %w[cnt], #1 \n" \
"st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v1.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v2.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v3.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"bne 1b \n"
#else
#define LOAD_DATA \
"1: \n" \
"vld1.32 {d6-d7}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d8-d9}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d10-d11}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d12-d13}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n"
#define DO_RELU \
"vmax.f32 q3, q3, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q4, q4, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q5, q5, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q6, q6, %q[vzero] @ vmaxq_f32() \n"
#define DO_RELU6 \
"vmin.f32 q3, q3, %q[vsix] @ vminq_f32() \n" \
"vmin.f32 q4, q4, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q5, q5, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q6, q6, %q[vsix] @ vmaxq_f32() \n"
#define DO_LEAKY_RELU \
"vcge.f32 q7, q3, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q8, q3, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q9, q4, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q10, q4, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q11, q5, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q12, q5, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q13, q6, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q14, q6, %q[vscale] @ vmulq_f32 \n" \
"vbif q3, q8, q7 @ choose \n" \
"vbif q4, q10, q9 @ choose \n" \
"vbif q5, q12, q11 @ choose \n" \
"vbif q6, q13, q13 @ choose \n"
#define DO_STORE \
"subs %[cnt], #1 \n" \
"vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d8-d9}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d10-d11}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d12-d13}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"bne 1b \n"
#endif
/*
* Data do activation process
* Now support relu relu6 leakyrelu act
*/
inline void act_switch_process(float* src,
float* dst,
int size,
const operators::ActivationParam* act_param) {
int cnt = size >> 4;
int remain = size % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
if (act_param != nullptr && act_param->has_active) {
float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha);
if (cnt > 0) {
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
LOAD_DATA DO_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_RELU DO_RELU6 DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
LOAD_DATA DO_RELU DO_RELU6 DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_LEAKY_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale)
: "memory",
"cc",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11");
#else
asm volatile(
LOAD_DATA DO_LEAKY_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale)
: "memory",
"cc",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
// remain
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (int i = 0; i < remain; i++) {
*dst = *src >= 0.f ? *src : 0.f;
src++;
dst++;
}
case lite_api::ActivationType::kRelu6:
for (int i = 0; i < remain; i++) {
float tmp = *src >= 0.f ? *src : 0.f;
*dst = tmp <= act_param->Relu_clipped_coef
? tmp
: act_param->Relu_clipped_coef;
src++;
dst++;
}
case lite_api::ActivationType::kLeakyRelu:
for (int i = 0; i < remain; i++) {
if (*src >= 0.f) {
*dst = *src;
} else {
*dst = *src * act_param->Leaky_relu_alpha;
}
src++;
dst++;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
}
/*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/
......
......@@ -52,6 +52,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const float* weights,
const float* bias,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s1_fp32(const float* din,
......@@ -67,7 +68,6 @@ void conv_depthwise_3x3s1_fp32(const float* din,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx);
......@@ -84,7 +84,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx);
template <typename Dtype>
......
......@@ -584,7 +584,6 @@ void conv_depthwise_3x3_fp32(const void* din,
const int pad_w = paddings[2];
int stride = param.strides[1];
int pad = pad_w;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
......@@ -603,7 +602,6 @@ void conv_depthwise_3x3_fp32(const void* din,
bias,
pad,
flag_bias,
flag_relu,
act_param,
ctx);
} else {
......@@ -638,7 +636,7 @@ void conv_depthwise_3x3_fp32(const void* din,
bias,
pad,
flag_bias,
flag_relu,
act_param,
ctx);
} else {
conv_3x3s2_depthwise_fp32(reinterpret_cast<const float*>(din),
......@@ -653,6 +651,7 @@ void conv_depthwise_3x3_fp32(const void* din,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
}
} else {
......
......@@ -52,7 +52,7 @@ inline int ConvOutputSize(int input_size,
return output_size;
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
......
......@@ -136,7 +136,13 @@ class ConvOpLite : public OpLite {
mutable ConvParam param_;
std::string padding_algorithm_{""};
};
// update padding dilation
void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize);
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册