未验证 提交 d4739621 编写于 作者: H HappyAngel 提交者: GitHub

[lite][arm]add conv relu6 and leaky_relu in conv_dw_3x3s2, test=develop (#2618)

* fix conv 2-pad to 4-pad

* fix compute conv shape

* fix pad, test=develop

* change conv_depthwise_3x3s1_fp.cc name to conv3x3s1p01_depthwise_fp32.cc to distinguish between conv3x3s1_depthwise_fp32.cc

* delete printf note in conv3x3s1, test=develop

* delete printf note, test=develop

* delete gem_sdot.h, test=develop

it is coped from __gemm_sdot_meta_.h

* update compute padding, test=develop

* fix padding size, must be 2 or 4. test=develop

* fix format in operators/conv_op.cc, test=develop

* change #if 0 to #if 1, test=develop

* put 2-pad to 4-pad in AttachImpl, test=develop

* fix clang-format error inn tests/math/connv_compute_test, test=develop

* fix x86 test result error, test=develop

* add asymmetric padding test case in liite/tests/math/conv_compute.cc, test=develop

* change paddings type to support dynamically modify, test=develop

* fix x86 build error in connv_compute_test, test=develop

* fix opencl build error, test=develop

* fix oopencl build error, test=develop

* fix  opencl/conv_compute build error, test=develop

* fix  opencl/conv_compute build error, test=develop

* fix format in kernels/opencl/conv_computte_ttest,test=develop

* fix build error, test=develop

fix build error in kernels/x86/conv_compute.h

* fix ccompute shape error in ooperators/conv_op.h, test=develop

* add conv_reelu6 and conv leaky_relu in conv_3x3s1_direct

* add conv_relu6 in c1, c2, c4,test=develop

* fix conflict in conv_bloock_utils.h, test=develop

* add relu6 and leankyrelu in conv_3x3s1_dw

* add conv_3x3s1px_dw relu6 and leaky_relu fusion, test=develop

* fix conflict in tests/math/conv_compute_arm, test=develop

* fix build error in winograd arm, test=develop

* channge act_param as pointer in conv_block_tuils.h, test=develop

* fix winograd in no equal 4-padding compute error, test=develop

* add conv relu6 and leaky_relu in conv_dw_3x3s2, test=develop

* fix format, test=develop

* fix format in conv_block_utils, test=develop

* move updatePadding from conv_op.cc to conv_op.h, test=develop

* fix format conv_op.h, test=develop

* fix buuilde error in conv_oop.h, test=develop

* remove flag_relu parameter in conv_3x3_depthwise, test=develop
上级 1b74fded
...@@ -25,7 +25,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -25,7 +25,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -40,7 +39,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, ...@@ -40,7 +39,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -55,7 +53,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -55,7 +53,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -70,7 +67,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, ...@@ -70,7 +67,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -93,7 +89,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -93,7 +89,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
const float *bias, const float *bias,
int pad, int pad,
bool flag_bias, bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext *ctx) { ARMContext *ctx) {
if (pad == 0) { if (pad == 0) {
...@@ -103,7 +98,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -103,7 +98,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
...@@ -118,7 +112,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -118,7 +112,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
...@@ -136,7 +129,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -136,7 +129,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
...@@ -151,7 +143,6 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -151,7 +143,6 @@ void conv_depthwise_3x3s1_fp32(const float *din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
...@@ -163,7 +154,7 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -163,7 +154,7 @@ void conv_depthwise_3x3s1_fp32(const float *din,
} }
} }
} }
// clang-format on
#ifdef __aarch64__ #ifdef __aarch64__
#define INIT_S1 \ #define INIT_S1 \
"PRFM PLDL1KEEP, [%[din_ptr0]] \n" \ "PRFM PLDL1KEEP, [%[din_ptr0]] \n" \
...@@ -2318,7 +2309,6 @@ void act_switch_3x3s1p1(const float *din_ptr0, ...@@ -2318,7 +2309,6 @@ void act_switch_3x3s1p1(const float *din_ptr0,
} }
} }
#endif #endif
// clang-format on
/** /**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4 * width > 4
...@@ -2328,7 +2318,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2328,7 +2318,6 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -2857,7 +2846,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, ...@@ -2857,7 +2846,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -3443,7 +3431,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3443,7 +3431,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -3579,129 +3566,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3579,129 +3566,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
} }
int cnt = tile_w; int cnt = tile_w;
/*
if (flag_relu) {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" // vld1q_f32(din_ptr0)
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" // vld1q_f32(din_ptr0)
"ext v16.16b, v0.16b, v1.16b, #4 \n" // v16 = 1234
"ext v17.16b, v0.16b, v1.16b, #8 \n" // v17 = 2345
"ld1 {v9.4s}, [%[din_ptr4]] \n" // vld1q_f32(din_ptr0)
"ld1 {v11.4s}, [%[din_ptr5]] \n" // vld1q_f32(din_ptr0)
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
*/
act_switch_3x3s1p0(din_ptr0, act_switch_3x3s1p0(din_ptr0,
din_ptr1, din_ptr1,
din_ptr2, din_ptr2,
...@@ -3760,90 +3624,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3760,90 +3624,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
int cnt = tile_w; int cnt = tile_w;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
/*
if (flag_relu) {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}*/
act_switch_3x3s1p0(din_ptr0, act_switch_3x3s1p0(din_ptr0,
din_ptr1, din_ptr1,
din_ptr2, din_ptr2,
...@@ -4174,7 +3954,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, ...@@ -4174,7 +3954,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
...@@ -4213,14 +3992,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, ...@@ -4213,14 +3992,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
float32x4_t wr1 = vld1q_f32(weight_ptr + 3); float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6); float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
// #ifdef __aarch64__
// float32x4_t wbias;
// if (flag_bias) {
// wbias = vdupq_n_f32(bias[i]);
// } else {
// wbias = vdupq_n_f32(0.f);
// }
// #endif // __aarch64__
float32x4_t wbias; float32x4_t wbias;
float bias_val = 0.f; float bias_val = 0.f;
if (flag_bias) { if (flag_bias) {
...@@ -4261,137 +4032,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, ...@@ -4261,137 +4032,6 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
break; break;
} }
} }
/*
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
*/
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0_s(dr0, act_switch_3x3s1p0_s(dr0,
dr1, dr1,
......
...@@ -836,7 +836,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -836,7 +836,6 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/; threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size); ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
/// get workspace /// get workspace
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <arm_neon.h> #include <arm_neon.h>
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_depthwise.h" #include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle { namespace paddle {
...@@ -24,13 +25,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -24,13 +25,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_s(float* dout, void conv_depthwise_3x3s2p0_bias_s(float* dout,
...@@ -38,13 +39,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -38,13 +39,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias(float* dout, void conv_depthwise_3x3s2p1_bias(float* dout,
...@@ -52,13 +53,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -52,13 +53,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s(float* dout, void conv_depthwise_3x3s2p1_bias_s(float* dout,
...@@ -66,13 +67,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -66,13 +67,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din, void conv_depthwise_3x3s2_fp32(const float* din,
...@@ -88,7 +89,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -88,7 +89,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const float* bias, const float* bias,
int pad, int pad,
bool flag_bias, bool flag_bias,
bool flag_relu, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
if (pad == 0) { if (pad == 0) {
if (w_in > 7) { if (w_in > 7) {
...@@ -97,13 +98,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -97,13 +98,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param,
ctx); ctx);
} else { } else {
conv_depthwise_3x3s2p0_bias_s(dout, conv_depthwise_3x3s2p0_bias_s(dout,
...@@ -111,13 +112,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -111,13 +112,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param,
ctx); ctx);
} }
} }
...@@ -128,13 +129,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -128,13 +129,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param,
ctx); ctx);
} else { } else {
conv_depthwise_3x3s2p1_bias_s(dout, conv_depthwise_3x3s2p1_bias_s(dout,
...@@ -142,13 +143,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -142,13 +143,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param,
ctx); ctx);
} }
} }
...@@ -412,6 +413,83 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -412,6 +413,83 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \ "and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\ \
"blt 1f \n" "blt 1f \n"
#define LEFT_RESULT_S2_RELU6 \
"fmax v16.4s, v16.4s, %[vzero].4s \n" \
"ld1 {v22.4s}, [%[six_ptr]] \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
"fmin v16.4s, v16.4s, v22.4s \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"fmax v17.4s, v17.4s, %[vzero].4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"fmin v17.4s, v17.4s, v22.4s \n" \
\
"cmp %w[cnt], #1 \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define LEFT_RESULT_S2_LEAKY_RELU \
"ld1 {v22.4s}, [%[scale_ptr]] \n" \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"cmp %w[cnt], #1 \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define MID_RESULT_S2_RELU \ #define MID_RESULT_S2_RELU \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
...@@ -438,6 +516,58 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -438,6 +516,58 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"bne 2b \n" "bne 2b \n"
#define MID_RESULT_S2_RELU6 \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"fmin v16.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"fmin v17.4s, v17.4s, v22.4s \n" \
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"bne 2b \n"
#define MID_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \ #define RIGHT_RESULT_S2_RELU \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\ \
...@@ -456,6 +586,47 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -456,6 +586,47 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"st1 {v17.4s}, [%[outptr1]], #16 \n" \ "st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n" "4: \n"
#define RIGHT_RESULT_S2_RELU6 \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"fmin v16.4s, v16.4s, v22.4s \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"fmin v17.4s, v17.4s, v22.4s \n" \
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define RIGHT_RESULT_S2_LEAKY_RELU \
"cmhs v11.4s, v16.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"cmhs v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"bif v17.16b, v12.16b, v11.16b \n" /* choose*/ \
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define COMPUTE_S_S2 \ #define COMPUTE_S_S2 \
"movi v9.4s, #0 \n" \ "movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
...@@ -500,7 +671,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -500,7 +671,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"fmax v4.4s, v4.4s, v9.4s \n" \ "fmax v4.4s, v4.4s, v9.4s \n" \
\ \
"st1 {v4.4s}, [%[out]] \n" "st1 {v4.4s}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \ #define COMPUTE_S_S2_P0 \
"movi v9.4s, #0 \n" \ "movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
...@@ -537,7 +707,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -537,7 +707,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"fadd v4.4s, v4.4s, v16.4s \n" "fadd v4.4s, v4.4s, v16.4s \n"
#define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n" #define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \ #define RESULT_S_S2_P0_RELU \
"fmax v4.4s, v4.4s, v9.4s \n" \ "fmax v4.4s, v4.4s, v9.4s \n" \
"st1 {v4.4s}, [%[out]] \n" "st1 {v4.4s}, [%[out]] \n"
...@@ -682,7 +851,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -682,7 +851,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vst1.32 {d6-d7}, [%[outptr]]! \n" \
"cmp %[cnt], #1 \n" \ "cmp %[cnt], #1 \n" \
"blt 1f \n" "blt 1f \n"
#define MID_RESULT_S2_RELU \ #define MID_RESULT_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \ "vmax.f32 q3, q3, q9 @ relu \n" \
"subs %[cnt], #1 \n" \ "subs %[cnt], #1 \n" \
...@@ -739,7 +907,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -739,7 +907,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vadd.f32 q3, q3, q5 @ add \n" "vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n" #define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_RELU \ #define RESULT_S_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu\n" \ "vmax.f32 q3, q3, q9 @ relu\n" \
\ \
...@@ -787,13 +954,233 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -787,13 +954,233 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vadd.f32 q3, q3, q5 @ add \n" "vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n" #define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \ #define RESULT_S_S2_P0_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \ "vmax.f32 q3, q3, q9 @ relu \n" \
"vst1.32 {d6-d7}, [%[out]] \n" "vst1.32 {d6-d7}, [%[out]] \n"
#endif #endif
#ifdef __aarch64__
void act_switch_3x3s2p1(const float* din0_ptr,
const float* din1_ptr,
const float* din2_ptr,
const float* din3_ptr,
const float* din4_ptr,
float* doutr0_ptr,
float* doutr1_ptr,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2,
uint32x4_t wmask,
float32x4_t wbias,
float32x4_t vzero,
int cnt,
int cnt_remain,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
}
#endif
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
* w_in > 7 * w_in > 7
...@@ -803,13 +1190,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -803,13 +1190,13 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
...@@ -821,7 +1208,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -821,7 +1208,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
cnt_col++; cnt_col++;
size_right_remain -= 8; size_right_remain -= 8;
} }
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); // int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4);
int size_right_pad = w_out * 2 - w_in; int size_right_pad = w_out * 2 - w_in;
...@@ -935,96 +1322,24 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -935,96 +1322,24 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = cnt_col; int cnt = cnt_col;
if (flag_relu) { act_switch_3x3s2p1(din0_ptr,
asm volatile( din1_ptr,
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 din2_ptr,
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU din3_ptr,
: [inptr0] "+r"(din0_ptr), din4_ptr,
[inptr1] "+r"(din1_ptr), doutr0_ptr,
[inptr2] "+r"(din2_ptr), doutr1_ptr,
[inptr3] "+r"(din3_ptr), wr0,
[inptr4] "+r"(din4_ptr), wr1,
[outptr0] "+r"(doutr0_ptr), wr2,
[outptr1] "+r"(doutr1_ptr), vmask_rp1,
[cnt] "+r"(cnt) vmask_rp2,
: [vzero] "w"(vzero), wmask,
[w0] "w"(wr0), wbias,
[w1] "w"(wr1), vzero,
[w2] "w"(wr2), cnt,
[remain] "r"(cnt_remain), cnt_remain,
[mask1] "w"(vmask_rp1), act_param);
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1061,65 +1376,37 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1061,65 +1376,37 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
if (flag_relu) { asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
asm volatile( MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 : [din0_ptr] "+r"(din0_ptr),
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU [din1_ptr] "+r"(din1_ptr),
: [din0_ptr] "+r"(din0_ptr), [din2_ptr] "+r"(din2_ptr),
[din1_ptr] "+r"(din1_ptr), [outptr] "+r"(doutr0_ptr),
[din2_ptr] "+r"(din2_ptr), [cnt] "+r"(cnt),
[outptr] "+r"(doutr0_ptr), [mask_ptr] "+r"(mask_ptr)
[cnt] "+r"(cnt), : [remain] "r"(cnt_remain),
[mask_ptr] "+r"(mask_ptr) [wr0] "w"(wr0),
: [remain] "r"(cnt_remain), [wr1] "w"(wr1),
[wr0] "w"(wr0), [wr2] "w"(wr2),
[wr1] "w"(wr1), [bias] "r"(bias_c)
[wr2] "w"(wr2), : "cc",
[bias] "r"(bias_c) "memory",
: "cc", "q3",
"memory", "q4",
"q3", "q5",
"q4", "q6",
"q5", "q7",
"q6", "q8",
"q7", "q9",
"q8", "q10",
"q9", "q11",
"q10", "q12",
"q11", "q13",
"q12", "q14",
"q13", "q15");
"q14", // do act
"q15"); if (act_param.has_active) {
} else { act_switch_process(doutr0, doutr0, w_out, &act_param);
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
...@@ -1136,13 +1423,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1136,13 +1423,13 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
...@@ -1198,108 +1485,59 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1198,108 +1485,59 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S2 RESULT_S_S2
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "w"(vbias),
[bias] "w"(vbias), [out] "r"(out_buf)
[out] "r"(out_buf) : "v4",
: "v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15");
"v15");
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else #else
if (flag_relu) { asm volatile(COMPUTE_S_S2 RESULT_S_S2
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c),
[bias] "r"(bias_c), [out] "r"(out_buf)
[out] "r"(out_buf) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15");
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif #endif
// do act
if (act_param.has_active) {
act_switch_process(out_buf, out_buf, w_out, &act_param);
}
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
} }
...@@ -1310,6 +1548,269 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1310,6 +1548,269 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
} }
} }
#ifdef __aarch64__
void act_switch_3x3s2p0(const float* din0_ptr,
const float* din1_ptr,
const float* din2_ptr,
const float* din3_ptr,
const float* din4_ptr,
float* doutr0_ptr,
float* doutr1_ptr,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2,
uint32x4_t wmask,
float32x4_t wbias,
float32x4_t vzero,
int cnt,
int cnt_remain,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU6
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU6
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 "4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
}
#endif
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
*/ */
...@@ -1319,13 +1820,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1319,13 +1820,13 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
...@@ -1438,117 +1939,24 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1438,117 +1939,24 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = tile_w; int cnt = tile_w;
if (flag_relu) { act_switch_3x3s2p0(din0_ptr,
asm volatile( din1_ptr,
INIT_S2 din2_ptr,
"ld1 {v15.4s}, [%[inptr0]] \n" din3_ptr,
"ld1 {v18.4s}, [%[inptr1]] \n" din4_ptr,
"ld1 {v19.4s}, [%[inptr2]] \n" doutr0_ptr,
"ld1 {v20.4s}, [%[inptr3]] \n" doutr1_ptr,
"ld1 {v21.4s}, [%[inptr4]] \n" wr0,
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} wr1,
MID_COMPUTE_S2 MID_RESULT_S2_RELU wr2,
"cmp %w[remain], #1 \n" vmask_rp1,
"blt 4f \n" RIGHT_COMPUTE_S2 vmask_rp2,
RIGHT_RESULT_S2_RELU wmask,
"4: \n" wbias,
: [inptr0] "+r"(din0_ptr), vzero,
[inptr1] "+r"(din1_ptr), cnt,
[inptr2] "+r"(din2_ptr), cnt_remain,
[inptr3] "+r"(din3_ptr), act_param);
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1576,64 +1984,36 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1576,64 +1984,36 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
} }
int cnt = tile_w; int cnt = tile_w;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
if (flag_relu) { asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_RESULT_S2
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [outptr] "+r"(doutr0_ptr),
[outptr] "+r"(doutr0_ptr), [cnt] "+r"(cnt),
[cnt] "+r"(cnt), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [remain] "r"(cnt_remain),
: [remain] "r"(cnt_remain), [wr0] "w"(wr0),
[wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c)
[bias] "r"(bias_c) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); if (act_param.has_active) {
} else { act_switch_process(doutr0, doutr0, w_out, &act_param);
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
...@@ -1650,13 +2030,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -1650,13 +2030,13 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, bool flag_bias,
bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
...@@ -1718,114 +2098,62 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -1718,114 +2098,62 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "w"(vbias),
[bias] "w"(vbias), [out] "r"(out_buf)
[out] "r"(out_buf) : "cc",
: "cc", "memory",
"memory", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15",
"v15", "v16");
"v16");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
}
#else #else
if (flag_relu) { asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr)
[din2_ptr] "+r"(din2_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c),
[bias] "r"(bias_c), [out] "r"(out_buf),
[out] "r"(out_buf), [mask_ptr] "r"(dmask)
[mask_ptr] "r"(dmask) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif #endif
if (act_param.has_active) {
act_switch_process(out_buf, out_buf, w_out, &act_param);
}
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
} }
......
...@@ -25,6 +25,511 @@ namespace paddle { ...@@ -25,6 +25,511 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
#ifdef __aarch64__
#define COMPUTE \
"ldr q8, [%[bias]]\n" /* load bias */ \
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ \
"and v19.16b, v8.16b, v8.16b\n" \
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ \
"and v20.16b, v8.16b, v8.16b\n" \
"ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ \
"and v21.16b, v8.16b, v8.16b\n" \
"ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ \
"and v22.16b, v8.16b, v8.16b\n" \
"ldr q8, [%[inr0]]\n" /* load input r0*/ \
"fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ \
"fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ \
"fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ \
"fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ \
"fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ \
"ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ \
"fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/ \
"fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/ \
"fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/ \
"ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/ \
"ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/ \
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/ \
"fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/ \
"ldr q8, [%[inr1]]\n" /* load input r1*/ \
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/ \
"fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/ \
"fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/ \
"fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/ \
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/ \
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/ \
"fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/ \
"fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/ \
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/ \
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/ \
"ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/ \
"ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/ \
"fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/ \
"ldr q8, [%[inr2]]\n" /* load input r2*/ \
"fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/ \
"fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/ \
"fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/ \
"fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/ \
"fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/ \
"fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/ \
"fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/ \
"fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/ \
"fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/ \
"fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/ \
"fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/ \
"fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/ \
"trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/ \
"trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/ \
"trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/ \
"trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/ \
"trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/ \
"trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/ \
"trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/ \
"trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
#define RELU /* relu */ \
"movi v0.4s, #0\n" /* for relu */ \
"fmax v19.4s, v19.4s, v0.4s\n" \
"fmax v20.4s, v20.4s, v0.4s\n" \
"fmax v21.4s, v21.4s, v0.4s\n" \
"fmax v22.4s, v22.4s, v0.4s\n"
#define RELU6 /* relu6 */ \
"fmin v19.4s, v19.4s, %[vsix].4s\n" \
"fmin v20.4s, v20.4s, %[vsix].4s\n" \
"fmin v21.4s, v21.4s, %[vsix].4s\n" \
"fmin v22.4s, v22.4s, %[vsix].4s\n"
#define LEAKY_RELU /* LeakyRelu */ \
"movi v0.4s, #0\n" /* for relu */ \
"cmhs v1.4s, v19.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v2.4s, v19.4s, %[vscale].4s \n" /* mul */ \
"cmhs v3.4s, v20.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v4.4s, v20.4s, %[vscale].4s \n" /* mul */ \
"cmhs v5.4s, v21.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v6.4s, v21.4s, %[vscale].4s \n" /* mul */ \
"cmhs v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
"bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \
"bif v19.16b, v8.16b, v7.16b \n" /* choose*/
#define STORE /* save result */ \
"str q19, [%[outc0]], #16\n" \
"str q20, [%[outc1]], #16\n" \
"str q21, [%[outc2]], #16\n" \
"str q22, [%[outc3]], #16\n"
#else
#define COMPUTE \
/* fill with bias */ \
"vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ /* load weights */ \
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ \
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ \
"vand.i32 q12, q8, q8\n" \
"vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ \
"vand.i32 q13, q8, q8\n" \
"vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ \
"vand.i32 q14, q8, q8\n" \
"vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ \
"vand.i32 q15, q8, q8\n" \
"vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ \
"vmla.f32 q12, q9, q0 @ w0 * inr0\n" \
"vmla.f32 q13, q9, q2 @ w0 * inr2\n" \
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ \
"vmla.f32 q14, q9, q4 @ w0 * inr4\n" \
"vmla.f32 q15, q9, q6 @ w0 * inr6\n" \
"vmla.f32 q12, q10, q1 @ w1 * inr1\n" \
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n" \
"vmla.f32 q13, q10, q3 @ w1 * inr3\n" \
"vmla.f32 q14, q10, q5 @ w1 * inr5\n" \
"vmla.f32 q15, q10, q7 @ w1 * inr7\n" \
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */ \
"vmla.f32 q12, q11, q2 @ w2 * inr2\n" \
"vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n" \
"vmla.f32 q13, q11, q4 @ w2 * inr4\n" \
"vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n" \
"vmla.f32 q14, q11, q6 @ w2 * inr6\n" \
"vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n" \
"vmla.f32 q15, q11, q8 @ w2 * inr8\n" /* mul r1 with w3, w4*/ \
"vmla.f32 q12, q9, q0 @ w3 * inr0\n" \
"vmla.f32 q13, q9, q2 @ w3 * inr2\n" \
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */ \
"vmla.f32 q14, q9, q4 @ w3 * inr4\n" \
"vmla.f32 q15, q9, q6 @ w3 * inr6\n" \
"vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/ \
"vmla.f32 q12, q10, q1 @ w4 * inr1\n" \
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n" \
"vmla.f32 q13, q10, q3 @ w4 * inr3\n" \
"vmla.f32 q14, q10, q5 @ w4 * inr5\n" \
"vmla.f32 q15, q10, q7 @ w4 * inr7\n" \
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */ \
"vmla.f32 q12, q11, q2 @ w5 * inr2\n" \
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n" \
"vmla.f32 q13, q11, q4 @ w5 * inr4\n" \
"vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n" \
"vmla.f32 q14, q11, q6 @ w5 * inr6\n" \
"vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n" \
"vmla.f32 q15, q11, q8 @ w5 * inr8\n" /* mul r2 with w6, w7*/ \
"vmla.f32 q12, q9, q0 @ w6 * inr0\n" \
"vmla.f32 q13, q9, q2 @ w6 * inr2\n" \
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */ \
"vmla.f32 q14, q9, q4 @ w6 * inr4\n" \
"vmla.f32 q15, q9, q6 @ w6 * inr6\n" \
"vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/ \
"vmla.f32 q12, q10, q1 @ w7 * inr1\n" \
"vmla.f32 q13, q10, q3 @ w7 * inr3\n" \
"vmla.f32 q14, q10, q5 @ w7 * inr5\n" \
"vmla.f32 q15, q10, q7 @ w7 * inr7\n" \
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n" \
"vmla.f32 q12, q11, q2 @ w8 * inr2\n" \
"vmla.f32 q13, q11, q4 @ w8 * inr4\n" \
"vmla.f32 q14, q11, q6 @ w8 * inr6\n" \
"vmla.f32 q15, q11, q8 @ w8 * inr8\n" /* transpose */ \
"vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/ \
"vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/ \
"vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/ \
"vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/
#define RELU /* relu */ \
"vmov.u32 q0, #0\n" \
"vld1.32 {d2-d3}, [%[six_ptr]]\n" \
"vmax.f32 q12, q12, q0\n" \
"vmax.f32 q13, q13, q0\n" \
"vmax.f32 q14, q14, q0\n" \
"vmax.f32 q15, q15, q0\n"
#define RELU6 /* relu6 */ \
"vmin.f32 q12, q12, q1\n" \
"vmin.f32 q13, q13, q1\n" \
"vmin.f32 q14, q14, q1\n" \
"vmin.f32 q15, q15, q1\n"
#define LEAKY_RELU /* LeakyRelu */ \
"vmov.u32 q0, #0\n" \
"vld1.32 {d2-d3}, [%[scale_ptr]]\n" \
"vcge.f32 q2, q12, q0 @ q0 > 0 \n" \
"vcge.f32 q4, q13, q0 @ q0 > 0 \n" \
"vcge.f32 q6, q14, q0 @ q0 > 0 \n" \
"vcge.f32 q8, q15, q0 @ q0 > 0 \n" \
"vmul.f32 q3, q12, q1 @ mul \n" \
"vmul.f32 q5, q13, q1 @ mul \n" \
"vmul.f32 q7, q14, q1 @ mul \n" \
"vmul.f32 q9, q15, q1 @ mul \n" \
"vbif q12, q3, q2 @ choose \n" \
"vbif q13, q5, q4 @ choose \n" \
"vbif q14, q7, q6 @ choose \n" \
"vbif q15, q9, q8 @ choose \n"
#define STORE /* save result */ \
"vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/ \
"vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/ \
"vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/ \
"vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/
#endif
void act_switch_3x3s2(const float* inr0,
const float* inr1,
const float* inr2,
float* outc0,
float* outc1,
float* outc2,
float* outc3,
const float* weight_c,
float* bias_local,
float32x4_t w0,
float32x4_t w1,
float32x4_t w2,
float32x4_t w3,
float32x4_t w4,
float32x4_t w5,
float32x4_t w6,
float32x4_t w7,
float32x4_t w8,
const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(tmp);
float32x4_t vscale = vdupq_n_f32(ss);
#else
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local), [six_ptr] "r"(vsix)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(COMPUTE RELU RELU6 STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local),
[vsix] "w"(vsix)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE RELU RELU6 STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local), [six_ptr] "r"(vsix)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__
asm volatile(COMPUTE LEAKY_RELU STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local),
[vscale] "w"(vscale)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE LEAKY_RELU STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local), [scale_ptr] "r"(vscale)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(COMPUTE STORE
: [inr0] "+r"(inr0),
[inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [w0] "w"(w0),
[w1] "w"(w1),
[w2] "w"(w2),
[w3] "w"(w3),
[w4] "w"(w4),
[w5] "w"(w5),
[w6] "w"(w6),
[w7] "w"(w7),
[w8] "w"(w8),
[bias] "r"(bias_local)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v19",
"v20",
"v21",
"v22");
#else
asm volatile(COMPUTE STORE
: [r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[wc0] "+r"(weight_c),
[outc0] "+r"(outc0),
[outc1] "+r"(outc1),
[outc2] "+r"(outc2),
[outc3] "+r"(outc3)
: [bias] "r"(bias_local)
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
}
}
void conv_3x3s2_depthwise_fp32(const float* i_data, void conv_3x3s2_depthwise_fp32(const float* i_data,
float* o_data, float* o_data,
...@@ -38,6 +543,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, ...@@ -38,6 +543,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
auto paddings = *param.paddings; auto paddings = *param.paddings;
int threads = ctx->threads(); int threads = ctx->threads();
...@@ -51,11 +557,9 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, ...@@ -51,11 +557,9 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const int win_round = ROUNDUP(win_ext, 4); const int win_round = ROUNDUP(win_ext, 4);
const int hin_round = oh * 2 + 1; const int hin_round = oh * 2 + 1;
const int prein_size = win_round * hin_round * out_c_block; const int prein_size = win_round * hin_round * out_c_block;
auto workspace_size = auto workspace_size = threads * prein_size + win_round + ow_round;
threads * prein_size + win_round /*tmp zero*/ + ow_round /*tmp writer*/;
ctx->ExtendWorkspace(sizeof(float) * workspace_size); ctx->ExtendWorkspace(sizeof(float) * workspace_size);
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
/// get workspace /// get workspace
...@@ -77,6 +581,8 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, ...@@ -77,6 +581,8 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
remain = remain > 0 ? remain : 0; remain = remain > 0 ? remain : 0;
int row_len = win_round * out_c_block; int row_len = win_round * out_c_block;
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < bs; ++n) { for (int n = 0; n < bs; ++n) {
const float* din_batch = i_data + n * ic * size_in_channel; const float* din_batch = i_data + n * ic * size_in_channel;
float* dout_batch = o_data + n * oc * size_out_channel; float* dout_batch = o_data + n * oc * size_out_channel;
...@@ -147,201 +653,47 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, ...@@ -147,201 +653,47 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
outc2 = pre_out + 8; outc2 = pre_out + 8;
outc3 = pre_out + 12; outc3 = pre_out + 12;
} }
// clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( act_switch_3x3s2(inr0,
"ldr q8, [%[bias]]\n" /* load bias */ inr1,
"ldp q0, q1, [%[inr0]], #32\n" /* load input r0*/ inr2,
"and v19.16b, v8.16b, v8.16b\n" outc0,
"ldp q2, q3, [%[inr0]], #32\n" /* load input r0*/ outc1,
"and v20.16b, v8.16b, v8.16b\n" outc2,
"ldp q4, q5, [%[inr0]], #32\n" /* load input r0*/ outc3,
"and v21.16b, v8.16b, v8.16b\n" weight_c,
"ldp q6, q7, [%[inr0]], #32\n" /* load input r0*/ bias_local,
"and v22.16b, v8.16b, v8.16b\n" w0,
"ldr q8, [%[inr0]]\n" /* load input r0*/ w1,
/* r0 mul w0-w2, get out */ w2,
"fmla v19.4s , %[w0].4s, v0.4s\n" /* outr0 = w0 * r0, 0*/ w3,
"fmla v20.4s , %[w0].4s, v2.4s\n" /* outr1 = w0 * r0, 2*/ w4,
"fmla v21.4s , %[w0].4s, v4.4s\n" /* outr2 = w0 * r0, 4*/ w5,
"fmla v22.4s , %[w0].4s, v6.4s\n" /* outr3 = w0 * r0, 6*/ w6,
"fmla v19.4s , %[w1].4s, v1.4s\n" /* outr0 = w1 * r0, 1*/ w7,
"ldp q0, q1, [%[inr1]], #32\n" /* load input r1*/ w8,
"fmla v20.4s , %[w1].4s, v3.4s\n" /* outr1 = w1 * r0, 3*/ act_param);
"fmla v21.4s , %[w1].4s, v5.4s\n" /* outr2 = w1 * r0, 5*/
"fmla v22.4s , %[w1].4s, v7.4s\n" /* outr3 = w1 * r0, 7*/
"fmla v19.4s , %[w2].4s, v2.4s\n" /* outr0 = w0 * r0, 2*/
"ldp q2, q3, [%[inr1]], #32\n" /* load input r1*/
"fmla v20.4s , %[w2].4s, v4.4s\n" /* outr1 = w0 * r0, 4*/
"ldp q4, q5, [%[inr1]], #32\n" /* load input r1*/
"fmla v21.4s , %[w2].4s, v6.4s\n" /* outr2 = w0 * r0, 6*/
"ldp q6, q7, [%[inr1]], #32\n" /* load input r1*/
"fmla v22.4s , %[w2].4s, v8.4s\n" /* outr3 = w0 * r0, 8*/
"ldr q8, [%[inr1]]\n" /* load input r1*/
/* r1, mul w3-w5, get out */
"fmla v19.4s , %[w3].4s, v0.4s\n" /* outr0 = w3 * r1, 0*/
"fmla v20.4s , %[w3].4s, v2.4s\n" /* outr1 = w3 * r1, 2*/
"fmla v21.4s , %[w3].4s, v4.4s\n" /* outr2 = w3 * r1, 4*/
"fmla v22.4s , %[w3].4s, v6.4s\n" /* outr3 = w3 * r1, 6*/
"fmla v19.4s , %[w4].4s, v1.4s\n" /* outr0 = w4 * r1, 1*/
"ldp q0, q1, [%[inr2]], #32\n" /* load input r2*/
"fmla v20.4s , %[w4].4s, v3.4s\n" /* outr1 = w4 * r1, 3*/
"fmla v21.4s , %[w4].4s, v5.4s\n" /* outr2 = w4 * r1, 5*/
"fmla v22.4s , %[w4].4s, v7.4s\n" /* outr3 = w4 * r1, 7*/
"fmla v19.4s , %[w5].4s, v2.4s\n" /* outr0 = w5 * r1, 2*/
"ldp q2, q3, [%[inr2]], #32\n" /* load input r2*/
"fmla v20.4s , %[w5].4s, v4.4s\n" /* outr1 = w5 * r1, 4*/
"ldp q4, q5, [%[inr2]], #32\n" /* load input r2*/
"fmla v21.4s , %[w5].4s, v6.4s\n" /* outr2 = w5 * r1, 6*/
"ldp q6, q7, [%[inr2]], #32\n" /* load input r2*/
"fmla v22.4s , %[w5].4s, v8.4s\n" /* outr3 = w5 * r1, 8*/
"ldr q8, [%[inr2]]\n" /* load input r2*/
/* r2, mul w6-w8, get out r0, r1 */
"fmla v19.4s , %[w6].4s, v0.4s\n" /* outr0 = w6 * r2, 0*/
"fmla v20.4s , %[w6].4s, v2.4s\n" /* outr1 = w6 * r2, 2*/
"fmla v21.4s , %[w6].4s, v4.4s\n" /* outr2 = w6 * r2, 4*/
"fmla v22.4s , %[w6].4s, v6.4s\n" /* outr3 = w6 * r2, 6*/
"fmla v19.4s , %[w7].4s, v1.4s\n" /* outr0 = w7 * r2, 1*/
"fmla v20.4s , %[w7].4s, v3.4s\n" /* outr1 = w7 * r2, 3*/
"fmla v21.4s , %[w7].4s, v5.4s\n" /* outr2 = w7 * r2, 5*/
"fmla v22.4s , %[w7].4s, v7.4s\n" /* outr3 = w7 * r2, 7*/
"fmla v19.4s , %[w8].4s, v2.4s\n" /* outr0 = w8 * r2, 2*/
"fmla v20.4s , %[w8].4s, v4.4s\n" /* outr1 = w8 * r2, 4*/
"fmla v21.4s , %[w8].4s, v6.4s\n" /* outr2 = w8 * r2, 6*/
"fmla v22.4s , %[w8].4s, v8.4s\n" /* outr3 = w8 * r2, 8*/
/* transpose */
"trn1 v0.4s, v19.4s, v20.4s\n" /* r0: a0a1c0c1*/
"trn2 v1.4s, v19.4s, v20.4s\n" /* r0: b0b1d0d1*/
"trn1 v2.4s, v21.4s, v22.4s\n" /* r0: a2a3c2c3*/
"trn2 v3.4s, v21.4s, v22.4s\n" /* r0: b2b3d2d3*/
"trn1 v19.2d, v0.2d, v2.2d\n" /* r0: a0a1a2a3*/
"trn2 v21.2d, v0.2d, v2.2d\n" /* r0: c0c1c2c3*/
"trn1 v20.2d, v1.2d, v3.2d\n" /* r0: b0b1b2b3*/
"trn2 v22.2d, v1.2d, v3.2d\n" /* r0: d0d1d2d3*/
/* relu */
"cbz %w[flag_relu], 0f\n" /* skip relu*/
"movi v0.4s, #0\n" /* for relu */
"fmax v19.4s, v19.4s, v0.4s\n"
"fmax v20.4s, v20.4s, v0.4s\n"
"fmax v21.4s, v21.4s, v0.4s\n"
"fmax v22.4s, v22.4s, v0.4s\n"
/* save result */
"0:\n"
"str q19, [%[outc0]], #16\n"
"str q20, [%[outc1]], #16\n"
"str q21, [%[outc2]], #16\n"
"str q22, [%[outc3]], #16\n"
:[inr0] "+r"(inr0), [inr1] "+r"(inr1),
[inr2] "+r"(inr2),
[outc0]"+r"(outc0), [outc1]"+r"(outc1),
[outc2]"+r"(outc2), [outc3]"+r"(outc3)
:[w0] "w"(w0), [w1] "w"(w1), [w2] "w"(w2),
[w3] "w"(w3), [w4] "w"(w4), [w5] "w"(w5),
[w6] "w"(w6), [w7] "w"(w7), [w8] "w"(w8),
[bias] "r" (bias_local), [flag_relu]"r"(flag_relu)
: "cc", "memory",
"v0","v1","v2","v3","v4","v5","v6","v7",
"v8", "v19","v20","v21","v22"
);
#else #else
asm volatile( act_switch_3x3s2(inr0,
/* fill with bias */ inr1,
"vld1.32 {d16-d17}, [%[bias]]\n" /* load bias */ inr2,
/* load weights */ outc0,
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w0-2, to q9-11 */ outc1,
"vld1.32 {d0-d3}, [%[r0]]!\n" /* load input r0, 0,1*/ outc2,
"vand.i32 q12, q8, q8\n" outc3,
"vld1.32 {d4-d7}, [%[r0]]!\n" /* load input r0, 2,3*/ weight_c,
"vand.i32 q13, q8, q8\n" bias_local,
"vld1.32 {d8-d11}, [%[r0]]!\n" /* load input r0, 4,5*/ vzero,
"vand.i32 q14, q8, q8\n" vzero,
"vld1.32 {d12-d15}, [%[r0]]!\n" /* load input r0, 6,7*/ vzero,
"vand.i32 q15, q8, q8\n" vzero,
"vld1.32 {d16-d17}, [%[r0]]\n" /* load input r0, 8*/ vzero,
/* mul r0 with w0, w1, w2 */ vzero,
"vmla.f32 q12, q9, q0 @ w0 * inr0\n" vzero,
"vmla.f32 q13, q9, q2 @ w0 * inr2\n" vzero,
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w2, to q11 */ vzero,
"vmla.f32 q14, q9, q4 @ w0 * inr4\n" act_param);
"vmla.f32 q15, q9, q6 @ w0 * inr6\n" #endif
"vmla.f32 q12, q10, q1 @ w1 * inr1\n"
"vld1.32 {d0-d3}, [%[r1]]! @ load r1, 0, 1\n"
"vmla.f32 q13, q10, q3 @ w1 * inr3\n"
"vmla.f32 q14, q10, q5 @ w1 * inr5\n"
"vmla.f32 q15, q10, q7 @ w1 * inr7\n"
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w3-4, to q9-10 */
"vmla.f32 q12, q11, q2 @ w2 * inr2\n"
"vld1.32 {d4-d7}, [%[r1]]! @ load r1, 2, 3\n"
"vmla.f32 q13, q11, q4 @ w2 * inr4\n"
"vld1.32 {d8-d11}, [%[r1]]! @ load r1, 4, 5\n"
"vmla.f32 q14, q11, q6 @ w2 * inr6\n"
"vld1.32 {d12-d15}, [%[r1]]! @ load r1, 6, 7\n"
"vmla.f32 q15, q11, q8 @ w2 * inr8\n"
/* mul r1 with w3, w4, w5 */
"vmla.f32 q12, q9, q0 @ w3 * inr0\n"
"vmla.f32 q13, q9, q2 @ w3 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w5, to q11 */
"vmla.f32 q14, q9, q4 @ w3 * inr4\n"
"vmla.f32 q15, q9, q6 @ w3 * inr6\n"
"vld1.32 {d16-d17}, [%[r1]]\n" /* load input r1, 8*/
"vmla.f32 q12, q10, q1 @ w4 * inr1\n"
"vld1.32 {d0-d3}, [%[r2]]! @ load r2, 0, 1\n"
"vmla.f32 q13, q10, q3 @ w4 * inr3\n"
"vmla.f32 q14, q10, q5 @ w4 * inr5\n"
"vmla.f32 q15, q10, q7 @ w4 * inr7\n"
"vld1.32 {d18-d21}, [%[wc0]]!\n" /* load w6-7, to q9-10 */
"vmla.f32 q12, q11, q2 @ w5 * inr2\n"
"vld1.32 {d4-d7}, [%[r2]]! @ load r2, 2, 3\n"
"vmla.f32 q13, q11, q4 @ w5 * inr4\n"
"vld1.32 {d8-d11}, [%[r2]]! @ load r2, 4, 5\n"
"vmla.f32 q14, q11, q6 @ w5 * inr6\n"
"vld1.32 {d12-d15}, [%[r2]]! @ load r2, 6, 7\n"
"vmla.f32 q15, q11, q8 @ w5 * inr8\n"
/* mul r2 with w6, w7, w8 */
"vmla.f32 q12, q9, q0 @ w6 * inr0\n"
"vmla.f32 q13, q9, q2 @ w6 * inr2\n"
"vld1.32 {d22-d23}, [%[wc0]]!\n" /* load w8, to q11 */
"vmla.f32 q14, q9, q4 @ w6 * inr4\n"
"vmla.f32 q15, q9, q6 @ w6 * inr6\n"
"vld1.32 {d16-d17}, [%[r2]]\n" /* load input r2, 8*/
"vmla.f32 q12, q10, q1 @ w7 * inr1\n"
"vmla.f32 q13, q10, q3 @ w7 * inr3\n"
"vmla.f32 q14, q10, q5 @ w7 * inr5\n"
"vmla.f32 q15, q10, q7 @ w7 * inr7\n"
"sub %[wc0], %[wc0], #144 @ wc0 - 144 to start address\n"
"vmla.f32 q12, q11, q2 @ w8 * inr2\n"
"vmla.f32 q13, q11, q4 @ w8 * inr4\n"
"vmla.f32 q14, q11, q6 @ w8 * inr6\n"
"vmla.f32 q15, q11, q8 @ w8 * inr8\n"
/* transpose */
"vtrn.32 q12, q13\n" /* a0a1c0c1, b0b1d0d1*/
"vtrn.32 q14, q15\n" /* a2a3c2c3, b2b3d2d3*/
"vswp d25, d28\n" /* a0a1a2a3, c0c1c2c3*/
"vswp d27, d30\n" /* b0b1b2b3, d0d1d2d3*/
"cmp %[flag_relu], #0\n"
"beq 0f\n" /* skip relu*/
"vmov.u32 q0, #0\n"
"vmax.f32 q12, q12, q0\n"
"vmax.f32 q13, q13, q0\n"
"vmax.f32 q14, q14, q0\n"
"vmax.f32 q15, q15, q0\n"
"0:\n"
"vst1.32 {d24-d25}, [%[outc0]]!\n" /* save outc0*/
"vst1.32 {d26-d27}, [%[outc1]]!\n" /* save outc1*/
"vst1.32 {d28-d29}, [%[outc2]]!\n" /* save outc2*/
"vst1.32 {d30-d31}, [%[outc3]]!\n" /* save outc3*/
:[r0] "+r"(inr0), [r1] "+r"(inr1),
[r2] "+r"(inr2), [wc0] "+r" (weight_c),
[outc0]"+r"(outc0), [outc1]"+r"(outc1),
[outc2]"+r"(outc2), [outc3]"+r"(outc3)
:[bias] "r" (bias_local),
[flag_relu]"r"(flag_relu)
:"cc", "memory",
"q0","q1","q2","q3","q4","q5","q6","q7",
"q8", "q9","q10","q11","q12","q13","q14","q15"
);
#endif // __arch64__
// clang-format off
if (flag_mask) { if (flag_mask) {
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
c0[i] = pre_out[i]; c0[i] = pre_out[i];
...@@ -350,6 +702,13 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, ...@@ -350,6 +702,13 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
c3[i] = pre_out[i + 12]; c3[i] = pre_out[i + 12];
} }
} }
inr0 += 32;
inr1 += 32;
inr2 += 32;
outc0 += 4;
outc1 += 4;
outc2 += 4;
outc3 += 4;
} }
} }
} }
......
...@@ -2151,6 +2151,210 @@ inline void act_switch_c8_fp32(const float* din_ptr, ...@@ -2151,6 +2151,210 @@ inline void act_switch_c8_fp32(const float* din_ptr,
} }
} }
#ifdef __aarch64__
#define LOAD_DATA \
"1: \n" \
"ld1 {v0.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v1.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v2.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v3.4s}, [%[din_ptr]], #16 \n" /*vld1q_f32(din_ptr0)*/
#define DO_RELU \
"fmax v0.4s, v0.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v1.4s, v1.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v2.4s, v2.4s, %[vzero].4s \n" /* vmaxq_f32() */ \
"fmax v3.4s, v3.4s, %[vzero].4s \n" /* vmaxq_f32() */
#define DO_RELU6 \
"fmin v0.4s, v0.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v1.4s, v1.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v2.4s, v2.4s, %[vsix].4s \n" /* vmaxq_f32() */ \
"fmin v3.4s, v3.4s, %[vsix].4s \n" /* vmaxq_f32() */
#define DO_LEAKY_RELU \
"cmhs v4.4s, v0.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v5.4s, v0.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v6.4s, v1.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v7.4s, v1.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v8.4s, v2.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v9.4s, v2.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"cmhs v10.4s, v3.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v11.4s, v3.4s, %[vscale].4s \n" /* vmulq_f32 */ \
"bif v0.16b, v5.16b, v4.16b \n" /* choose*/ \
"bif v1.16b, v7.16b, v6.16b \n" /* choose*/ \
"bif v2.16b, v9.16b, v8.16b \n" /* choose*/ \
"bif v3.16b, v11.16b, v10.16b \n" /* choose*/
#define DO_STORE \
"subs %w[cnt], %w[cnt], #1 \n" \
"st1 {v0.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v1.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v2.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"st1 {v3.4s}, [%[dout_ptr]], #16 \n" /* vst1q_f32() */ \
"bne 1b \n"
#else
#define LOAD_DATA \
"1: \n" \
"vld1.32 {d6-d7}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d8-d9}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d10-d11}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n" \
"vld1.32 {d12-d13}, [%[din_ptr]]! @ vld1q_f32(din_ptr) \n"
#define DO_RELU \
"vmax.f32 q3, q3, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q4, q4, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q5, q5, %q[vzero] @ vmaxq_f32() \n" \
"vmax.f32 q6, q6, %q[vzero] @ vmaxq_f32() \n"
#define DO_RELU6 \
"vmin.f32 q3, q3, %q[vsix] @ vminq_f32() \n" \
"vmin.f32 q4, q4, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q5, q5, %q[vsix] @ vmaxq_f32() \n" \
"vmin.f32 q6, q6, %q[vsix] @ vmaxq_f32() \n"
#define DO_LEAKY_RELU \
"vcge.f32 q7, q3, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q8, q3, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q9, q4, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q10, q4, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q11, q5, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q12, q5, %q[vscale] @ vmulq_f32 \n" \
"vcge.f32 q13, q6, %q[vzero] @ vcgeq_u32 \n" \
"vmul.f32 q14, q6, %q[vscale] @ vmulq_f32 \n" \
"vbif q3, q8, q7 @ choose \n" \
"vbif q4, q10, q9 @ choose \n" \
"vbif q5, q12, q11 @ choose \n" \
"vbif q6, q13, q13 @ choose \n"
#define DO_STORE \
"subs %[cnt], #1 \n" \
"vst1.32 {d6-d7}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d8-d9}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d10-d11}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"vst1.32 {d12-d13}, [%[dout_ptr]]! @ vst1q_f32() \n" \
"bne 1b \n"
#endif
/*
* Data do activation process
* Now support relu relu6 leakyrelu act
*/
inline void act_switch_process(float* src,
float* dst,
int size,
const operators::ActivationParam* act_param) {
int cnt = size >> 4;
int remain = size % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
if (act_param != nullptr && act_param->has_active) {
float32x4_t vsix = vdupq_n_f32(act_param->Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param->Leaky_relu_alpha);
if (cnt > 0) {
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
LOAD_DATA DO_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kRelu6:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_RELU DO_RELU6 DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "memory", "cc", "v0", "v1", "v2", "v3");
#else
asm volatile(
LOAD_DATA DO_RELU DO_RELU6 DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vsix] "w"(vsix)
: "memory", "cc", "q3", "q4", "q5", "q6");
#endif
break;
case lite_api::ActivationType::kLeakyRelu:
#ifdef __aarch64__
asm volatile(
LOAD_DATA DO_LEAKY_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale)
: "memory",
"cc",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11");
#else
asm volatile(
LOAD_DATA DO_LEAKY_RELU DO_STORE
: [din_ptr] "+r"(src), [dout_ptr] "+r"(dst), [cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vscale] "w"(vscale)
: "memory",
"cc",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14");
#endif
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
// remain
switch (act_param->active_type) {
case lite_api::ActivationType::kRelu:
for (int i = 0; i < remain; i++) {
*dst = *src >= 0.f ? *src : 0.f;
src++;
dst++;
}
case lite_api::ActivationType::kRelu6:
for (int i = 0; i < remain; i++) {
float tmp = *src >= 0.f ? *src : 0.f;
*dst = tmp <= act_param->Relu_clipped_coef
? tmp
: act_param->Relu_clipped_coef;
src++;
dst++;
}
case lite_api::ActivationType::kLeakyRelu:
for (int i = 0; i < remain; i++) {
if (*src >= 0.f) {
*dst = *src;
} else {
*dst = *src * act_param->Leaky_relu_alpha;
}
src++;
dst++;
}
break;
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param->active_type)
<< " fuse not support";
}
}
}
/*wirte result in outputs /*wirte result in outputs
* input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w] * input din: [n, c / 8, h, w * 8], output dout: [n, c, h, w]
*/ */
......
...@@ -52,6 +52,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, ...@@ -52,6 +52,7 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1_fp32(const float* din, void conv_depthwise_3x3s1_fp32(const float* din,
...@@ -67,7 +68,6 @@ void conv_depthwise_3x3s1_fp32(const float* din, ...@@ -67,7 +68,6 @@ void conv_depthwise_3x3s1_fp32(const float* din,
const float* bias, const float* bias,
int pad, int pad,
bool flag_bias, bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
...@@ -84,7 +84,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -84,7 +84,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const float* bias, const float* bias,
int pad, int pad,
bool flag_bias, bool flag_bias,
bool flag_relu, const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
template <typename Dtype> template <typename Dtype>
......
...@@ -584,7 +584,6 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -584,7 +584,6 @@ void conv_depthwise_3x3_fp32(const void* din,
const int pad_w = paddings[2]; const int pad_w = paddings[2];
int stride = param.strides[1]; int stride = param.strides[1];
int pad = pad_w; int pad = pad_w;
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
bool pads_equal = bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
...@@ -603,7 +602,6 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -603,7 +602,6 @@ void conv_depthwise_3x3_fp32(const void* din,
bias, bias,
pad, pad,
flag_bias, flag_bias,
flag_relu,
act_param, act_param,
ctx); ctx);
} else { } else {
...@@ -638,7 +636,7 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -638,7 +636,7 @@ void conv_depthwise_3x3_fp32(const void* din,
bias, bias,
pad, pad,
flag_bias, flag_bias,
flag_relu, act_param,
ctx); ctx);
} else { } else {
conv_3x3s2_depthwise_fp32(reinterpret_cast<const float*>(din), conv_3x3s2_depthwise_fp32(reinterpret_cast<const float*>(din),
...@@ -653,6 +651,7 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -653,6 +651,7 @@ void conv_depthwise_3x3_fp32(const void* din,
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
param, param,
act_param,
ctx); ctx);
} }
} else { } else {
......
...@@ -52,12 +52,12 @@ inline int ConvOutputSize(int input_size, ...@@ -52,12 +52,12 @@ inline int ConvOutputSize(int input_size,
return output_size; return output_size;
} }
inline void UpdatePaddingAndDilation(std::vector<int>* paddings, void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations, std::vector<int>* dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::string padding_algorithm, const std::string padding_algorithm,
const lite::DDim data_dims, const lite::DDim data_dims,
const lite::DDim& ksize) { const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME" // when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") { if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
......
...@@ -136,7 +136,13 @@ class ConvOpLite : public OpLite { ...@@ -136,7 +136,13 @@ class ConvOpLite : public OpLite {
mutable ConvParam param_; mutable ConvParam param_;
std::string padding_algorithm_{""}; std::string padding_algorithm_{""};
}; };
// update padding dilation
void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize);
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册