未验证 提交 14e43112 编写于 作者: H HappyAngel 提交者: GitHub

[arm] improve conv3x3_dw performance with relu, relu6 and leakey relu (#3183)


* improve conv_dw profile with rel relu6 leakyrelu, test=develop

* add depthwise, test=develop

* fix ci error, test=develop

* fix cv demo print, test=develop
上级 f7f65134
...@@ -68,6 +68,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -68,6 +68,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
gemv_arm_int8.cc gemv_arm_int8.cc
conv3x3s1_direct_fp32.cc conv3x3s1_direct_fp32.cc
conv3x3s2_direct_fp32.cc conv3x3s2_direct_fp32.cc
conv3x3s1p01_depthwise_fp32_relu.cc
conv3x3s2p01_depthwise_fp32_relu.cc
conv3x3s1p01_depthwise_fp32.cc conv3x3s1p01_depthwise_fp32.cc
conv3x3s2p01_depthwise_fp32.cc conv3x3s2p01_depthwise_fp32.cc
conv3x3s1px_depthwise_fp32.cc conv3x3s1px_depthwise_fp32.cc
......
...@@ -91,23 +91,20 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -91,23 +91,20 @@ void conv_depthwise_3x3s1_fp32(const float *din,
bool flag_bias, bool flag_bias,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext *ctx) { ARMContext *ctx) {
bool has_active = act_param.has_active;
bool flag_relu = false;
bool relu6 = false;
if (has_active) {
if (act_param.active_type == lite_api::ActivationType::kRelu) {
flag_relu = true;
} else {
relu6 = true;
}
}
if (pad == 0) { if (pad == 0) {
if (w_in > 5) { if (w_in > 5) {
conv_depthwise_3x3s1p0_bias(dout, if (relu6) {
din, conv_depthwise_3x3s1p0_bias(dout,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s(dout,
din, din,
weights, weights,
bias, bias,
...@@ -120,25 +117,57 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -120,25 +117,57 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_out, w_out,
act_param, act_param,
ctx); ctx);
} else {
conv_depthwise_3x3s1p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s1p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} }
} }
if (pad == 1) { if (pad == 1) {
if (w_in > 4) { if (w_in > 4) {
conv_depthwise_3x3s1p1_bias(dout, if (relu6) {
din, conv_depthwise_3x3s1p1_bias(dout,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s(dout,
din, din,
weights, weights,
bias, bias,
...@@ -151,6 +180,51 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -151,6 +180,51 @@ void conv_depthwise_3x3s1_fp32(const float *din,
w_out, w_out,
act_param, act_param,
ctx); ctx);
} else {
conv_depthwise_3x3s1p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s1p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} }
} }
} }
...@@ -1924,223 +1998,169 @@ void act_switch_3x3s1p1(const float *din_ptr0, ...@@ -1924,223 +1998,169 @@ void act_switch_3x3s1p1(const float *din_ptr0,
float *vbias, float *vbias,
int cnt, int cnt,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active; float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
if (has_active) { float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
switch (act_param.active_type) { asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
case lite_api::ActivationType::kRelu: MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
asm volatile( : [cnt] "+r"(cnt),
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 [din_ptr0] "+r"(din_ptr0),
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU [din_ptr1] "+r"(din_ptr1),
: [cnt] "+r"(cnt), [din_ptr2] "+r"(din_ptr2),
[din_ptr0] "+r"(din_ptr0), [din_ptr3] "+r"(din_ptr3),
[din_ptr1] "+r"(din_ptr1), [din_ptr4] "+r"(din_ptr4),
[din_ptr2] "+r"(din_ptr2), [din_ptr5] "+r"(din_ptr5),
[din_ptr3] "+r"(din_ptr3), [doutr0] "+r"(doutr0),
[din_ptr4] "+r"(din_ptr4), [doutr1] "+r"(doutr1),
[din_ptr5] "+r"(din_ptr5), [doutr2] "+r"(doutr2),
[doutr0] "+r"(doutr0), [doutr3] "+r"(doutr3)
[doutr1] "+r"(doutr1), : [w0] "w"(wr0),
[doutr2] "+r"(doutr2), [w1] "w"(wr1),
[doutr3] "+r"(doutr3) [w2] "w"(wr2),
: [w0] "w"(wr0), [bias_val] "r"(vbias),
[w1] "w"(wr1), [vmask] "r"(vmask),
[w2] "w"(wr2), [rmask] "r"(rmask),
[bias_val] "r"(vbias), [vzero] "w"(vzero)
[vmask] "r"(vmask), : "cc",
[rmask] "r"(rmask), "memory",
[vzero] "w"(vzero) "v0",
: "cc", "v1",
"memory", "v2",
"v0", "v3",
"v1", "v4",
"v2", "v5",
"v3", "v6",
"v4", "v7",
"v5", "v8",
"v6", "v9",
"v7", "v10",
"v8", "v11",
"v9", "v12",
"v10", "v13",
"v11", "v14",
"v12", "v15",
"v13", "v16",
"v14", "v17",
"v15", "v18",
"v16", "v19",
"v17", "v20",
"v18", "v21",
"v19", "v22",
"v20", "v23",
"v21", "v24",
"v22", "v25");
"v23", break;
"v24", case lite_api::ActivationType::kRelu6:
"v25"); /* 0 <= din <= 6 */
break; asm volatile(
case lite_api::ActivationType::kRelu6: INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
/* 0 <= din <= 6 */ MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
asm volatile( : [cnt] "+r"(cnt),
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 [din_ptr0] "+r"(din_ptr0),
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 [din_ptr1] "+r"(din_ptr1),
: [cnt] "+r"(cnt), [din_ptr2] "+r"(din_ptr2),
[din_ptr0] "+r"(din_ptr0), [din_ptr3] "+r"(din_ptr3),
[din_ptr1] "+r"(din_ptr1), [din_ptr4] "+r"(din_ptr4),
[din_ptr2] "+r"(din_ptr2), [din_ptr5] "+r"(din_ptr5),
[din_ptr3] "+r"(din_ptr3), [doutr0] "+r"(doutr0),
[din_ptr4] "+r"(din_ptr4), [doutr1] "+r"(doutr1),
[din_ptr5] "+r"(din_ptr5), [doutr2] "+r"(doutr2),
[doutr0] "+r"(doutr0), [doutr3] "+r"(doutr3)
[doutr1] "+r"(doutr1), : [w0] "w"(wr0),
[doutr2] "+r"(doutr2), [w1] "w"(wr1),
[doutr3] "+r"(doutr3) [w2] "w"(wr2),
: [w0] "w"(wr0), [vsix] "w"(vsix),
[w1] "w"(wr1), [bias_val] "r"(vbias),
[w2] "w"(wr2), [vmask] "r"(vmask),
[vsix] "w"(vsix), [rmask] "r"(rmask),
[bias_val] "r"(vbias), [vzero] "w"(vzero)
[vmask] "r"(vmask), : "cc",
[rmask] "r"(rmask), "memory",
[vzero] "w"(vzero) "v0",
: "cc", "v1",
"memory", "v2",
"v0", "v3",
"v1", "v4",
"v2", "v5",
"v3", "v6",
"v4", "v7",
"v5", "v8",
"v6", "v9",
"v7", "v10",
"v8", "v11",
"v9", "v12",
"v10", "v13",
"v11", "v14",
"v12", "v15",
"v13", "v16",
"v14", "v17",
"v15", "v18",
"v16", "v19",
"v17", "v20",
"v18", "v21",
"v19", "v22",
"v20", "v23",
"v21", "v24",
"v22", "v25");
"v23", break;
"v24", case lite_api::ActivationType::kLeakyRelu:
"v25"); /*din = din >= 0 ? din : din * scale*/
break; asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
case lite_api::ActivationType::kLeakyRelu: MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1
/*din = din >= 0 ? din : din * scale*/ RIGHT_RESULT_S1_LEAKY_RELU
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU : [cnt] "+r"(cnt),
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU [din_ptr0] "+r"(din_ptr0),
RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU [din_ptr1] "+r"(din_ptr1),
: [cnt] "+r"(cnt), [din_ptr2] "+r"(din_ptr2),
[din_ptr0] "+r"(din_ptr0), [din_ptr3] "+r"(din_ptr3),
[din_ptr1] "+r"(din_ptr1), [din_ptr4] "+r"(din_ptr4),
[din_ptr2] "+r"(din_ptr2), [din_ptr5] "+r"(din_ptr5),
[din_ptr3] "+r"(din_ptr3), [doutr0] "+r"(doutr0),
[din_ptr4] "+r"(din_ptr4), [doutr1] "+r"(doutr1),
[din_ptr5] "+r"(din_ptr5), [doutr2] "+r"(doutr2),
[doutr0] "+r"(doutr0), [doutr3] "+r"(doutr3)
[doutr1] "+r"(doutr1), : [w0] "w"(wr0),
[doutr2] "+r"(doutr2), [w1] "w"(wr1),
[doutr3] "+r"(doutr3) [w2] "w"(wr2),
: [w0] "w"(wr0), [vscale] "w"(vscale),
[w1] "w"(wr1), [bias_val] "r"(vbias),
[w2] "w"(wr2), [vmask] "r"(vmask),
[vscale] "w"(vscale), [rmask] "r"(rmask),
[bias_val] "r"(vbias), [vzero] "w"(vzero)
[vmask] "r"(vmask), : "cc",
[rmask] "r"(rmask), "memory",
[vzero] "w"(vzero) "v0",
: "cc", "v1",
"memory", "v2",
"v0", "v3",
"v1", "v4",
"v2", "v5",
"v3", "v6",
"v4", "v7",
"v5", "v8",
"v6", "v9",
"v7", "v10",
"v8", "v11",
"v9", "v12",
"v10", "v13",
"v11", "v14",
"v12", "v15",
"v13", "v16",
"v14", "v17",
"v15", "v18",
"v16", "v19",
"v17", "v20",
"v18", "v21",
"v19", "v22",
"v20", "v23",
"v21", "v24",
"v22", "v25");
"v23", break;
"v24", default:
"v25"); LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
break; << " fuse not support";
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} }
} }
#else #else
...@@ -2159,153 +2179,117 @@ void act_switch_3x3s1p1(const float *din_ptr0, ...@@ -2159,153 +2179,117 @@ void act_switch_3x3s1p1(const float *din_ptr0,
float bias_val, float bias_val,
int cnt, int cnt,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active; float tmp = act_param.Relu_clipped_coef;
if (has_active) { float ss = act_param.Leaky_relu_alpha;
float tmp = act_param.Relu_clipped_coef; float vsix[4] = {tmp, tmp, tmp, tmp};
float ss = act_param.Leaky_relu_alpha; float vscale[4] = {ss, ss, ss, ss};
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss}; switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
switch (act_param.active_type) { asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
case lite_api::ActivationType::kRelu: MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
asm volatile( : [dout_ptr1] "+r"(doutr0),
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 [dout_ptr2] "+r"(doutr1),
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU [din0_ptr] "+r"(din_ptr0),
: [dout_ptr1] "+r"(doutr0), [din1_ptr] "+r"(din_ptr1),
[dout_ptr2] "+r"(doutr1), [din2_ptr] "+r"(din_ptr2),
[din0_ptr] "+r"(din_ptr0), [din3_ptr] "+r"(din_ptr3),
[din1_ptr] "+r"(din_ptr1), [cnt] "+r"(cnt),
[din2_ptr] "+r"(din_ptr2), [rmask] "+r"(rmask_ptr),
[din3_ptr] "+r"(din_ptr3), [vmask] "+r"(vmask_ptr)
[cnt] "+r"(cnt), : [wr0] "w"(wr0),
[rmask] "+r"(rmask_ptr), [wr1] "w"(wr1),
[vmask] "+r"(vmask_ptr) [wr2] "w"(wr2),
: [wr0] "w"(wr0), [bias_val] "r"(bias_val),
[wr1] "w"(wr1), [vzero] "w"(vzero)
[wr2] "w"(wr2), : "cc",
[bias_val] "r"(bias_val), "memory",
[vzero] "w"(vzero) "q4",
: "cc", "q5",
"memory", "q6",
"q4", "q7",
"q5", "q8",
"q6", "q9",
"q7", "q10",
"q8", "q11",
"q9", "q12",
"q10", "q13",
"q11", "q14",
"q12", "q15");
"q13", break;
"q14", case lite_api::ActivationType::kRelu6:
"q15"); /* 0 <= din <= 6 */
break; asm volatile(
case lite_api::ActivationType::kRelu6: INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
/* 0 <= din <= 6 */ MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
asm volatile( : [dout_ptr1] "+r"(doutr0),
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1 [dout_ptr2] "+r"(doutr1),
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6 [din0_ptr] "+r"(din_ptr0),
: [dout_ptr1] "+r"(doutr0), [din1_ptr] "+r"(din_ptr1),
[dout_ptr2] "+r"(doutr1), [din2_ptr] "+r"(din_ptr2),
[din0_ptr] "+r"(din_ptr0), [din3_ptr] "+r"(din_ptr3),
[din1_ptr] "+r"(din_ptr1), [cnt] "+r"(cnt),
[din2_ptr] "+r"(din_ptr2), [rmask] "+r"(rmask_ptr),
[din3_ptr] "+r"(din_ptr3), [vmask] "+r"(vmask_ptr)
[cnt] "+r"(cnt), : [wr0] "w"(wr0),
[rmask] "+r"(rmask_ptr), [wr1] "w"(wr1),
[vmask] "+r"(vmask_ptr) [wr2] "w"(wr2),
: [wr0] "w"(wr0), [bias_val] "r"(bias_val),
[wr1] "w"(wr1), [six_ptr] "r"(vsix),
[wr2] "w"(wr2), [vzero] "w"(vzero)
[bias_val] "r"(bias_val), : "cc",
[six_ptr] "r"(vsix), "memory",
[vzero] "w"(vzero) "q4",
: "cc", "q5",
"memory", "q6",
"q4", "q7",
"q5", "q8",
"q6", "q9",
"q7", "q10",
"q8", "q11",
"q9", "q12",
"q10", "q13",
"q11", "q14",
"q12", "q15");
"q13", break;
"q14", case lite_api::ActivationType::kLeakyRelu:
"q15"); /*din = din >= 0 ? din : din * scale*/
break; asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
case lite_api::ActivationType::kLeakyRelu: MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1
/*din = din >= 0 ? din : din * scale*/ RIGHT_RESULT_S1_LEAKY_RELU
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU : [dout_ptr1] "+r"(doutr0),
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU [dout_ptr2] "+r"(doutr1),
RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU [din0_ptr] "+r"(din_ptr0),
: [dout_ptr1] "+r"(doutr0), [din1_ptr] "+r"(din_ptr1),
[dout_ptr2] "+r"(doutr1), [din2_ptr] "+r"(din_ptr2),
[din0_ptr] "+r"(din_ptr0), [din3_ptr] "+r"(din_ptr3),
[din1_ptr] "+r"(din_ptr1), [cnt] "+r"(cnt),
[din2_ptr] "+r"(din_ptr2), [rmask] "+r"(rmask_ptr),
[din3_ptr] "+r"(din_ptr3), [vmask] "+r"(vmask_ptr)
[cnt] "+r"(cnt), : [wr0] "w"(wr0),
[rmask] "+r"(rmask_ptr), [wr1] "w"(wr1),
[vmask] "+r"(vmask_ptr) [wr2] "w"(wr2),
: [wr0] "w"(wr0), [bias_val] "r"(bias_val),
[wr1] "w"(wr1), [scale_ptr] "r"(vscale),
[wr2] "w"(wr2), [vzero] "w"(vzero)
[bias_val] "r"(bias_val), : "cc",
[scale_ptr] "r"(vscale), "memory",
[vzero] "w"(vzero) "q4",
: "cc", "q5",
"memory", "q6",
"q4", "q7",
"q5", "q8",
"q6", "q9",
"q7", "q10",
"q8", "q11",
"q9", "q12",
"q10", "q13",
"q11", "q14",
"q12", "q15");
"q13", break;
"q14", default:
"q15"); LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
break; << " fuse not support";
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
} }
#endif #endif
...@@ -2575,278 +2559,214 @@ void act_switch_3x3s1p1_s(const float *din_ptr0, ...@@ -2575,278 +2559,214 @@ void act_switch_3x3s1p1_s(const float *din_ptr0,
float32x4_t vzero, float32x4_t vzero,
float32x4_t wbias, float32x4_t wbias,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
#ifdef __aarch64__ #ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else #else
float tmp = act_param.Relu_clipped_coef; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha; float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp}; float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss}; float vscale[4] = {ss, ss, ss, ss};
#endif #endif
switch (act_param.active_type) { switch (act_param.active_type) {
case lite_api::ActivationType::kRelu: case lite_api::ActivationType::kRelu:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[mask] "w"(vmask_rp), [mask] "w"(vmask_rp),
[bias] "w"(wbias), [bias] "w"(wbias),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "v0", : "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16", "v16",
"v17"); "v17");
break; break;
#else #else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[mask] "w"(vmask_rp), [mask] "w"(vmask_rp),
[bias] "w"(wbias), [bias] "w"(wbias),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "cc", : "cc",
"memory", "memory",
"q6", "q6",
"q7", "q7",
"q8", "q8",
"q9", "q9",
"q10", "q10",
"q11", "q11",
"q12", "q12",
"q13", "q13",
"q14", "q14",
"q15"); "q15");
break; break;
#endif #endif
case lite_api::ActivationType::kRelu6: case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */ /* 0 <= din <= 6 */
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[mask] "w"(vmask_rp), [mask] "w"(vmask_rp),
[bias] "w"(wbias), [bias] "w"(wbias),
[vsix] "w"(vsix), [vsix] "w"(vsix),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "v0", : "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16", "v16",
"v17"); "v17");
break; break;
#else #else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6 asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[mask] "w"(vmask_rp), [mask] "w"(vmask_rp),
[bias] "w"(wbias), [bias] "w"(wbias),
[six_ptr] "r"(vsix), [six_ptr] "r"(vsix),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "cc", : "cc",
"memory", "memory",
"q6", "q6",
"q7", "q7",
"q8", "q8",
"q9", "q9",
"q10", "q10",
"q11", "q11",
"q12", "q12",
"q13", "q13",
"q14", "q14",
"q15"); "q15");
break; break;
#endif #endif
case lite_api::ActivationType::kLeakyRelu: case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/ /*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[mask] "w"(vmask_rp), [mask] "w"(vmask_rp),
[bias] "w"(wbias), [bias] "w"(wbias),
[vscale] "w"(vscale), [vscale] "w"(vscale),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "v0", : "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15", "v15",
"v16", "v16",
"v17", "v17",
"v18", "v18",
"v19", "v19",
"v20"); "v20");
break; break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[scale_ptr] "r"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else #else
asm volatile(COMPUTE_S_S1 RESULT_S_S1 asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[mask] "w"(vmask_rp), [mask] "w"(vmask_rp),
[bias] "w"(wbias), [bias] "w"(wbias),
[out1] "r"(doutr0), [scale_ptr] "r"(vscale),
[out2] "r"(doutr1) [out1] "r"(doutr0),
: "cc", [out2] "r"(doutr1)
"memory", : "cc",
"q6", "memory",
"q7", "q6",
"q8", "q7",
"q9", "q8",
"q10", "q9",
"q11", "q10",
"q12", "q11",
"q13", "q12",
"q14", "q13",
"q15"); "q14",
"q15");
break;
#endif #endif
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
} }
} }
/** /**
...@@ -2987,262 +2907,198 @@ void act_switch_3x3s1p0(const float *din_ptr0, ...@@ -2987,262 +2907,198 @@ void act_switch_3x3s1p0(const float *din_ptr0,
int cnt, int cnt,
int remain, int remain,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active; float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
if (has_active) { float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
switch (act_param.active_type) { asm volatile(
case lite_api::ActivationType::kRelu: INIT_S1
asm volatile( "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
INIT_S1 "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ MID_COMPUTE_S1 MID_RESULT_S1_RELU
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ "cmp %w[remain], #1 \n"
MID_COMPUTE_S1 MID_RESULT_S1_RELU "blt 0f \n" RIGHT_COMPUTE_S1
"cmp %w[remain], #1 \n" RIGHT_RESULT_S1_RELU "0: \n"
"blt 0f \n" RIGHT_COMPUTE_S1 : [cnt] "+r"(cnt),
RIGHT_RESULT_S1_RELU "0: \n" [din_ptr0] "+r"(din_ptr0),
: [cnt] "+r"(cnt), [din_ptr1] "+r"(din_ptr1),
[din_ptr0] "+r"(din_ptr0), [din_ptr2] "+r"(din_ptr2),
[din_ptr1] "+r"(din_ptr1), [din_ptr3] "+r"(din_ptr3),
[din_ptr2] "+r"(din_ptr2), [din_ptr4] "+r"(din_ptr4),
[din_ptr3] "+r"(din_ptr3), [din_ptr5] "+r"(din_ptr5),
[din_ptr4] "+r"(din_ptr4), [doutr0] "+r"(doutr0),
[din_ptr5] "+r"(din_ptr5), [doutr1] "+r"(doutr1),
[doutr0] "+r"(doutr0), [doutr2] "+r"(doutr2),
[doutr1] "+r"(doutr1), [doutr3] "+r"(doutr3)
[doutr2] "+r"(doutr2), : [w0] "w"(wr0),
[doutr3] "+r"(doutr3) [w1] "w"(wr1),
: [w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [bias_val] "r"(vbias),
[w2] "w"(wr2), [vmask] "r"(vmask),
[bias_val] "r"(vbias), [rmask] "r"(rmask),
[vmask] "r"(vmask), [vzero] "w"(vzero),
[rmask] "r"(rmask), [remain] "r"(remain)
[vzero] "w"(vzero), : "cc",
[remain] "r"(remain) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21",
"v20", "v22",
"v21", "v23",
"v22", "v24",
"v23", "v25");
"v24", break;
"v25"); case lite_api::ActivationType::kRelu6:
break; /* 0 <= din <= 6 */
case lite_api::ActivationType::kRelu6: asm volatile(
/* 0 <= din <= 6 */ INIT_S1
asm volatile( "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
INIT_S1 "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ MID_COMPUTE_S1 MID_RESULT_S1_RELU6
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ "cmp %w[remain], #1 \n"
MID_COMPUTE_S1 MID_RESULT_S1_RELU6 "blt 0f \n" RIGHT_COMPUTE_S1
"cmp %w[remain], #1 \n" RIGHT_RESULT_S1_RELU6 "0: \n"
"blt 0f \n" RIGHT_COMPUTE_S1 : [cnt] "+r"(cnt),
RIGHT_RESULT_S1_RELU6 "0: \n" [din_ptr0] "+r"(din_ptr0),
: [cnt] "+r"(cnt), [din_ptr1] "+r"(din_ptr1),
[din_ptr0] "+r"(din_ptr0), [din_ptr2] "+r"(din_ptr2),
[din_ptr1] "+r"(din_ptr1), [din_ptr3] "+r"(din_ptr3),
[din_ptr2] "+r"(din_ptr2), [din_ptr4] "+r"(din_ptr4),
[din_ptr3] "+r"(din_ptr3), [din_ptr5] "+r"(din_ptr5),
[din_ptr4] "+r"(din_ptr4), [doutr0] "+r"(doutr0),
[din_ptr5] "+r"(din_ptr5), [doutr1] "+r"(doutr1),
[doutr0] "+r"(doutr0), [doutr2] "+r"(doutr2),
[doutr1] "+r"(doutr1), [doutr3] "+r"(doutr3)
[doutr2] "+r"(doutr2), : [w0] "w"(wr0),
[doutr3] "+r"(doutr3) [w1] "w"(wr1),
: [w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [vsix] "w"(vsix),
[w2] "w"(wr2), [bias_val] "r"(vbias),
[vsix] "w"(vsix), [vmask] "r"(vmask),
[bias_val] "r"(vbias), [rmask] "r"(rmask),
[vmask] "r"(vmask), [remain] "r"(remain)
[rmask] "r"(rmask), : "cc",
[remain] "r"(remain) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21",
"v20", "v22",
"v21", "v23",
"v22", "v24",
"v23", "v25");
"v24", break;
"v25"); case lite_api::ActivationType::kLeakyRelu:
break; /*din = din >= 0 ? din : din * scale*/
case lite_api::ActivationType::kLeakyRelu: asm volatile(
/*din = din >= 0 ? din : din * scale*/ INIT_S1
asm volatile( "ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
INIT_S1 "ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ "ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ "ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ "ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ "ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ "cmp %w[remain], #1 \n"
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU "blt 0f \n" RIGHT_COMPUTE_S1
"cmp %w[remain], #1 \n" RIGHT_RESULT_S1_LEAKY_RELU "0: \n"
"blt 0f \n" RIGHT_COMPUTE_S1 : [cnt] "+r"(cnt),
RIGHT_RESULT_S1_LEAKY_RELU "0: \n" [din_ptr0] "+r"(din_ptr0),
: [cnt] "+r"(cnt), [din_ptr1] "+r"(din_ptr1),
[din_ptr0] "+r"(din_ptr0), [din_ptr2] "+r"(din_ptr2),
[din_ptr1] "+r"(din_ptr1), [din_ptr3] "+r"(din_ptr3),
[din_ptr2] "+r"(din_ptr2), [din_ptr4] "+r"(din_ptr4),
[din_ptr3] "+r"(din_ptr3), [din_ptr5] "+r"(din_ptr5),
[din_ptr4] "+r"(din_ptr4), [doutr0] "+r"(doutr0),
[din_ptr5] "+r"(din_ptr5), [doutr1] "+r"(doutr1),
[doutr0] "+r"(doutr0), [doutr2] "+r"(doutr2),
[doutr1] "+r"(doutr1), [doutr3] "+r"(doutr3)
[doutr2] "+r"(doutr2), : [w0] "w"(wr0),
[doutr3] "+r"(doutr3) [w1] "w"(wr1),
: [w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [vscale] "w"(vscale),
[w2] "w"(wr2), [bias_val] "r"(vbias),
[vscale] "w"(vscale), [vmask] "r"(vmask),
[bias_val] "r"(vbias), [rmask] "r"(rmask),
[vmask] "r"(vmask), [remain] "r"(remain)
[rmask] "r"(rmask), : "cc",
[remain] "r"(remain) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21",
"v20", "v22",
"v21", "v23",
"v22", "v24",
"v23", "v25");
"v24", break;
"v25"); default:
break; LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
default: << " fuse not support";
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} }
} }
#else #else
...@@ -3262,191 +3118,146 @@ void act_switch_3x3s1p0(const float *din_ptr0, ...@@ -3262,191 +3118,146 @@ void act_switch_3x3s1p0(const float *din_ptr0,
int cnt, int cnt,
int remain, int remain,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active; float tmp = act_param.Relu_clipped_coef;
if (has_active) { float ss = act_param.Leaky_relu_alpha;
float tmp = act_param.Relu_clipped_coef; float vsix[4] = {tmp, tmp, tmp, tmp};
float ss = act_param.Leaky_relu_alpha; float vscale[4] = {ss, ss, ss, ss};
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss}; switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
switch (act_param.active_type) { asm volatile(INIT_S1
case lite_api::ActivationType::kRelu: "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
asm volatile(INIT_S1 "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" "vext.32 q6, q8, q9, #1 @ 0012\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
"vext.32 q6, q8, q9, #1 @ 0012\n" MID_RESULT_S1_RELU
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 "cmp %[remain], #1 \n"
MID_RESULT_S1_RELU "blt 0f \n" RIGHT_COMPUTE_S1
"cmp %[remain], #1 \n" RIGHT_RESULT_S1_RELU "0: \n"
"blt 0f \n" RIGHT_COMPUTE_S1 : [dout_ptr1] "+r"(doutr0),
RIGHT_RESULT_S1_RELU "0: \n" [dout_ptr2] "+r"(doutr1),
: [dout_ptr1] "+r"(doutr0), [din0_ptr] "+r"(din_ptr0),
[dout_ptr2] "+r"(doutr1), [din1_ptr] "+r"(din_ptr1),
[din0_ptr] "+r"(din_ptr0), [din2_ptr] "+r"(din_ptr2),
[din1_ptr] "+r"(din_ptr1), [din3_ptr] "+r"(din_ptr3),
[din2_ptr] "+r"(din_ptr2), [cnt] "+r"(cnt),
[din3_ptr] "+r"(din_ptr3), [rmask] "+r"(rmask_ptr),
[cnt] "+r"(cnt), [vmask] "+r"(vmask_ptr)
[rmask] "+r"(rmask_ptr), : [wr0] "w"(wr0),
[vmask] "+r"(vmask_ptr) [wr1] "w"(wr1),
: [wr0] "w"(wr0), [wr2] "w"(wr2),
[wr1] "w"(wr1), [bias_val] "r"(bias_val),
[wr2] "w"(wr2), [vzero] "w"(vzero),
[bias_val] "r"(bias_val), [remain] "r"(remain)
[vzero] "w"(vzero), : "cc",
[remain] "r"(remain) "memory",
: "cc", "q4",
"memory", "q5",
"q4", "q6",
"q5", "q7",
"q6", "q8",
"q7", "q9",
"q8", "q10",
"q9", "q11",
"q10", "q12",
"q11", "q13",
"q12", "q14",
"q13", "q15");
"q14", break;
"q15"); case lite_api::ActivationType::kRelu6:
break; /* 0 <= din <= 6 */
case lite_api::ActivationType::kRelu6: asm volatile(INIT_S1
/* 0 <= din <= 6 */ "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
asm volatile(INIT_S1 "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" "vext.32 q6, q8, q9, #1 @ 0012\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
"vext.32 q6, q8, q9, #1 @ 0012\n" MID_RESULT_S1_RELU6
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 "cmp %[remain], #1 \n"
MID_RESULT_S1_RELU6 "blt 0f \n" RIGHT_COMPUTE_S1
"cmp %[remain], #1 \n" RIGHT_RESULT_S1_RELU6 "0: \n"
"blt 0f \n" RIGHT_COMPUTE_S1 : [dout_ptr1] "+r"(doutr0),
RIGHT_RESULT_S1_RELU6 "0: \n" [dout_ptr2] "+r"(doutr1),
: [dout_ptr1] "+r"(doutr0), [din0_ptr] "+r"(din_ptr0),
[dout_ptr2] "+r"(doutr1), [din1_ptr] "+r"(din_ptr1),
[din0_ptr] "+r"(din_ptr0), [din2_ptr] "+r"(din_ptr2),
[din1_ptr] "+r"(din_ptr1), [din3_ptr] "+r"(din_ptr3),
[din2_ptr] "+r"(din_ptr2), [cnt] "+r"(cnt),
[din3_ptr] "+r"(din_ptr3), [rmask] "+r"(rmask_ptr),
[cnt] "+r"(cnt), [vmask] "+r"(vmask_ptr)
[rmask] "+r"(rmask_ptr), : [wr0] "w"(wr0),
[vmask] "+r"(vmask_ptr) [wr1] "w"(wr1),
: [wr0] "w"(wr0), [wr2] "w"(wr2),
[wr1] "w"(wr1), [six_ptr] "r"(vsix),
[wr2] "w"(wr2), [bias_val] "r"(bias_val),
[six_ptr] "r"(vsix), [vzero] "w"(vzero),
[bias_val] "r"(bias_val), [remain] "r"(remain)
[vzero] "w"(vzero), : "cc",
[remain] "r"(remain) "memory",
: "cc", "q4",
"memory", "q5",
"q4", "q6",
"q5", "q7",
"q6", "q8",
"q7", "q9",
"q8", "q10",
"q9", "q11",
"q10", "q12",
"q11", "q13",
"q12", "q14",
"q13", "q15");
"q14", break;
"q15"); case lite_api::ActivationType::kLeakyRelu:
break; /*din = din >= 0 ? din : din * scale*/
case lite_api::ActivationType::kLeakyRelu: asm volatile(INIT_S1
/*din = din >= 0 ? din : din * scale*/ "sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
asm volatile(INIT_S1 "sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" "sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" "sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" "vext.32 q6, q8, q9, #1 @ 0012\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" "vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
"vext.32 q6, q8, q9, #1 @ 0012\n" MID_RESULT_S1_LEAKY_RELU
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 "cmp %[remain], #1 \n"
MID_RESULT_S1_LEAKY_RELU "blt 0f \n" RIGHT_COMPUTE_S1
"cmp %[remain], #1 \n" RIGHT_RESULT_S1_LEAKY_RELU
"blt 0f \n" RIGHT_COMPUTE_S1 "0: \n"
RIGHT_RESULT_S1_LEAKY_RELU : [dout_ptr1] "+r"(doutr0),
"0: \n" [dout_ptr2] "+r"(doutr1),
: [dout_ptr1] "+r"(doutr0), [din0_ptr] "+r"(din_ptr0),
[dout_ptr2] "+r"(doutr1), [din1_ptr] "+r"(din_ptr1),
[din0_ptr] "+r"(din_ptr0), [din2_ptr] "+r"(din_ptr2),
[din1_ptr] "+r"(din_ptr1), [din3_ptr] "+r"(din_ptr3),
[din2_ptr] "+r"(din_ptr2), [cnt] "+r"(cnt),
[din3_ptr] "+r"(din_ptr3), [rmask] "+r"(rmask_ptr),
[cnt] "+r"(cnt), [vmask] "+r"(vmask_ptr)
[rmask] "+r"(rmask_ptr), : [wr0] "w"(wr0),
[vmask] "+r"(vmask_ptr) [wr1] "w"(wr1),
: [wr0] "w"(wr0), [wr2] "w"(wr2),
[wr1] "w"(wr1), [scale_ptr] "r"(vscale),
[wr2] "w"(wr2), [bias_val] "r"(bias_val),
[scale_ptr] "r"(vscale), [vzero] "w"(vzero),
[bias_val] "r"(bias_val), [remain] "r"(remain)
[vzero] "w"(vzero), : "cc",
[remain] "r"(remain) "memory",
: "cc", "q4",
"memory", "q5",
"q4", "q6",
"q5", "q7",
"q6", "q8",
"q7", "q9",
"q8", "q10",
"q9", "q11",
"q10", "q12",
"q11", "q13",
"q12", "q14",
"q13", "q15");
"q14", break;
"q15"); default:
break; LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
default: << " fuse not support";
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
} }
#endif #endif
...@@ -3694,287 +3505,220 @@ void act_switch_3x3s1p0_s(const float *din_ptr0, ...@@ -3694,287 +3505,220 @@ void act_switch_3x3s1p0_s(const float *din_ptr0,
unsigned int *vmask_ptr, unsigned int *vmask_ptr,
float bias_val, float bias_val,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active;
if (has_active) {
#ifdef __aarch64__ #ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef); float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha); float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else #else
float tmp = act_param.Relu_clipped_coef; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha; float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp}; float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss}; float vscale[4] = {ss, ss, ss, ss};
#endif #endif
switch (act_param.active_type) { switch (act_param.active_type) {
case lite_api::ActivationType::kRelu: case lite_api::ActivationType::kRelu:
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vbias] "w"(wbias), [vbias] "w"(wbias),
[mask1] "w"(vmask_rp1), [mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2), [mask2] "w"(vmask_rp2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "cc", : "cc",
"memory", "memory",
"v0", "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15"); "v15");
break; break;
#else #else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3), [din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr) [vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[bias_val] "r"(bias_val), [bias_val] "r"(bias_val),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "cc", : "cc",
"memory", "memory",
"q4", "q4",
"q5", "q5",
"q6", "q6",
"q7", "q7",
"q8", "q8",
"q9", "q9",
"q10", "q10",
"q11", "q11",
"q12", "q12",
"q13", "q13",
"q14", "q14",
"q15"); "q15");
break; break;
#endif #endif
case lite_api::ActivationType::kRelu6: case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */ /* 0 <= din <= 6 */
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vbias] "w"(wbias), [vbias] "w"(wbias),
[mask1] "w"(vmask_rp1), [mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2), [mask2] "w"(vmask_rp2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[vsix] "w"(vsix), [vsix] "w"(vsix),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "cc", : "cc",
"memory", "memory",
"v0", "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15"); "v15");
break; break;
#else #else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6 asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3), [din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr) [vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[six_ptr] "r"(vsix), [six_ptr] "r"(vsix),
[bias_val] "r"(bias_val), [bias_val] "r"(bias_val),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "cc", : "cc",
"memory", "memory",
"q4", "q4",
"q5", "q5",
"q6", "q6",
"q7", "q7",
"q8", "q8",
"q9", "q9",
"q10", "q10",
"q11", "q11",
"q12", "q12",
"q13", "q13",
"q14", "q14",
"q15"); "q15");
break; break;
#endif #endif
case lite_api::ActivationType::kLeakyRelu: case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/ /*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3) [din3] "+r"(din_ptr3)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vbias] "w"(wbias), [vbias] "w"(wbias),
[mask1] "w"(vmask_rp1), [mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2), [mask2] "w"(vmask_rp2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[vscale] "w"(vscale), [vscale] "w"(vscale),
[out1] "r"(doutr0), [out1] "r"(doutr0),
[out2] "r"(doutr1) [out2] "r"(doutr1)
: "cc", : "cc",
"memory", "memory",
"v0", "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15"); "v15");
break; break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[scale_ptr] "r"(vscale),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else #else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0), : [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1), [din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2), [din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3), [din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr) [vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0), : [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[vzero] "w"(vzero), [vzero] "w"(vzero),
[bias_val] "r"(bias_val), [scale_ptr] "r"(vscale),
[out1] "r"(doutr0), [bias_val] "r"(bias_val),
[out2] "r"(doutr1) [out1] "r"(doutr0),
: "cc", [out2] "r"(doutr1)
"memory", : "cc",
"q4", "memory",
"q5", "q4",
"q6", "q5",
"q7", "q6",
"q8", "q7",
"q9", "q8",
"q10", "q9",
"q11", "q10",
"q12", "q11",
"q13", "q12",
"q14", "q13",
"q15"); "q14",
"q15");
break;
#endif #endif
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
} }
} }
/** /**
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#ifdef __aarch64__
#define INIT_S1 \
"PRFM PLDL1KEEP, [%[din_ptr0]] \n" \
"PRFM PLDL1KEEP, [%[din_ptr1]] \n" \
"PRFM PLDL1KEEP, [%[din_ptr2]] \n" \
"PRFM PLDL1KEEP, [%[din_ptr3]] \n" \
"PRFM PLDL1KEEP, [%[din_ptr4]] \n" \
"PRFM PLDL1KEEP, [%[din_ptr5]] \n" \
"movi v21.4s, #0x0\n" /* out0 = 0 */ \
\
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/
#define LEFT_COMPUTE_S1 \
"ext v16.16b, %[vzero].16b, v0.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ /* r0 */ \
"fmla v12.4s, v0.4s, %[w0].s[1]\n" /* outr00 += din0_0123 * w0[1]*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"sub %[din_ptr0], %[din_ptr0], #4 \n" /* din_ptr0-- */ \
"sub %[din_ptr1], %[din_ptr1], #4 \n" /* din_ptr0-- */ \
\
"fmla v12.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din0_0012 * w0[0]*/ \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
"sub %[din_ptr2], %[din_ptr2], #4 \n" /* din_ptr0-- */ \
"sub %[din_ptr3], %[din_ptr3], #4 \n" /* din_ptr0-- */ \
\
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_1234 * w0[2]*/ \
\
"ext v16.16b, %[vzero].16b, v2.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234 */ /* r1 */ \
"fmla v13.4s , v2.4s, %[w0].s[1]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v12.4s , v2.4s, %[w1].s[1]\n" /* outr00 += din1_0123 * w1[1]*/ \
"sub %[din_ptr4], %[din_ptr4], #4 \n" /* din_ptr0-- */ \
"sub %[din_ptr5], %[din_ptr5], #4 \n" /* din_ptr0-- */ \
\
"fmla v13.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v12.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ext v17.16b, v4.16b, v5.16b, #4 \n" /* v16=1234 */ \
"ext v16.16b, %[vzero].16b, v4.16b, #12 \n" /* v16 = 00123*/ \
\
/* r2 */ \
"fmla v14.4s , v4.4s, %[w0].s[1]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v13.4s , v4.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"fmla v12.4s , v4.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \
\
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v14.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v12.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ext v16.16b, %[vzero].16b, v6.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234 */ /* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[1]\n" /*outr00 += din2_0123 * w0[1]*/ \
"fmla v14.4s , v6.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"fmla v13.4s , v6.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w0].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ext v16.16b, %[vzero].16b, v8.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234 */
#define LEFT_RESULT_S1 \
/* r4 */ \
"fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ /* r5 */ \
"fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \
"cmp %w[cnt], #1 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"blt 3f \n"
#define MID_COMPUTE_S1 \
"1: \n" /* r0 */ \
"fmla v12.4s , v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v0.4s}, [%[din_ptr0]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v12.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v1.4s}, [%[din_ptr0]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v2.4s}, [%[din_ptr1]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v3.4s}, [%[din_ptr1]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v4.4s}, [%[din_ptr2]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v5.4s}, [%[din_ptr2]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
#define MID_RESULT_S1 \
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\
"subs %w[cnt], %w[cnt], #1 \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"bne 1b \n"
#define RIGHT_COMPUTE_S1 \
"3: \n" \
"ld1 {v18.4s, v19.4s}, [%[vmask]] \n" \
"ld1 {v22.4s}, [%[doutr0]] \n" \
"ld1 {v23.4s}, [%[doutr1]] \n" \
"ld1 {v24.4s}, [%[doutr2]] \n" \
"ld1 {v25.4s}, [%[doutr3]] \n" \
\
"bif v0.16b, %[vzero].16b, v18.16b \n" \
"bif v1.16b, %[vzero].16b, v19.16b \n" \
"bif v2.16b, %[vzero].16b, v18.16b \n" \
"bif v3.16b, %[vzero].16b, v19.16b \n" \
\
"bif v4.16b, %[vzero].16b, v18.16b \n" \
"bif v5.16b, %[vzero].16b, v19.16b \n" \
"bif v6.16b, %[vzero].16b, v18.16b \n" \
"bif v7.16b, %[vzero].16b, v19.16b \n" \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ /* r0 */ \
"fmla v12.4s, v0.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"bif v8.16b, %[vzero].16b, v18.16b \n" \
"bif v9.16b, %[vzero].16b, v19.16b \n" \
"bif v10.16b, %[vzero].16b, v18.16b \n" \
"bif v11.16b, %[vzero].16b, v19.16b \n" \
\
"fmla v12.4s, v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"ld1 {v18.4s}, [%[rmask]] \n" \
\
"fmla v12.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v2.16b, v3.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v2.16b, v3.16b, #8 \n" /* v16 = 2345 */ /* r1 */ \
"fmla v13.4s , v2.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v12.4s , v2.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"fmla v13.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v12.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"fmla v13.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v12.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v4.16b, v5.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v4.16b, v5.16b, #8 \n" /* v16 = 2345 */ /* r2 */ \
"fmla v14.4s , v4.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v4.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v12.4s , v4.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"fmla v14.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v12.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"fmla v14.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v12.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v6.16b, v7.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v6.16b, v7.16b, #8 \n" /* v16 = 2345 */
#define RIGHT_RESULT_S1 \
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"bif v12.16b, v22.16b, v18.16b \n" \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"bif v13.16b, v23.16b, v18.16b \n" \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"st1 {v13.4s}, [%[doutr1]], #16 \n" \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"bif v14.16b, v24.16b, v18.16b \n" \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"bif v15.16b, v25.16b, v18.16b \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define LEFT_RESULT_S1_RELU \
/* r4 */ \
"fmla v15.4s , v8.4s, %[w1].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
"fmla v14.4s , v8.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w2[1]*/ \
\
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" /* vst1q_f32() */ \
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* vst1q_f32() */ \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w1[1]*/ \
\
"ext v16.16b, %[vzero].16b, v10.16b, #12 \n" /* v16 = 00123*/ \
"ext v17.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234 */ \
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ /* r5*/ \
"fmla v15.4s , v10.4s, %[w2].s[1]\n" /* outr00 += din2_0123 * w1[1]*/ \
\
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \
\
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[0]\n" /* outr00 += din2_0123 * w0[1]*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" /* vst1q_f32() */ \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din1_0123 * w0[1]*/ \
\
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" /* vst1q_f32() */ \
"cmp %w[cnt], #1 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
"blt 3f \n"
#define MID_RESULT_S1_RELU \
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v6.4s}, [%[din_ptr3]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
\
"ld1 {v7.4s}, [%[din_ptr3]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v12.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"st1 {v13.4s}, [%[doutr1]], #16 \n" \
\
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v13.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \
\
/* r3 */ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ \
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ \
"ld1 {v14.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v16 = 2345 */ \
\
"subs %w[cnt], %w[cnt], #1 \n" \
\
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n" \
"ld1 {v15.4s}, [%[bias_val]] \n" /*vdupq_n_f32(bias_val)*/ \
\
"bne 1b \n"
#define RIGHT_RESULT_S1_RELU \
/* r3 */ \
"fmla v15.4s , v6.4s, %[w0].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v6.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v13.4s , v6.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"fmax v12.4s, v12.4s, %[vzero].4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w0].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v13.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v12.16b, v22.16b, v18.16b \n" \
\
"fmla v15.4s , v17.4s, %[w0].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v13.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v8.16b, v9.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v8.16b, v9.16b, #8 \n" /* v16 = 2345 */ /* r3 */ \
"fmla v15.4s , v8.4s, %[w1].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
"fmla v14.4s , v8.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"st1 {v12.4s}, [%[doutr0]], #16 \n" \
"fmax v13.4s, v13.4s, %[vzero].4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w1].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
"fmla v14.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v13.16b, v23.16b, v18.16b \n" \
\
"fmla v15.4s , v17.4s, %[w1].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
"fmla v14.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"ext v16.16b, v10.16b, v11.16b, #4 \n" /* v16 = 1234*/ \
"ext v17.16b, v10.16b, v11.16b, #8 \n" /* v16 = 2345 */ \
\
"st1 {v13.4s}, [%[doutr1]], #16 \n" /* r3 */ \
"fmla v15.4s , v10.4s, %[w2].s[0]\n" /* outr00 += din0_0123 * w0[0]*/ \
\
"fmax v14.4s, v14.4s, %[vzero].4s \n" /*relu*/ \
\
"fmla v15.4s , v16.4s, %[w2].s[1]\n" /* outr00 += din0_1234 * w0[1]*/ \
\
"bif v14.16b, v24.16b, v18.16b \n" \
\
"fmla v15.4s , v17.4s, %[w2].s[2]\n" /* outr00 += din0_2345 * w0[2]*/ \
\
"st1 {v14.4s}, [%[doutr2]], #16 \n" \
\
"fmax v15.4s, v15.4s, %[vzero].4s \n" /*relu*/ \
\
"bif v15.16b, v25.16b, v18.16b \n" \
\
"st1 {v15.4s}, [%[doutr3]], #16 \n"
#define COMPUTE_S_S1 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
"prfm pldl1keep, [%[din2]]\n" \
"prfm pldl1keep, [%[din3]]\n" \
\
"ld1 {v0.4s}, [%[din0]], #16\n" \
"ld1 {v1.4s}, [%[din1]], #16\n" \
"ld1 {v2.4s}, [%[din2]], #16\n" \
"ld1 {v3.4s}, [%[din3]], #16\n" \
\
"bif v0.16b, %[zero].16b, %[mask].16b\n" \
"bif v1.16b, %[zero].16b, %[mask].16b\n" \
"bif v2.16b, %[zero].16b, %[mask].16b\n" \
"bif v3.16b, %[zero].16b, %[mask].16b\n" \
\
"ext v4.16b, %[zero].16b, v0.16b, #12\n" \
"ext v5.16b, %[zero].16b, v1.16b, #12\n" \
"ext v6.16b, %[zero].16b, v2.16b, #12\n" \
"ext v7.16b, %[zero].16b, v3.16b, #12\n" \
\
"ext v8.16b, v0.16b, %[zero].16b, #4\n" \
"ext v9.16b, v1.16b, %[zero].16b, #4\n" \
"ext v10.16b, v2.16b, %[zero].16b, #4\n" \
"ext v11.16b, v3.16b, %[zero].16b, #4\n" \
\
"fmul v12.4s, v0.4s, %[wr0].s[1]\n" \
"fmul v13.4s, v1.4s, %[wr0].s[1]\n" \
\
"fmul v14.4s, v1.4s, %[wr1].s[1]\n" \
"fmul v15.4s, v2.4s, %[wr1].s[1]\n" \
\
"fmul v16.4s, v2.4s, %[wr2].s[1]\n" \
"fmul v17.4s, v3.4s, %[wr2].s[1]\n" \
\
"fmla v12.4s, v4.4s, %[wr0].s[0]\n" \
"fmla v13.4s, v5.4s, %[wr0].s[0]\n" \
\
"fmla v14.4s, v5.4s, %[wr1].s[0]\n" \
"fmla v15.4s, v6.4s, %[wr1].s[0]\n" \
\
"fmla v16.4s, v6.4s, %[wr2].s[0]\n" \
"fmla v17.4s, v7.4s, %[wr2].s[0]\n" \
\
"fmla v12.4s, v8.4s, %[wr0].s[2]\n" \
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" \
\
"fmla v14.4s, v9.4s, %[wr1].s[2]\n" \
"fmla v15.4s, v10.4s, %[wr1].s[2]\n" \
\
"fmla v16.4s, v10.4s, %[wr2].s[2]\n" \
"fmla v17.4s, v11.4s, %[wr2].s[2]\n" \
\
"fadd v12.4s, v12.4s, v14.4s\n" \
"fadd v12.4s, v12.4s, v16.4s\n" \
\
"fadd v13.4s, v13.4s, v15.4s\n" \
"fadd v13.4s, v13.4s, v17.4s\n" \
\
"fadd v12.4s, v12.4s, %[bias].4s\n" \
"fadd v13.4s, v13.4s, %[bias].4s\n"
#define RESULT_S_S1 \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define RESULT_S_S1_RELU \
"prfm pldl1keep, [%[out1]]\n" \
"prfm pldl1keep, [%[out2]]\n" \
\
"fmax v12.4s, v12.4s, %[zero].4s\n" \
"fmax v13.4s, v13.4s, %[zero].4s\n" \
\
"st1 {v12.4s}, [%[out1]]\n" \
"st1 {v13.4s}, [%[out2]]\n"
#define COMPUTE_S_S1_P0 \
"prfm pldl1keep, [%[din0]]\n" \
"prfm pldl1keep, [%[din1]]\n" \
"prfm pldl1keep, [%[din2]]\n" \
"prfm pldl1keep, [%[din3]]\n" \
\
"ld1 {v0.4s, v1.4s}, [%[din0]]\n" \
"ld1 {v2.4s, v3.4s}, [%[din1]]\n" \
"ld1 {v4.4s, v5.4s}, [%[din2]]\n" \
"ld1 {v6.4s, v7.4s}, [%[din3]]\n" \
\
"bif v0.16b, %[zero].16b, %[mask1].16b\n" \
"bif v1.16b, %[zero].16b, %[mask2].16b\n" \
\
"bif v2.16b, %[zero].16b, %[mask1].16b\n" \
"bif v3.16b, %[zero].16b, %[mask2].16b\n" \
\
"bif v4.16b, %[zero].16b, %[mask1].16b\n" \
"bif v5.16b, %[zero].16b, %[mask2].16b\n" \
\
"bif v6.16b, %[zero].16b, %[mask1].16b\n" \
"bif v7.16b, %[zero].16b, %[mask2].16b\n" \
\
"ext v8.16b, v0.16b, v1.16b, #4\n" \
"ext v9.16b, v0.16b, v1.16b, #8\n" \
\
"and v12.16b, %[vbias].16b, %[vbias].16b \n" \
"and v13.16b, %[vbias].16b, %[vbias].16b \n" /* r0 */ \
"fmul v10.4s, v0.4s, %[wr0].s[0]\n" \
"fmul v11.4s, v8.4s, %[wr0].s[1]\n" \
"fmla v12.4s, v9.4s, %[wr0].s[2]\n" \
\
"ext v8.16b, v2.16b, v3.16b, #4\n" \
"ext v9.16b, v2.16b, v3.16b, #8\n" /* r1 */ \
"fmul v14.4s, v2.4s, %[wr0].s[0]\n" \
"fmla v10.4s, v2.4s, %[wr1].s[0]\n" \
\
"fmul v15.4s, v8.4s, %[wr0].s[1]\n" \
"fmla v11.4s, v8.4s, %[wr1].s[1]\n" \
\
"fmla v13.4s, v9.4s, %[wr0].s[2]\n" \
"fmla v12.4s, v9.4s, %[wr1].s[2]\n" \
\
"ext v8.16b, v4.16b, v5.16b, #4\n" \
"ext v9.16b, v4.16b, v5.16b, #8\n" /* r2 */ \
"fmla v14.4s, v4.4s, %[wr1].s[0]\n" \
"fmla v10.4s, v4.4s, %[wr2].s[0]\n" \
\
"fmla v15.4s, v8.4s, %[wr1].s[1]\n" \
"fmla v11.4s, v8.4s, %[wr2].s[1]\n" \
\
"fmla v13.4s, v9.4s, %[wr1].s[2]\n" \
"fmla v12.4s, v9.4s, %[wr2].s[2]\n" \
\
"ext v8.16b, v6.16b, v7.16b, #4\n" \
"ext v9.16b, v6.16b, v7.16b, #8\n" \
\
"fmla v14.4s, v6.4s, %[wr2].s[0]\n" \
\
"fmla v15.4s, v8.4s, %[wr2].s[1]\n" \
\
"fadd v12.4s, v12.4s, v10.4s\n" \
\
"fmla v13.4s, v9.4s, %[wr2].s[2]\n" \
\
"fadd v12.4s, v12.4s, v11.4s\n" \
"fadd v13.4s, v13.4s, v14.4s\n" \
"fadd v13.4s, v13.4s, v15.4s\n" // \
// "prfm pldl1keep, [%[out1]]\n" \
// "prfm pldl1keep, [%[out2]]\n" \
// \
// "st1 {v12.4s}, [%[out1]]\n" \
// "st1 {v13.4s}, [%[out2]]\n" \
#else
#define INIT_S1 \
"pld [%[din0_ptr]] @ preload data\n" \
"pld [%[din1_ptr]] @ preload data\n" \
"pld [%[din2_ptr]] @ preload data\n" \
"pld [%[din3_ptr]] @ preload data\n" \
\
"vld1.32 {d16-d18}, [%[din0_ptr]]! @ load din r0\n" \
"vld1.32 {d20-d22}, [%[din1_ptr]]! @ load din r1\n" \
"vld1.32 {d24-d26}, [%[din2_ptr]]! @ load din r2\n" \
"vld1.32 {d28-d30}, [%[din3_ptr]]! @ load din r3\n" \
\
"vdup.32 q4, %[bias_val] @ and \n" \
"vdup.32 q5, %[bias_val] @ and \n"
#define LEFT_COMPUTE_S1 \
"vext.32 q6, %q[vzero], q8, #3 @ 0012\n" \
"vext.32 q7, q8, q9, #1 @ 1234\n" /* r0 */ \
"vmla.f32 q4, q8, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
\
"sub %[din0_ptr], #12 @ 1pad + 2 float data overlap\n" \
"sub %[din1_ptr], #12 @ 1pad + 2 float data overlap\n" \
"sub %[din2_ptr], #12 @ 1pad + 2 float data overlap\n" \
"sub %[din3_ptr], #12 @ 1pad + 2 float data overlap\n" \
\
"vmla.f32 q4, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \
\
"pld [%[din0_ptr]] @ preload data\n" \
"pld [%[din1_ptr]] @ preload data\n" \
"pld [%[din2_ptr]] @ preload data\n" \
"pld [%[din3_ptr]] @ preload data\n" \
\
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \
\
"vext.32 q6, %q[vzero], q10, #3 @ 0012\n" \
"vext.32 q7, q10, q11, #1 @ 1234\n" \
\
/* r1 */ \
"vmla.f32 q5, q10, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q10, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \
\
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \
\
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[2]\n" \
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \
\
"vext.32 q6, %q[vzero], q12, #3 @ 0012\n" \
"vext.32 q7, q12, q13, #1 @ 1234\n" \
\
/* r2 */ \
"vmla.f32 q5, q12, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q12, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \
\
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \
\
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[2]\n" \
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \
\
"vext.32 q6, %q[vzero], q14, #3 @ 0012\n" \
"vext.32 q7, q14, q15, #1 @ 1234\n"
#define LEFT_RESULT_S1 \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
"cmp %[cnt], #1 @ check whether has mid cols\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define MID_COMPUTE_S1 \
"1: @ right pad entry\n" /* r0 */ \
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \
\
"pld [%[din0_ptr]] @ preload data\n" \
"pld [%[din1_ptr]] @ preload data\n" \
"pld [%[din2_ptr]] @ preload data\n" \
"pld [%[din3_ptr]] @ preload data\n" \
\
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d16-d17}, [%[din0_ptr]]! @ load din r0\n" \
\
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \
\
"vld1.32 {d18}, [%[din0_ptr]] @ load din r0\n" \
\
"vext.32 q6, q10, q11, #1 @ 1234\n" \
"vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d20-d21}, [%[din1_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d22}, [%[din1_ptr]] @ load din r0\n" \
\
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \
\
"vext.32 q6, q12, q13, #1 @ 1234\n" \
"vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d24-d25}, [%[din2_ptr]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d26}, [%[din2_ptr]] @ load din r0\n" \
\
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \
\
"vext.32 q6, q14, q15, #1 @ 1234\n" \
"vext.32 q7, q14, q15, #2 @ 2345\n"
#define MID_RESULT_S1 \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"subs %[cnt], #1 @ loop count minus 1\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
\
"bne 1b @ jump to main loop start point\n"
#define RIGHT_COMPUTE_S1 \
"3: @ right pad entry\n" \
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \
\
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \
"vld1.32 {d31}, [%[vmask]]! @ load din r0\n" \
\
"vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \
"vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \
"vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \
\
"vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \
"vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \
"vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \
\
"vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \
"vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \
"vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \
\
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
\
"vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \
"vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \
"vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \
\
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \
\
"vext.32 q6, q10, q11, #1 @ 1234\n" \
"vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d19}, [%[rmask]]! @ load din r0\n" \
"vld1.32 {d23}, [%[rmask]]! @ load din r0\n" \
\
"vmla.f32 q5, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d16-d17}, [%[dout_ptr1]] @ load din r0\n" \
"vld1.32 {d20-d21}, [%[dout_ptr2]] @ load din r0\n" \
\
"vmla.f32 q5, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \
\
"vext.32 q6, q12, q13, #1 @ 1234\n" \
"vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \
\
"vmla.f32 q5, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vmla.f32 q5, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q4, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \
\
"vext.32 q6, q14, q15, #1 @ 1234\n" \
"vext.32 q7, q14, q15, #2 @ 2345\n"
#define RIGHT_RESULT_S1 \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vbif d8, d16, d19 @ bit select, deal with right pad\n" \
"vbif d9, d17, d23 @ bit select, deal with right pad\n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
\
"vbif d10, d20, d19 @ bit select, deal with right pad\n" \
"vbif d11, d21, d23 @ bit select, deal with right pad\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define LEFT_RESULT_S1_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][0] @ q4 += 1234 * wr0[0]\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 1234 * wr0[2]\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
\
"cmp %[cnt], #1 @ check whether has mid cols\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
"blt 3f @ jump to main loop start point\n"
#define MID_RESULT_S1_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vld1.32 {d28-d29}, [%[din3_ptr]]! @ load din r0\n" \
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vld1.32 {d30}, [%[din3_ptr]] @ load din r0\n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" \
"vdup.32 q4, %[bias_val] @ and \n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n" \
\
"subs %[cnt], #1 @ loop count minus 1\n" \
\
"vdup.32 q5, %[bias_val] @ and \n" \
\
"bne 1b @ jump to main loop start point\n"
#define RIGHT_RESULT_S1_RELU \
/* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vmax.f32 q4, q4, %q[vzero] @ relu \n" \
\
"vmla.f32 q5, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vbif d8, d16, d19 @ bit select, deal with right pad\n" \
"vbif d9, d17, d23 @ bit select, deal with right pad\n" \
\
"vmla.f32 q5, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vst1.32 {d8-d9}, [%[dout_ptr1]]! @ store result, add pointer\n" \
\
"vmax.f32 q5, q5, %q[vzero] @ relu \n" \
\
"vbif d10, d20, d19 @ bit select, deal with right pad\n" \
"vbif d11, d21, d23 @ bit select, deal with right pad\n" \
\
"vst1.32 {d10-d11}, [%[dout_ptr2]]! @ store result, add pointer\n"
#define COMPUTE_S_S1 \
"pld [%[din0]]\n" \
"pld [%[din1]]\n" \
"pld [%[din2]]\n" \
"pld [%[din3]]\n" \
\
"vld1.32 {d12-d13}, [%[din0]]!\n" \
"vld1.32 {d14-d15}, [%[din1]]!\n" \
"vld1.32 {d16-d17}, [%[din2]]!\n" \
"vld1.32 {d18-d19}, [%[din3]]!\n" \
\
"vbif q6, %q[vzero], %q[mask]\n" \
"vbif q7, %q[vzero], %q[mask]\n" \
"vbif q8, %q[vzero], %q[mask]\n" \
"vbif q9, %q[vzero], %q[mask]\n" \
\
"vmul.f32 q14, q6, %e[wr0][1]\n" \
"vmul.f32 q15, q7, %e[wr0][1]\n" \
\
"vmla.f32 q14, q7, %e[wr1][1]\n" \
"vmla.f32 q15, q8, %e[wr1][1]\n" \
\
"vmla.f32 q14, q8, %e[wr2][1]\n" \
"vmla.f32 q15, q9, %e[wr2][1]\n" \
\
"vext.32 q10, %q[vzero], q6, #3\n" \
"vext.32 q11, %q[vzero], q7, #3\n" \
"vext.32 q12, %q[vzero], q8, #3\n" \
"vext.32 q13, %q[vzero], q9, #3\n" \
\
"vmla.f32 q14, q10, %e[wr0][0]\n" \
"vmla.f32 q15, q11, %e[wr0][0]\n" \
\
"vmla.f32 q14, q11, %e[wr1][0]\n" \
"vmla.f32 q15, q12, %e[wr1][0]\n" \
\
"vmla.f32 q14, q12, %e[wr2][0]\n" \
"vmla.f32 q15, q13, %e[wr2][0]\n" \
\
"vext.32 q10, q6, %q[vzero], #1\n" \
"vext.32 q11, q7, %q[vzero], #1\n" \
"vext.32 q12, q8, %q[vzero], #1\n" \
"vext.32 q13, q9, %q[vzero], #1\n" \
\
"vmla.f32 q14, q10, %f[wr0][0]\n" \
"vmla.f32 q15, q11, %f[wr0][0]\n" \
\
"vmla.f32 q14, q11, %f[wr1][0]\n" \
"vmla.f32 q15, q12, %f[wr1][0]\n" \
\
"vmla.f32 q14, q12, %f[wr2][0]\n" \
"vmla.f32 q15, q13, %f[wr2][0]\n" \
\
"vadd.f32 q14, q14, %q[bias]\n" \
"vadd.f32 q15, q15, %q[bias]\n"
#define RESULT_S_S1 \
"pld [%[out1]]\n" \
"pld [%[out2]]\n" \
\
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define RESULT_S_S1_RELU \
"pld [%[out1]]\n" \
"pld [%[out2]]\n" \
\
"vmax.f32 q14, q14, %q[vzero]\n" \
"vmax.f32 q15, q15, %q[vzero]\n" \
\
"vst1.32 {d28-d29}, [%[out1]]\n" \
"vst1.32 {d30-d31}, [%[out2]]\n"
#define COMPUTE_S_S1_P0 \
"pld [%[din0]]\n" \
"pld [%[din1]]\n" \
"pld [%[din2]]\n" \
"pld [%[din3]]\n" \
"vld1.32 {d16-d18}, [%[din0]] @ load din r0\n" \
"vld1.32 {d20-d22}, [%[din1]] @ load din r1\n" \
"vld1.32 {d24-d26}, [%[din2]] @ load din r2\n" \
"vld1.32 {d28-d30}, [%[din3]] @ load din r3\n" \
\
"vdup.32 q4, %[bias_val] @ and \n" \
"vdup.32 q5, %[bias_val] @ and \n" \
\
"vld1.32 {d19}, [%[vmask]]! @ load din r0\n" \
"vld1.32 {d23}, [%[vmask]]! @ load din r0\n" \
\
"vld1.32 {d27}, [%[vmask]]! @ load din r0\n" \
\
"vbif d16, %e[vzero], d19 @ bit select, deal with right pad\n" \
"vbif d20, %e[vzero], d19 @ bit select, deal with right pad\n" \
\
"vbif d17, %e[vzero], d23 @ bit select, deal with right pad\n" \
"vbif d21, %e[vzero], d23 @ bit select, deal with right pad\n" \
\
"vbif d18, %e[vzero], d27 @ bit select, deal with right pad\n" \
"vbif d22, %e[vzero], d27 @ bit select, deal with right pad\n" \
\
"vext.32 q6, q8, q9, #1 @ 1234\n" \
"vext.32 q7, q8, q9, #2 @ 2345\n" /* r0 */ \
"vmla.f32 q4, q8, %e[wr0][0] @ q4 += 0123 * wr0[0]\n" \
\
"vbif d24, %e[vzero], d19 @ bit select, deal with right pad\n" \
"vbif d25, %e[vzero], d23 @ bit select, deal with right pad\n" \
"vbif d26, %e[vzero], d27 @ bit select, deal with right pad\n" \
\
"vmla.f32 q4, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
\
"vbif d28, %e[vzero], d19 @ bit select, deal with right pad\n" \
"vbif d29, %e[vzero], d23 @ bit select, deal with right pad\n" \
"vbif d30, %e[vzero], d27 @ bit select, deal with right pad\n" \
\
"vmla.f32 q4, q7, %f[wr0][0] @ q4 += 2345 * wr0[2]\n" \
\
"vext.32 q6, q10, q11, #1 @ 1234\n" \
"vext.32 q7, q10, q11, #2 @ 2345\n" /* r1 */ \
"vmla.f32 q5, q10, %e[wr0][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q10, %e[wr1][0] @ q4 += 1234 * wr0[1]\n" \
\
"vmul.f32 q8, q6, %e[wr0][1] @ q4 += 1234 * wr0[1]\n" \
"vmul.f32 q10, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
\
"vmul.f32 q9, q7, %f[wr0][0] @ q4 += 1234 * wr0[1]\n" \
"vmul.f32 q11, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \
\
"vext.32 q6, q12, q13, #1 @ 1234\n" \
"vext.32 q7, q12, q13, #2 @ 2345\n" /* r2 */ \
"vmla.f32 q5, q12, %e[wr1][0] @ q4 += 1234 * wr0[0]\n" \
"vmla.f32 q4, q12, %e[wr2][0] @ q4 += 1234 * wr0[1]\n" \
\
"vmla.f32 q8, q6, %e[wr1][1] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q10, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
\
"vmla.f32 q9, q7, %f[wr1][0] @ q4 += 1234 * wr0[1]\n" \
"vmla.f32 q11, q7, %f[wr2][0] @ q4 += 1234 * wr0[1]\n" \
\
"vext.32 q6, q14, q15, #1 @ 1234\n" \
"vext.32 q7, q14, q15, #2 @ 2345\n" /* r3 */ \
"vmla.f32 q5, q14, %e[wr2][0] @ q4 += 0123 * wr0[0]\n" \
\
"vmla.f32 q8, q6, %e[wr2][1] @ q4 += 1234 * wr0[1]\n" \
"vadd.f32 q4, q4, q10 @ q4 += q10 \n" \
\
"pld [%[out1]]\n" \
"pld [%[out2]]\n" \
\
"vmla.f32 q9, q7, %f[wr2][0] @ q4 += 2345 * wr0[2]\n" \
"vadd.f32 q14, q4, q11 @ q4 += q10 \n" \
\
"vadd.f32 q5, q5, q8 @ q4 += q10 \n" \
"vadd.f32 q15, q5, q9 @ q4 += q10 \n"
#endif
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p1_bias_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
int cnt_col = tile_w - 1;
unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in);
const unsigned int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 5) {
size_pad_right = 1;
cnt_col -= 1;
remain = 4;
} else if (remain == 0 && size_pad_right == 6) {
size_pad_right = 2;
cnt_col -= 1;
remain = 4;
}
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
if (i == 0) {
din_ptr0 = zero_ptr;
din_ptr1 = dr0;
din_ptr2 = dr1;
din_ptr3 = dr2;
din_ptr4 = dr3;
din_ptr5 = dr4;
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
} else {
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
}
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 > h_in) {
switch (i + 5 - h_in) {
case 5:
din_ptr1 = zero_ptr;
case 4:
din_ptr2 = zero_ptr;
case 3:
din_ptr3 = zero_ptr;
case 2:
din_ptr4 = zero_ptr;
case 1:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = cnt_col;
if (flag_relu) {
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
// unsigned int* rst_mask = rmask;
if (i == 0) {
din_ptr0 = zero_ptr;
din_ptr1 = dr0;
din_ptr2 = dr1;
din_ptr3 = dr2;
dr0 = dr1;
dr1 = dr2;
dr2 = dr3;
dr3 = dr2 + w_in;
} else {
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
}
//! process bottom pad
if (i + 3 > h_in) {
switch (i + 3 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = cnt_col;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
if (flag_relu) {
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p1_bias_s_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
}
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p0_bias_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
if (flag_relu) {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 3 >= h_in) {
switch (i + 3 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
case 0:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
if (flag_relu) {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p0_bias_s_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#endif // __aarch64__
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float *dr0 = din_channel + j * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
doutr1 = trash_buf;
case 0:
dr3 = zero_ptr;
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
default:
break;
}
}
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -91,23 +91,20 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -91,23 +91,20 @@ void conv_depthwise_3x3s2_fp32(const float* din,
bool flag_bias, bool flag_bias,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
bool has_active = act_param.has_active;
bool flag_relu = false;
bool relu6 = false;
if (has_active) {
if (act_param.active_type == lite_api::ActivationType::kRelu) {
flag_relu = true;
} else {
relu6 = true;
}
}
if (pad == 0) { if (pad == 0) {
if (w_in > 8) { if (w_in > 8) {
conv_depthwise_3x3s2p0_bias(dout, if (relu6) {
din, conv_depthwise_3x3s2p0_bias(dout,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s(dout,
din, din,
weights, weights,
bias, bias,
...@@ -120,25 +117,57 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -120,25 +117,57 @@ void conv_depthwise_3x3s2_fp32(const float* din,
w_out, w_out,
act_param, act_param,
ctx); ctx);
} else {
conv_depthwise_3x3s2p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} }
} }
if (pad == 1) { if (pad == 1) {
if (w_in > 7) { if (w_in > 7) {
conv_depthwise_3x3s2p1_bias(dout, if (relu6) {
din, conv_depthwise_3x3s2p1_bias(dout,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s(dout,
din, din,
weights, weights,
bias, bias,
...@@ -151,6 +180,51 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -151,6 +180,51 @@ void conv_depthwise_3x3s2_fp32(const float* din,
w_out, w_out,
act_param, act_param,
ctx); ctx);
} else {
conv_depthwise_3x3s2p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} }
} }
} }
...@@ -978,207 +1052,158 @@ void act_switch_3x3s2p1(const float* din0_ptr, ...@@ -978,207 +1052,158 @@ void act_switch_3x3s2p1(const float* din0_ptr,
int cnt, int cnt,
int cnt_remain, int cnt_remain,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active; float tmp = act_param.Relu_clipped_coef;
if (has_active) { float ss = act_param.Leaky_relu_alpha;
float tmp = act_param.Relu_clipped_coef; float vsix[4] = {tmp, tmp, tmp, tmp};
float ss = act_param.Leaky_relu_alpha; float vscale[4] = {ss, ss, ss, ss};
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss}; switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
switch (act_param.active_type) { asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
case lite_api::ActivationType::kRelu: MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
asm volatile( : [inptr0] "+r"(din0_ptr),
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 [inptr1] "+r"(din1_ptr),
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU [inptr2] "+r"(din2_ptr),
: [inptr0] "+r"(din0_ptr), [inptr3] "+r"(din3_ptr),
[inptr1] "+r"(din1_ptr), [inptr4] "+r"(din4_ptr),
[inptr2] "+r"(din2_ptr), [outptr0] "+r"(doutr0_ptr),
[inptr3] "+r"(din3_ptr), [outptr1] "+r"(doutr1_ptr),
[inptr4] "+r"(din4_ptr), [cnt] "+r"(cnt)
[outptr0] "+r"(doutr0_ptr), : [vzero] "w"(vzero),
[outptr1] "+r"(doutr1_ptr), [w0] "w"(wr0),
[cnt] "+r"(cnt) [w1] "w"(wr1),
: [vzero] "w"(vzero), [w2] "w"(wr2),
[w0] "w"(wr0), [remain] "r"(cnt_remain),
[w1] "w"(wr1), [mask1] "w"(vmask_rp1),
[w2] "w"(wr2), [mask2] "w"(vmask_rp2),
[remain] "r"(cnt_remain), [wmask] "w"(wmask),
[mask1] "w"(vmask_rp1), [vbias] "w"(wbias)
[mask2] "w"(vmask_rp2), : "cc",
[wmask] "w"(wmask), "memory",
[vbias] "w"(wbias) "v0",
: "cc", "v1",
"memory", "v2",
"v0", "v3",
"v1", "v4",
"v2", "v5",
"v3", "v6",
"v4", "v7",
"v5", "v8",
"v6", "v9",
"v7", "v10",
"v8", "v11",
"v9", "v12",
"v10", "v13",
"v11", "v14",
"v12", "v15",
"v13", "v16",
"v14", "v17",
"v15", "v18",
"v16", "v19",
"v17", "v20",
"v18", "v21");
"v19", break;
"v20", case lite_api::ActivationType::kRelu6:
"v21"); /* 0 <= din <= 6 */
break; asm volatile(
case lite_api::ActivationType::kRelu6: INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
/* 0 <= din <= 6 */ MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
asm volatile( : [inptr0] "+r"(din0_ptr),
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2 [inptr1] "+r"(din1_ptr),
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6 [inptr2] "+r"(din2_ptr),
: [inptr0] "+r"(din0_ptr), [inptr3] "+r"(din3_ptr),
[inptr1] "+r"(din1_ptr), [inptr4] "+r"(din4_ptr),
[inptr2] "+r"(din2_ptr), [outptr0] "+r"(doutr0_ptr),
[inptr3] "+r"(din3_ptr), [outptr1] "+r"(doutr1_ptr),
[inptr4] "+r"(din4_ptr), [cnt] "+r"(cnt)
[outptr0] "+r"(doutr0_ptr), : [vzero] "w"(vzero),
[outptr1] "+r"(doutr1_ptr), [w0] "w"(wr0),
[cnt] "+r"(cnt) [w1] "w"(wr1),
: [vzero] "w"(vzero), [w2] "w"(wr2),
[w0] "w"(wr0), [remain] "r"(cnt_remain),
[w1] "w"(wr1), [six_ptr] "r"(vsix),
[w2] "w"(wr2), [mask1] "w"(vmask_rp1),
[remain] "r"(cnt_remain), [mask2] "w"(vmask_rp2),
[six_ptr] "r"(vsix), [wmask] "w"(wmask),
[mask1] "w"(vmask_rp1), [vbias] "w"(wbias)
[mask2] "w"(vmask_rp2), : "cc",
[wmask] "w"(wmask), "memory",
[vbias] "w"(wbias) "v0",
: "cc", "v1",
"memory", "v2",
"v0", "v3",
"v1", "v4",
"v2", "v5",
"v3", "v6",
"v4", "v7",
"v5", "v8",
"v6", "v9",
"v7", "v10",
"v8", "v11",
"v9", "v12",
"v10", "v13",
"v11", "v14",
"v12", "v15",
"v13", "v16",
"v14", "v17",
"v15", "v18",
"v16", "v19",
"v17", "v20",
"v18", "v21",
"v19", "v22");
"v20", break;
"v21", case lite_api::ActivationType::kLeakyRelu:
"v22"); /*din = din >= 0 ? din : din * scale*/
break; asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
case lite_api::ActivationType::kLeakyRelu: MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU RIGHT_COMPUTE_S2
/*din = din >= 0 ? din : din * scale*/ RIGHT_RESULT_S2_LEAKY_RELU
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU : [inptr0] "+r"(din0_ptr),
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU [inptr1] "+r"(din1_ptr),
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU [inptr2] "+r"(din2_ptr),
: [inptr0] "+r"(din0_ptr), [inptr3] "+r"(din3_ptr),
[inptr1] "+r"(din1_ptr), [inptr4] "+r"(din4_ptr),
[inptr2] "+r"(din2_ptr), [outptr0] "+r"(doutr0_ptr),
[inptr3] "+r"(din3_ptr), [outptr1] "+r"(doutr1_ptr),
[inptr4] "+r"(din4_ptr), [cnt] "+r"(cnt)
[outptr0] "+r"(doutr0_ptr), : [vzero] "w"(vzero),
[outptr1] "+r"(doutr1_ptr), [w0] "w"(wr0),
[cnt] "+r"(cnt) [w1] "w"(wr1),
: [vzero] "w"(vzero), [w2] "w"(wr2),
[w0] "w"(wr0), [remain] "r"(cnt_remain),
[w1] "w"(wr1), [scale_ptr] "r"(vscale),
[w2] "w"(wr2), [mask1] "w"(vmask_rp1),
[remain] "r"(cnt_remain), [mask2] "w"(vmask_rp2),
[scale_ptr] "r"(vscale), [wmask] "w"(wmask),
[mask1] "w"(vmask_rp1), [vbias] "w"(wbias)
[mask2] "w"(vmask_rp2), : "cc",
[wmask] "w"(wmask), "memory",
[vbias] "w"(wbias) "v0",
: "cc", "v1",
"memory", "v2",
"v0", "v3",
"v1", "v4",
"v2", "v5",
"v3", "v6",
"v4", "v7",
"v5", "v8",
"v6", "v9",
"v7", "v10",
"v8", "v11",
"v9", "v12",
"v10", "v13",
"v11", "v14",
"v12", "v15",
"v13", "v16",
"v14", "v17",
"v15", "v18",
"v16", "v19",
"v17", "v20",
"v18", "v21",
"v19", "v22");
"v20", break;
"v21", default:
"v22"); LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
break; << " fuse not support";
default:
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
} }
} }
#endif #endif
...@@ -1570,249 +1595,191 @@ void act_switch_3x3s2p0(const float* din0_ptr, ...@@ -1570,249 +1595,191 @@ void act_switch_3x3s2p0(const float* din0_ptr,
int cnt, int cnt,
int cnt_remain, int cnt_remain,
const operators::ActivationParam act_param) { const operators::ActivationParam act_param) {
bool has_active = act_param.has_active; float tmp = act_param.Relu_clipped_coef;
if (has_active) { float ss = act_param.Leaky_relu_alpha;
float tmp = act_param.Relu_clipped_coef; float vsix[4] = {tmp, tmp, tmp, tmp};
float ss = act_param.Leaky_relu_alpha; float vscale[4] = {ss, ss, ss, ss};
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss}; switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
switch (act_param.active_type) { asm volatile(
case lite_api::ActivationType::kRelu: INIT_S2
asm volatile( "ld1 {v15.4s}, [%[inptr0]] \n"
INIT_S2 "ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v15.4s}, [%[inptr0]] \n" "ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n" "ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n" "ld1 {v21.4s}, [%[inptr4]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n" "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v21.4s}, [%[inptr4]] \n" MID_COMPUTE_S2 MID_RESULT_S2_RELU
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} "cmp %w[remain], #1 \n"
MID_COMPUTE_S2 MID_RESULT_S2_RELU "blt 4f \n" RIGHT_COMPUTE_S2
"cmp %w[remain], #1 \n" RIGHT_RESULT_S2_RELU
"blt 4f \n" RIGHT_COMPUTE_S2 "4: \n"
RIGHT_RESULT_S2_RELU : [inptr0] "+r"(din0_ptr),
"4: \n" [inptr1] "+r"(din1_ptr),
: [inptr0] "+r"(din0_ptr), [inptr2] "+r"(din2_ptr),
[inptr1] "+r"(din1_ptr), [inptr3] "+r"(din3_ptr),
[inptr2] "+r"(din2_ptr), [inptr4] "+r"(din4_ptr),
[inptr3] "+r"(din3_ptr), [outptr0] "+r"(doutr0_ptr),
[inptr4] "+r"(din4_ptr), [outptr1] "+r"(doutr1_ptr),
[outptr0] "+r"(doutr0_ptr), [cnt] "+r"(cnt)
[outptr1] "+r"(doutr1_ptr), : [vzero] "w"(vzero),
[cnt] "+r"(cnt) [w0] "w"(wr0),
: [vzero] "w"(vzero), [w1] "w"(wr1),
[w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [remain] "r"(cnt_remain),
[w2] "w"(wr2), [mask1] "w"(vmask_rp1),
[remain] "r"(cnt_remain), [mask2] "w"(vmask_rp2),
[mask1] "w"(vmask_rp1), [wmask] "w"(wmask),
[mask2] "w"(vmask_rp2), [vbias] "w"(wbias)
[wmask] "w"(wmask), : "cc",
[vbias] "w"(wbias) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21");
"v20", break;
"v21"); case lite_api::ActivationType::kRelu6:
break; /* 0 <= din <= 6 */
case lite_api::ActivationType::kRelu6: asm volatile(
/* 0 <= din <= 6 */ INIT_S2
asm volatile( "ld1 {v15.4s}, [%[inptr0]] \n"
INIT_S2 "ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v15.4s}, [%[inptr0]] \n" "ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n" "ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n" "ld1 {v21.4s}, [%[inptr4]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n" "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v21.4s}, [%[inptr4]] \n" "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} MID_RESULT_S2_RELU6
"ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2 "cmp %w[remain], #1 \n"
MID_RESULT_S2_RELU6 "blt 4f \n" RIGHT_COMPUTE_S2
"cmp %w[remain], #1 \n" RIGHT_RESULT_S2_RELU6
"blt 4f \n" RIGHT_COMPUTE_S2 "4: \n"
RIGHT_RESULT_S2_RELU6 : [inptr0] "+r"(din0_ptr),
"4: \n" [inptr1] "+r"(din1_ptr),
: [inptr0] "+r"(din0_ptr), [inptr2] "+r"(din2_ptr),
[inptr1] "+r"(din1_ptr), [inptr3] "+r"(din3_ptr),
[inptr2] "+r"(din2_ptr), [inptr4] "+r"(din4_ptr),
[inptr3] "+r"(din3_ptr), [outptr0] "+r"(doutr0_ptr),
[inptr4] "+r"(din4_ptr), [outptr1] "+r"(doutr1_ptr),
[outptr0] "+r"(doutr0_ptr), [cnt] "+r"(cnt)
[outptr1] "+r"(doutr1_ptr), : [vzero] "w"(vzero),
[cnt] "+r"(cnt) [w0] "w"(wr0),
: [vzero] "w"(vzero), [w1] "w"(wr1),
[w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [remain] "r"(cnt_remain),
[w2] "w"(wr2), [six_ptr] "r"(vsix),
[remain] "r"(cnt_remain), [mask1] "w"(vmask_rp1),
[six_ptr] "r"(vsix), [mask2] "w"(vmask_rp2),
[mask1] "w"(vmask_rp1), [wmask] "w"(wmask),
[mask2] "w"(vmask_rp2), [vbias] "w"(wbias)
[wmask] "w"(wmask), : "cc",
[vbias] "w"(wbias) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21",
"v20", "v22");
"v21", break;
"v22"); case lite_api::ActivationType::kLeakyRelu:
break; /*din = din >= 0 ? din : din * scale*/
case lite_api::ActivationType::kLeakyRelu: asm volatile(
/*din = din >= 0 ? din : din * scale*/ INIT_S2
asm volatile( "ld1 {v15.4s}, [%[inptr0]] \n"
INIT_S2 "ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v15.4s}, [%[inptr0]] \n" "ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n" "ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n" "ld1 {v21.4s}, [%[inptr4]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n" "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v21.4s}, [%[inptr4]] \n" "ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} MID_RESULT_S2_LEAKY_RELU
"ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2 "cmp %w[remain], #1 \n"
MID_RESULT_S2_LEAKY_RELU "blt 4f \n" RIGHT_COMPUTE_S2
"cmp %w[remain], #1 \n" RIGHT_RESULT_S2_LEAKY_RELU
"blt 4f \n" RIGHT_COMPUTE_S2 "4: \n"
RIGHT_RESULT_S2_LEAKY_RELU : [inptr0] "+r"(din0_ptr),
"4: \n" [inptr1] "+r"(din1_ptr),
: [inptr0] "+r"(din0_ptr), [inptr2] "+r"(din2_ptr),
[inptr1] "+r"(din1_ptr), [inptr3] "+r"(din3_ptr),
[inptr2] "+r"(din2_ptr), [inptr4] "+r"(din4_ptr),
[inptr3] "+r"(din3_ptr), [outptr0] "+r"(doutr0_ptr),
[inptr4] "+r"(din4_ptr), [outptr1] "+r"(doutr1_ptr),
[outptr0] "+r"(doutr0_ptr), [cnt] "+r"(cnt)
[outptr1] "+r"(doutr1_ptr), : [vzero] "w"(vzero),
[cnt] "+r"(cnt) [w0] "w"(wr0),
: [vzero] "w"(vzero), [w1] "w"(wr1),
[w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [remain] "r"(cnt_remain),
[w2] "w"(wr2), [scale_ptr] "r"(vscale),
[remain] "r"(cnt_remain), [mask1] "w"(vmask_rp1),
[scale_ptr] "r"(vscale), [mask2] "w"(vmask_rp2),
[mask1] "w"(vmask_rp1), [wmask] "w"(wmask),
[mask2] "w"(vmask_rp2), [vbias] "w"(wbias)
[wmask] "w"(wmask), : "cc",
[vbias] "w"(wbias) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21",
"v20", "v22");
"v21", break;
"v22"); default:
break; LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
default: << " fuse not support";
LOG(FATAL) << "this act_type: "
<< static_cast<int>(act_param.active_type)
<< " fuse not support";
}
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 "4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
} }
} }
#endif #endif
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include "lite/backends/arm/math/conv_depthwise.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
#ifdef __aarch64__
#define INIT_S2 \
"prfm pldl1keep, [%[inptr0]] \n" \
"prfm pldl1keep, [%[inptr1]] \n" \
"prfm pldl1keep, [%[inptr2]] \n" \
"prfm pldl1keep, [%[inptr3]] \n" \
"prfm pldl1keep, [%[inptr4]] \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"and v17.16b, %[vbias].16b, %[vbias].16b \n"
#define LEFT_COMPUTE_S2 \
"ext v10.16b, %[vzero].16b, v1.16b, #12 \n" /* r0 */ \
"fmul v11.4s, v0.4s, %[w0].s[1] \n" /* {0,2,4,6} * w01 */ \
"fmul v12.4s, v1.4s, %[w0].s[2] \n" /* {1,3,5,7} * w02 */ \
"fmla v16.4s, v10.4s, %[w0].s[0] \n" /* {0,1,3,5} * w00*/ \
\
"ext v10.16b, %[vzero].16b, v3.16b, #12 \n" /* v10 = {0,1,3,5} */ \
\
"sub %[inptr0], %[inptr0], #4 \n" \
"sub %[inptr1], %[inptr1], #4 \n" /* r1 */ \
"fmla v11.4s, v2.4s, %[w1].s[1] \n" \
"fmla v12.4s, v3.4s, %[w1].s[2] \n" \
"fmla v16.4s, v10.4s, %[w1].s[0] \n" \
\
"ext v10.16b, %[vzero].16b, v5.16b, #12 \n" \
\
"sub %[inptr2], %[inptr2], #4 \n" \
"sub %[inptr3], %[inptr3], #4 \n" /* r2 */ \
"fmul v13.4s, v4.4s, %[w0].s[1] \n" \
"fmla v11.4s, v4.4s, %[w2].s[1] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[2] \n" \
"fmla v12.4s, v5.4s, %[w2].s[2] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[0] \n" \
"fmla v16.4s, v10.4s, %[w2].s[0] \n" \
\
"ext v10.16b, %[vzero].16b, v7.16b, #12 \n" \
\
"sub %[inptr4], %[inptr4], #4 \n" /* r3 */ \
"fmla v13.4s, v6.4s, %[w1].s[1] \n" \
"fmla v14.4s, v7.4s, %[w1].s[2] \n" \
"fmla v17.4s, v10.4s, %[w1].s[0] \n" \
\
"ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
#define LEFT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"cmp %w[cnt], #1 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define MID_COMPUTE_S2 \
"2: \n" /* r0 */ \
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, v18.16b, #4 \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, v19.16b, #4 \n" \
\
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, v20.16b, #4 \n" \
\
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, v21.16b, #4 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
#define MID_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"bne 2b \n"
#define RIGHT_COMPUTE_S2 \
"1: \n" \
"cmp %w[remain], #1 \n" \
"blt 4f \n" \
"3: \n" \
"bif v0.16b, %[vzero].16b, %[mask1].16b \n" \
"bif v1.16b, %[vzero].16b, %[mask2].16b \n" \
\
"bif v2.16b, %[vzero].16b, %[mask1].16b \n" \
"bif v3.16b, %[vzero].16b, %[mask2].16b \n" \
\
"bif v4.16b, %[vzero].16b, %[mask1].16b \n" \
"bif v5.16b, %[vzero].16b, %[mask2].16b \n" \
\
"ext v10.16b, v0.16b, %[vzero].16b, #4 \n" \
\
"bif v6.16b, %[vzero].16b, %[mask1].16b \n" \
"bif v7.16b, %[vzero].16b, %[mask2].16b \n" /* r0 */ \
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, %[vzero].16b, #4 \n" \
"bif v8.16b, %[vzero].16b, %[mask1].16b \n" \
"bif v9.16b, %[vzero].16b, %[mask2].16b \n" /* r1 */ \
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, %[vzero].16b, #4 \n" /* r2 */ \
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, %[vzero].16b, #4 \n" /* r3 */ \
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, %[vzero].16b, #4 \n" \
"ld1 {v0.4s}, [%[outptr0]] \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" \
"ld1 {v1.4s}, [%[outptr1]] \n"
#define RIGHT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define LEFT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" \
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
\
"ld1 {v18.4s}, [%[inptr1]] \n" \
"ld1 {v19.4s}, [%[inptr2]] \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
\
"fmax v17.4s, v17.4s, %[vzero].4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"cmp %w[cnt], #1 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"blt 1f \n"
#define MID_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"ld1 {v19.4s}, [%[inptr2]] \n" \
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"and v16.16b, %[vbias].16b, %[vbias].16b \n" \
"subs %w[cnt], %w[cnt], #1 \n" \
\
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
\
"and v17.16b, %[vbias].16b, %[vbias].16b \n" \
\
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"fmax v17.4s, v17.4s, %[vzero].4s \n" /* relu */ \
\
"bif v17.16b, v1.16b, %[wmask].16b \n" \
\
"st1 {v17.4s}, [%[outptr1]], #16 \n" \
"4: \n"
#define COMPUTE_S_S2 \
"movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
\
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \
\
"bif v10.16b, v9.16b, v6.16b \n" \
"bif v11.16b, v9.16b, v7.16b \n" \
"bif v12.16b, v9.16b, v6.16b \n" \
"bif v13.16b, v9.16b, v7.16b \n" \
"bif v14.16b, v9.16b, v6.16b \n" \
"bif v15.16b, v9.16b, v7.16b \n" \
\
"ext v6.16b, v9.16b, v11.16b, #12 \n" \
"ext v7.16b, v9.16b, v13.16b, #12 \n" \
"ext v8.16b, v9.16b, v15.16b, #12 \n" \
\
"fmul v4.4s, v10.4s, %[wr0].s[1] \n" \
"fmul v5.4s, v11.4s, %[wr0].s[2] \n" \
"fmul v6.4s, v6.4s, %[wr0].s[0] \n" \
\
"fmla v4.4s, v12.4s, %[wr1].s[1] \n" \
"fmla v5.4s, v13.4s, %[wr1].s[2] \n" \
"fmla v6.4s, v7.4s, %[wr1].s[0] \n" \
\
"fmla v4.4s, v14.4s, %[wr2].s[1] \n" \
"fmla v5.4s, v15.4s, %[wr2].s[2] \n" \
"fmla v6.4s, v8.4s, %[wr2].s[0] \n" \
\
"fadd v4.4s, v4.4s, v5.4s \n" \
"fadd v4.4s, v4.4s, v6.4s \n"
#define RESULT_S_S2 \
"fadd v4.4s, v4.4s, %[bias].4s \n" \
\
"st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_RELU \
"fadd v4.4s, v4.4s, %[bias].4s \n" \
"fmax v4.4s, v4.4s, v9.4s \n" \
\
"st1 {v4.4s}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \
"movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
\
"ld2 {v10.4s, v11.4s}, [%[din0_ptr]], #32 \n" \
"ld2 {v12.4s, v13.4s}, [%[din1_ptr]], #32 \n" \
"ld2 {v14.4s, v15.4s}, [%[din2_ptr]], #32 \n" \
"and v4.16b, %[bias].16b, %[bias].16b \n" \
\
"bif v10.16b, v9.16b, v6.16b \n" \
"bif v11.16b, v9.16b, v7.16b \n" \
"bif v12.16b, v9.16b, v6.16b \n" \
"bif v13.16b, v9.16b, v7.16b \n" \
"bif v14.16b, v9.16b, v6.16b \n" \
"bif v15.16b, v9.16b, v7.16b \n" \
\
"ext v6.16b, v10.16b, v9.16b, #4 \n" \
"ext v7.16b, v12.16b, v9.16b, #4 \n" \
"ext v8.16b, v14.16b, v9.16b, #4 \n" \
\
"fmla v4.4s, v10.4s, %[wr0].s[0] \n" \
"fmul v5.4s, v11.4s, %[wr0].s[1] \n" \
"fmul v16.4s, v6.4s, %[wr0].s[2] \n" \
\
"fmla v4.4s, v12.4s, %[wr1].s[0] \n" \
"fmla v5.4s, v13.4s, %[wr1].s[1] \n" \
"fmla v16.4s, v7.4s, %[wr1].s[2] \n" \
\
"fmla v4.4s, v14.4s, %[wr2].s[0] \n" \
"fmla v5.4s, v15.4s, %[wr2].s[1] \n" \
"fmla v16.4s, v8.4s, %[wr2].s[2] \n" \
\
"fadd v4.4s, v4.4s, v5.4s \n" \
"fadd v4.4s, v4.4s, v16.4s \n"
#define RESULT_S_S2_P0 "st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \
"fmax v4.4s, v4.4s, v9.4s \n" \
"st1 {v4.4s}, [%[out]] \n"
#else
#define INIT_S2 \
"vmov.u32 q9, #0 \n" \
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r1\n" \
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \
"pld [%[din0_ptr]] @ preload data\n" \
"pld [%[din1_ptr]] @ preload data\n" \
"pld [%[din2_ptr]] @ preload data\n" \
\
"vdup.32 q3, %[bias] @ and \n"
#define LEFT_COMPUTE_S2 \
"vext.32 q6, q9, q11, #3 @ shift right 1 data\n" \
"vext.32 q7, q9, q13, #3 @ shift right 1 data\n" \
"vext.32 q8, q9, q15, #3 @ shift right 1 data\n" \
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 1, out0\n" \
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 1, out0\n" \
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 1, out0\n" \
\
"sub %[din0_ptr], #4 @ inpitr0 - 1\n" \
"sub %[din1_ptr], #4 @ inpitr1 - 1\n" \
"sub %[din2_ptr], #4 @ inpitr2 - 1\n" \
\
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \
\
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \
\
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \
\
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 1, out1\n" \
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \
\
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define LEFT_RESULT_S2 \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"cmp %[cnt], #1 \n" \
"blt 1f \n"
#define MID_COMPUTE_S2 \
"2: \n" \
"vld1.32 {d16}, [%[din0_ptr]] @ load din r0\n" \
"vdup.32 q3, %[bias] @ and \n" \
"vext.32 q6, q10, q8, #1 @ shift left 1 \n" \
"vld1.32 {d16}, [%[din1_ptr]] @ load din r1\n" \
\
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \
\
"vext.32 q7, q12, q8, #1 @ shift left 1 \n" \
"vld1.32 {d16}, [%[din2_ptr]] @ load din r1\n" \
\
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \
\
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \
\
"vext.32 q6, q14, q8, #1 @ shift left 1 \n" \
\
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \
\
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \
\
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define MID_RESULT_S2 \
"subs %[cnt], #1 \n" \
\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"bne 2b \n"
#define RIGHT_COMPUTE_S2 \
"1: \n" \
"cmp %[remain], #1 \n" \
"blt 3f \n" \
\
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \
"vdup.32 q3, %[bias] @ and \n" \
\
"vbif q10, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q11, q9, q7 @ bit select, deal with " \
"right pad\n" \
"vbif q12, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q13, q9, q7 @ bit select, deal with " \
"right pad\n" \
"vbif q14, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q15, q9, q7 @ bit select, deal with " \
"right pad\n" \
\
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" \
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" \
\
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \
\
"vext.32 q6, q14, q9, #1 @ shift left 1 \n" \
"vld1.f32 {d20-d21}, [%[outptr]] @ load output\n" \
\
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \
\
"vld1.f32 {d22-d23}, [%[mask_ptr]] @ load mask\n" \
\
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RIGHT_RESULT_S2 \
"vbif.f32 q3, q10, q11 @ write mask\n" \
\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"3: \n"
#define LEFT_RESULT_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"cmp %[cnt], #1 \n" \
"blt 1f \n"
#define MID_RESULT_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \
"subs %[cnt], #1 \n" \
\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \
"vbif.f32 q3, q10, q11 @ write mask\n" \
\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"3: \n"
#define COMPUTE_S_S2 \
"vmov.u32 q9, #0 \n" \
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \
"vdup.32 q3, %[bias] @ and \n" \
\
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \
\
"vbif q10, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q11, q9, q7 @ bit select, deal with " \
"right pad\n" \
"vbif q12, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q13, q9, q7 @ bit select, deal with " \
"right pad\n" \
"vbif q14, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q15, q9, q7 @ bit select, deal with " \
"right pad\n" \
\
"vext.32 q6, q9, q11, #3 @ shift left 1 \n" \
"vext.32 q7, q9, q13, #3 @ shift left 1 \n" \
"vext.32 q8, q9, q15, #3 @ shift left 1 \n" \
\
"vmul.f32 q4, q10, %e[wr0][1] @ mul weight 0, out0\n" \
"vmul.f32 q5, q11, %f[wr0][0] @ mul weight 0, out0\n" \
"vmla.f32 q3, q6, %e[wr0][0] @ mul weight 0, out0\n" \
\
"vmla.f32 q4, q12, %e[wr1][1] @ mul weight 1, out0\n" \
"vmla.f32 q5, q13, %f[wr1][0] @ mul weight 1, out0\n" \
"vmla.f32 q3, q7, %e[wr1][0] @ mul weight 1, out0\n" \
\
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n" \
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu\n" \
\
"vst1.32 {d6-d7}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \
"vmov.u32 q9, #0 \n" \
"vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \
"vdup.32 q3, %[bias] @ and \n" \
\
"vld2.32 {d20-d23}, [%[din0_ptr]]! @ load din r0\n" \
"vld2.32 {d24-d27}, [%[din1_ptr]]! @ load din r1\n" \
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \
\
"vbif q10, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q11, q9, q7 @ bit select, deal with " \
"right pad\n" \
"vbif q12, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q13, q9, q7 @ bit select, deal with " \
"right pad\n" \
"vbif q14, q9, q6 @ bit select, deal with " \
"right pad\n" \
"vbif q15, q9, q7 @ bit select, deal with " \
"right pad\n" \
\
"vext.32 q6, q10, q9, #1 @ shift left 1 \n" \
"vext.32 q7, q12, q9, #1 @ shift left 1 \n" \
"vext.32 q8, q14, q9, #1 @ shift left 1 \n" \
\
"vmul.f32 q4, q10, %e[wr0][0] @ mul weight 0, out0\n" \
"vmul.f32 q5, q11, %e[wr0][1] @ mul weight 0, out0\n" \
"vmla.f32 q3, q6, %f[wr0][0] @ mul weight 0, out0\n" \
\
"vmla.f32 q4, q12, %e[wr1][0] @ mul weight 1, out0\n" \
"vmla.f32 q5, q13, %e[wr1][1] @ mul weight 1, out0\n" \
"vmla.f32 q3, q7, %f[wr1][0] @ mul weight 1, out0\n" \
\
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n" \
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#endif
/**
* \brief depthwise convolution kernel 3x3, stride 2
* w_in > 7
*/
void conv_depthwise_3x3s2p1_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); //
int size_right_pad = w_out * 2 - w_in;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else {
dr0 = dr4;
dr1 = dr0 + w_in;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
if (flag_relu) {
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
dr0 = dr1;
dr1 = dr2;
dr2 = dr1 + w_in;
} else {
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
}
//! process bottom pad
if (i + 2 > h_in) {
switch (i + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
if (flag_relu) {
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/
void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
int hs = -1;
int he = 2;
float out_buf[4];
for (int j = 0; j < h_out; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
if (hs == -1) {
dr0 = zeros;
}
if (he > h_in) {
dr2 = zeros;
}
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
hs += 2;
he += 2;
}
}
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2
*/
// w_in > 7
void conv_depthwise_3x3s2p0_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
dr0 = dr4;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i * 2 + 5 > h_in) {
switch (i * 2 + 5 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
case 0:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
if (flag_relu) {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
//! process bottom pad
if (i * 2 + 3 > h_in) {
switch (i * 2 + 3 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
if (flag_relu) {
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
/**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/
void conv_depthwise_3x3s2p0_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
float out_buf[4];
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
for (int j = 0; j < h_out; j++) {
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
if (j * 2 + 2 >= h_in) {
switch (j + 2 - h_in) {
case 1:
din1_ptr = zero_ptr;
case 0:
din2_ptr = zero_ptr;
default:
break;
}
}
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w];
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -207,6 +207,118 @@ void conv_depthwise_5x5s2_int8(Dtype* dout, ...@@ -207,6 +207,118 @@ void conv_depthwise_5x5s2_int8(Dtype* dout,
int padh, int padh,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1p0_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p0_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p1_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p1_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -60,17 +60,18 @@ void fill_with_mat(cv::Mat& mat, uint8_t* src, int num) { // NOLINT ...@@ -60,17 +60,18 @@ void fill_with_mat(cv::Mat& mat, uint8_t* src, int num) { // NOLINT
} }
} }
double compare_diff(uint8_t* data1, uint8_t* data2, int size) { double compare_diff(uint8_t* data1, uint8_t* data2, int size, uint8_t* diff_v) {
double diff = 0.0; double diff = 0.0;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
double val = abs(data1[i] - data2[i]); double val = abs(data1[i] - data2[i]);
diff_v[i] = val;
diff = val > diff ? val : diff; diff = val > diff ? val : diff;
} }
return diff; return diff;
} }
void print_data(const uint8_t* data, int size) { void print_data(const uint8_t* data, int size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
std::cout << data[i] << " "; printf("%d ", data[i]);
if ((i + 1) % 10 == 0) { if ((i + 1) % 10 == 0) {
std::cout << std::endl; std::cout << std::endl;
} }
...@@ -139,7 +140,8 @@ bool test_convert(bool cv_run, ...@@ -139,7 +140,8 @@ bool test_convert(bool cv_run,
if (cv_run) { if (cv_run) {
resize_cv = im_resize.data; resize_cv = im_resize.data;
double diff = compare_diff(resize_cv, resize_lite, out_size); uint8_t* diff_v = new uint8_t[out_size];
double diff = compare_diff(resize_cv, resize_lite, out_size, diff_v);
if (diff > 1) { if (diff > 1) {
std::cout << "din: " << std::endl; std::cout << "din: " << std::endl;
print_data(src, in_size); print_data(src, in_size);
...@@ -147,6 +149,8 @@ bool test_convert(bool cv_run, ...@@ -147,6 +149,8 @@ bool test_convert(bool cv_run,
print_data(resize_cv, out_size); print_data(resize_cv, out_size);
std::cout << "lite out: " << std::endl; std::cout << "lite out: " << std::endl;
print_data(resize_lite, out_size); print_data(resize_lite, out_size);
std::cout << "lite out: " << std::endl;
print_data(diff_v, out_size);
return false; return false;
} else { } else {
// save_img // save_img
...@@ -225,7 +229,8 @@ bool test_flip(bool cv_run, ...@@ -225,7 +229,8 @@ bool test_flip(bool cv_run,
if (cv_run) { if (cv_run) {
resize_cv = im_resize.data; resize_cv = im_resize.data;
double diff = compare_diff(resize_cv, resize_lite, out_size); uint8_t* diff_v = new uint8_t[out_size];
double diff = compare_diff(resize_cv, resize_lite, out_size, diff_v);
if (diff > 1) { if (diff > 1) {
std::cout << "din: " << std::endl; std::cout << "din: " << std::endl;
print_data(src, in_size); print_data(src, in_size);
...@@ -233,6 +238,8 @@ bool test_flip(bool cv_run, ...@@ -233,6 +238,8 @@ bool test_flip(bool cv_run,
print_data(resize_cv, out_size); print_data(resize_cv, out_size);
std::cout << "lite out: " << std::endl; std::cout << "lite out: " << std::endl;
print_data(resize_lite, out_size); print_data(resize_lite, out_size);
std::cout << "diff out: " << std::endl;
print_data(diff_v, out_size);
return false; return false;
} else { } else {
// save_img // save_img
...@@ -316,7 +323,8 @@ bool test_rotate(bool cv_run, ...@@ -316,7 +323,8 @@ bool test_rotate(bool cv_run,
std::cout << "compare diff: " << std::endl; std::cout << "compare diff: " << std::endl;
if (cv_run) { if (cv_run) {
resize_cv = im_resize.data; resize_cv = im_resize.data;
double diff = compare_diff(resize_cv, resize_lite, out_size); uint8_t* diff_v = new uint8_t[out_size];
double diff = compare_diff(resize_cv, resize_lite, out_size, diff_v);
if (diff > 1) { if (diff > 1) {
std::cout << "din: " << std::endl; std::cout << "din: " << std::endl;
print_data(src, in_size); print_data(src, in_size);
...@@ -324,6 +332,8 @@ bool test_rotate(bool cv_run, ...@@ -324,6 +332,8 @@ bool test_rotate(bool cv_run,
print_data(resize_cv, out_size); print_data(resize_cv, out_size);
std::cout << "lite out: " << std::endl; std::cout << "lite out: " << std::endl;
print_data(resize_lite, out_size); print_data(resize_lite, out_size);
std::cout << "diff out: " << std::endl;
print_data(diff_v, out_size);
return false; return false;
} else { } else {
// save_img // save_img
...@@ -401,14 +411,17 @@ bool test_resize(bool cv_run, ...@@ -401,14 +411,17 @@ bool test_resize(bool cv_run,
if (cv_run) { if (cv_run) {
resize_cv = im_resize.data; resize_cv = im_resize.data;
double diff = compare_diff(resize_cv, resize_lite, out_size); uint8_t* diff_v = new uint8_t[out_size];
if (diff > 1) { double diff = compare_diff(resize_cv, resize_lite, out_size, diff_v);
if (diff > 10) {
std::cout << "din: " << std::endl; std::cout << "din: " << std::endl;
print_data(src, in_size); print_data(src, in_size);
std::cout << "cv out: " << std::endl; std::cout << "cv out: " << std::endl;
print_data(resize_cv, out_size); print_data(resize_cv, out_size);
std::cout << "lite out: " << std::endl; std::cout << "lite out: " << std::endl;
print_data(resize_lite, out_size); print_data(resize_lite, out_size);
std::cout << "diff out: " << std::endl;
print_data(diff_v, out_size);
return false; return false;
} else { } else {
// save_img // save_img
...@@ -545,7 +558,7 @@ void test_custom(bool has_img, // input is image ...@@ -545,7 +558,7 @@ void test_custom(bool has_img, // input is image
tparam1.rotate_param = rotate; tparam1.rotate_param = rotate;
ImagePreprocess image_preprocess(srcFormat, dstFormat, tparam); ImagePreprocess image_preprocess(srcFormat, dstFormat, tparam);
std::cout << "image convert testing"; std::cout << "image convert testing" << std::endl;
bool re = test_convert(cv_run, bool re = test_convert(cv_run,
src, src,
img, img,
...@@ -561,7 +574,7 @@ void test_custom(bool has_img, // input is image ...@@ -561,7 +574,7 @@ void test_custom(bool has_img, // input is image
if (!re) { if (!re) {
return; return;
} }
std::cout << "image resize testing"; std::cout << "image resize testing" << std::endl;
tparam.oh = dsth; tparam.oh = dsth;
tparam.ow = dstw; tparam.ow = dstw;
ImagePreprocess image_preprocess1(srcFormat, srcFormat, tparam1); ImagePreprocess image_preprocess1(srcFormat, srcFormat, tparam1);
...@@ -580,7 +593,7 @@ void test_custom(bool has_img, // input is image ...@@ -580,7 +593,7 @@ void test_custom(bool has_img, // input is image
return; return;
} }
std::cout << "image rotate testing"; std::cout << "image rotate testing" << std::endl;
if (rotate == 90 || rotate == 270) { if (rotate == 90 || rotate == 270) {
tparam.oh = srcw; tparam.oh = srcw;
tparam.ow = srch; tparam.ow = srch;
...@@ -611,7 +624,7 @@ void test_custom(bool has_img, // input is image ...@@ -611,7 +624,7 @@ void test_custom(bool has_img, // input is image
tparam.oh = srch; tparam.oh = srch;
tparam.ow = srcw; tparam.ow = srcw;
ImagePreprocess image_preprocess3(srcFormat, srcFormat, tparam); ImagePreprocess image_preprocess3(srcFormat, srcFormat, tparam);
std::cout << "image flip testing"; std::cout << "image flip testing" << std::endl;
re = test_flip(cv_run, re = test_flip(cv_run,
src, src,
img, img,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册