未验证 提交 9e38adc8 编写于 作者: H HappyAngel 提交者: GitHub

[arm] improve con3x3_dw (#4063)


* optimize conv_dw profiler

* fix build conv_dw_3x3s1 bug

* update conv_dw_3x3s2

* fxi foormat test=develop
上级 ee4cb1dc
...@@ -20,61 +20,117 @@ namespace lite { ...@@ -20,61 +20,117 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void conv_depthwise_3x3s1p0_bias(float *dout, void conv_depthwise_3x3s1p1_bias_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p0_bias_s(float *dout, void conv_depthwise_3x3s1p1_bias_s_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias(float *dout, void conv_depthwise_3x3s1p0_bias_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias_s(float *dout, void conv_depthwise_3x3s1p0_bias_s_relu6(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, const float *six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext *ctx); ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1p1_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1p0_bias_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1p0_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx);
void conv_depthwise_3x3s1_fp32(const float *din, void conv_depthwise_3x3s1_fp32(const float *din,
float *dout, float *dout,
...@@ -92,138 +148,270 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -92,138 +148,270 @@ void conv_depthwise_3x3s1_fp32(const float *din,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext *ctx) { ARMContext *ctx) {
bool has_active = act_param.has_active; bool has_active = act_param.has_active;
bool flag_relu = false; auto act_type = act_param.active_type;
bool relu6 = false; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
if (has_active) { if (has_active) {
if (act_param.active_type == lite_api::ActivationType::kRelu) { switch (act_type) {
flag_relu = true; case lite_api::ActivationType::kRelu:
} else { if (pad == 0) {
relu6 = true; if (w_in > 5) {
conv_depthwise_3x3s1p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kRelu6:
if (pad == 0) {
if (w_in > 5) {
conv_depthwise_3x3s1p0_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kLeakyRelu:
if (pad == 0) {
if (w_in > 5) {
conv_depthwise_3x3s1p0_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p0_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 4) {
conv_depthwise_3x3s1p1_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_type)
<< " fuse not support";
} }
} } else {
if (pad == 0) { if (pad == 0) {
if (w_in > 5) { if (w_in > 5) {
if (relu6) { conv_depthwise_3x3s1p0_bias_no_relu(dout,
conv_depthwise_3x3s1p0_bias(dout, din,
din, weights,
weights, bias,
bias, flag_bias,
flag_bias, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param, ctx);
ctx);
} else {
conv_depthwise_3x3s1p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s1p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s1p0_bias_s_relu(dout, conv_depthwise_3x3s1p0_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
} }
} }
} if (pad == 1) {
if (pad == 1) { if (w_in > 4) {
if (w_in > 4) { conv_depthwise_3x3s1p1_bias_no_relu(dout,
if (relu6) { din,
conv_depthwise_3x3s1p1_bias(dout, weights,
din, bias,
weights, flag_bias,
bias, false,
flag_bias, num,
num, ch_in,
ch_in, h_in,
h_in, w_in,
w_in, h_out,
h_out, w_out,
w_out, ctx);
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s1p1_bias_relu(dout, conv_depthwise_3x3s1p1_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s1p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s1p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} }
} }
} }
...@@ -1978,338 +2166,19 @@ void conv_depthwise_3x3s1_fp32(const float *din, ...@@ -1978,338 +2166,19 @@ void conv_depthwise_3x3s1_fp32(const float *din,
#endif #endif
#ifdef __aarch64__ void conv_depthwise_3x3s1p1_bias_relu6(float *dout,
void act_switch_3x3s1p1(const float *din_ptr0, const float *din,
const float *din_ptr1, const float *weights,
const float *din_ptr2, const float *bias,
const float *din_ptr3, const float *six,
const float *din_ptr4, bool flag_bias,
const float *din_ptr5, const int num,
float *doutr0, const int ch_in,
float *doutr1, const int h_in,
float *doutr2, const int w_in,
float *doutr3, const int h_out,
float32x4_t wr0, const int w_out,
float32x4_t wr1, ARMContext *ctx) {
float32x4_t wr2,
unsigned int *vmask,
unsigned int *rmask,
float32x4_t vzero,
float *vbias,
int cnt,
const operators::ActivationParam act_param) {
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#else
void act_switch_3x3s1p1(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask_ptr,
unsigned int *rmask_ptr,
float32x4_t vzero,
float bias_val,
int cnt,
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[six_ptr] "r"(vsix),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[scale_ptr] "r"(vscale),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#endif
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p1_bias(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! pad is done implicit //! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window //! for 4x6 convolution window
...@@ -2355,7 +2224,9 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2355,7 +2224,9 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
vst1q_u32(rmask, vmask_result); vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel; const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel; float *dout_batch = dout + n * ch_in * size_out_channel;
...@@ -2458,25 +2329,56 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2458,25 +2329,56 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
act_switch_3x3s1p1(din_ptr0, asm volatile(
din_ptr1, INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
din_ptr2, MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
din_ptr3, : [cnt] "+r"(cnt),
din_ptr4, [din_ptr0] "+r"(din_ptr0),
din_ptr5, [din_ptr1] "+r"(din_ptr1),
doutr0, [din_ptr2] "+r"(din_ptr2),
doutr1, [din_ptr3] "+r"(din_ptr3),
doutr2, [din_ptr4] "+r"(din_ptr4),
doutr3, [din_ptr5] "+r"(din_ptr5),
wr0, [doutr0] "+r"(doutr0),
wr1, [doutr1] "+r"(doutr1),
wr2, [doutr2] "+r"(doutr2),
vmask, [doutr3] "+r"(doutr3)
rmask, : [w0] "w"(wr0),
vzero, [w1] "w"(wr1),
vbias, [w2] "w"(wr2),
cnt, [vsix] "w"(vsix),
act_param); [bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
...@@ -2525,759 +2427,58 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2525,759 +2427,58 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
int cnt = cnt_col; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p1(din_ptr0, asm volatile(
din_ptr1, INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU6 MID_COMPUTE_S1
din_ptr2, MID_RESULT_S1_RELU6 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU6
din_ptr3, : [dout_ptr1] "+r"(doutr0),
doutr0, [dout_ptr2] "+r"(doutr1),
doutr1, [din0_ptr] "+r"(din_ptr0),
wr0, [din1_ptr] "+r"(din_ptr1),
wr1, [din2_ptr] "+r"(din_ptr2),
wr2, [din3_ptr] "+r"(din_ptr3),
vmask_ptr, [cnt] "+r"(cnt),
rmask_ptr, [rmask] "+r"(rmask_ptr),
vzero, [vmask] "+r"(vmask_ptr)
bias_val, : [wr0] "w"(wr0),
cnt, [wr1] "w"(wr1),
act_param); [wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[six_ptr] "r"(six),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
} }
} }
} }
void act_switch_3x3s1p1_s(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp,
float32x4_t vzero,
float32x4_t wbias,
const operators::ActivationParam act_param) {
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[six_ptr] "r"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20");
break;
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[scale_ptr] "r"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p1_bias_s(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
const float *dr0 = din_channel;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
for (int j = 0; j < h_out; j += 2) {
const float *dr0_ptr = dr0;
const float *dr1_ptr = dr1;
const float *dr2_ptr = dr2;
const float *dr3_ptr = dr3;
if (j == 0) {
dr0_ptr = zero;
dr1_ptr = dr0;
dr2_ptr = dr1;
dr3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
dr1_ptr = zero;
case 2:
dr2_ptr = zero;
case 1:
dr3_ptr = zero;
default:
break;
}
}
//! process bottom remain
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
act_switch_3x3s1p1_s(dr0_ptr,
dr1_ptr,
dr2_ptr,
dr3_ptr,
out_buf1,
out_buf2,
wr0,
wr1,
wr2,
vmask_rp,
vzero,
wbias,
act_param);
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
#ifdef __aarch64__
void act_switch_3x3s1p0(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
const float *din_ptr4,
const float *din_ptr5,
float *doutr0,
float *doutr1,
float *doutr2,
float *doutr3,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask,
unsigned int *rmask,
float32x4_t vzero,
float *vbias,
int cnt,
int remain,
const operators::ActivationParam act_param) {
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU6
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#else
void act_switch_3x3s1p0(const float *din_ptr0,
const float *din_ptr1,
const float *din_ptr2,
const float *din_ptr3,
float *doutr0,
float *doutr1,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
unsigned int *vmask_ptr,
unsigned int *rmask_ptr,
float32x4_t vzero,
float bias_val,
int cnt,
int remain,
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) { void conv_depthwise_3x3s1p1_bias_leakyRelu(float *dout,
case lite_api::ActivationType::kRelu: const float *din,
asm volatile(INIT_S1 const float *weights,
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" const float *bias,
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" const float *scale,
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" bool flag_bias,
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" const int num,
"vext.32 q6, q8, q9, #1 @ 0012\n" const int ch_in,
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 const int h_in,
MID_RESULT_S1_RELU const int w_in,
"cmp %[remain], #1 \n" const int h_out,
"blt 0f \n" RIGHT_COMPUTE_S1 const int w_out,
RIGHT_RESULT_S1_RELU "0: \n" ARMContext *ctx) {
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU6
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[six_ptr] "r"(vsix),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[scale_ptr] "r"(vscale),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#endif
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p0_bias(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! pad is done implicit //! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window //! for 4x6 convolution window
...@@ -3293,14 +2494,19 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3293,14 +2494,19 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
int tile_w = w_out >> 2; int tile_w = w_out >> 2;
int remain = w_out % 4; int remain = w_out % 4;
int cnt_col = tile_w - 1;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3}; const unsigned int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0 if (remain == 0 && size_pad_right == 5) {
tile_w -= 1; size_pad_right = 1;
cnt_col -= 1;
remain = 4; remain = 4;
} else if (remain == 0 && size_pad_right == 6) {
size_pad_right = 2; size_pad_right = 2;
cnt_col -= 1;
remain = 4;
} }
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp1 =
...@@ -3308,7 +2514,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3308,7 +2514,7 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result = uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx));
unsigned int vmask[8]; unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1); vst1q_u32(vmask, vmask_rp1);
...@@ -3318,7 +2524,9 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3318,7 +2524,9 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
vst1q_u32(rmask, vmask_result); vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel; const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel; float *dout_batch = dout + n * ch_in * size_out_channel;
...@@ -3355,7 +2563,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3355,7 +2563,6 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
const float *din_ptr3 = dr3; const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4; const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5; const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero); float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__ #ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) { for (int i = 0; i < h_out; i += 4) {
...@@ -3371,26 +2578,37 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3371,26 +2578,37 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
doutr1 = doutr0 + w_out; doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out; doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out; doutr3 = doutr2 + w_out;
if (i == 0) {
dr0 = dr4; din_ptr0 = zero_ptr;
dr1 = dr5; din_ptr1 = dr0;
dr2 = dr1 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
din_ptr4 = dr3;
din_ptr5 = dr4;
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
} else {
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
}
dr3 = dr2 + w_in; dr3 = dr2 + w_in;
dr4 = dr3 + w_in; dr4 = dr3 + w_in;
dr5 = dr4 + w_in; dr5 = dr4 + w_in;
//! process bottom pad //! process bottom pad
if (i + 5 >= h_in) { if (i + 5 > h_in) {
switch (i + 5 - h_in) { switch (i + 5 - h_in) {
case 4: case 5:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
case 3: case 4:
din_ptr2 = zero_ptr; din_ptr2 = zero_ptr;
case 2: case 3:
din_ptr3 = zero_ptr; din_ptr3 = zero_ptr;
case 1: case 2:
din_ptr4 = zero_ptr; din_ptr4 = zero_ptr;
case 0: case 1:
din_ptr5 = zero_ptr; din_ptr5 = zero_ptr;
default: default:
break; break;
...@@ -3410,31 +2628,62 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3410,31 +2628,62 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
} }
} }
int cnt = tile_w; int cnt = cnt_col;
act_switch_3x3s1p0(din_ptr0, asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
din_ptr1, MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
din_ptr2, RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU
din_ptr3, : [cnt] "+r"(cnt),
din_ptr4, [din_ptr0] "+r"(din_ptr0),
din_ptr5, [din_ptr1] "+r"(din_ptr1),
doutr0, [din_ptr2] "+r"(din_ptr2),
doutr1, [din_ptr3] "+r"(din_ptr3),
doutr2, [din_ptr4] "+r"(din_ptr4),
doutr3, [din_ptr5] "+r"(din_ptr5),
wr0, [doutr0] "+r"(doutr0),
wr1, [doutr1] "+r"(doutr1),
wr2, [doutr2] "+r"(doutr2),
vmask, [doutr3] "+r"(doutr3)
rmask, : [w0] "w"(wr0),
vzero, [w1] "w"(wr1),
vbias, [w2] "w"(wr2),
cnt, [vscale] "w"(vscale),
remain, [bias_val] "r"(vbias),
act_param); [vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
for (int i = 0; i < h_out; i += 2) { for (int i = 0; i < h_out; i += 2) {
//! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
din_ptr2 = dr2; din_ptr2 = dr2;
...@@ -3443,13 +2692,24 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3443,13 +2692,24 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
doutr0 = dout_ptr; doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out; doutr1 = dout_ptr + w_out;
dr0 = dr2; if (i == 0) {
dr1 = dr3; din_ptr0 = zero_ptr;
dr2 = dr1 + w_in; din_ptr1 = dr0;
dr3 = dr2 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
dr0 = dr1;
dr1 = dr2;
dr2 = dr3;
dr3 = dr2 + w_in;
} else {
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
}
//! process bottom pad //! process bottom pad
if (i + 4 > h_in) { if (i + 3 > h_in) {
switch (i + 4 - h_in) { switch (i + 3 - h_in) {
case 3: case 3:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
case 2: case 2:
...@@ -3464,292 +2724,1140 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3464,292 +2724,1140 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
if (i + 2 > h_out) { if (i + 2 > h_out) {
doutr1 = write_ptr; doutr1 = write_ptr;
} }
int cnt = tile_w; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0(din_ptr0, asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_LEAKY_RELU
din_ptr1, MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
din_ptr2, RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_LEAKY_RELU
din_ptr3, : [dout_ptr1] "+r"(doutr0),
doutr0, [dout_ptr2] "+r"(doutr1),
doutr1, [din0_ptr] "+r"(din_ptr0),
wr0, [din1_ptr] "+r"(din_ptr1),
wr1, [din2_ptr] "+r"(din_ptr2),
wr2, [din3_ptr] "+r"(din_ptr3),
vmask_ptr, [cnt] "+r"(cnt),
rmask_ptr, [rmask] "+r"(rmask_ptr),
vzero, [vmask] "+r"(vmask_ptr)
bias_val, : [wr0] "w"(wr0),
cnt, [wr1] "w"(wr1),
remain, [wr2] "w"(wr2),
act_param); [bias_val] "r"(bias_val),
[scale_ptr] "r"(scale),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
} }
} }
} }
void act_switch_3x3s1p0_s(const float *din_ptr0,
const float *din_ptr1, void conv_depthwise_3x3s1p1_bias_s_relu6(float *dout,
const float *din_ptr2, const float *din,
const float *din_ptr3, const float *weights,
float *doutr0, const float *bias,
float *doutr1, const float *six,
float32x4_t wr0, bool flag_bias,
float32x4_t wr1, const int num,
float32x4_t wr2, const int ch_in,
uint32x4_t vmask_rp1, const int h_in,
uint32x4_t vmask_rp2, const int w_in,
float32x4_t vzero, const int h_out,
float32x4_t wbias, const int w_out,
unsigned int *vmask_ptr, ARMContext *ctx) {
float bias_val, const int right_pad_idx[4] = {3, 2, 1, 0};
const operators::ActivationParam act_param) { const float zero[4] = {0.f, 0.f, 0.f, 0.f};
#ifdef __aarch64__
float32x4_t vsix = vdupq_n_f32(act_param.Relu_clipped_coef);
float32x4_t vscale = vdupq_n_f32(act_param.Leaky_relu_alpha);
#else
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
#endif
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[six_ptr] "r"(vsix),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
break;
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[scale_ptr] "r"(vscale),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
break;
#endif
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p0_bias_s(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
const operators::ActivationParam act_param,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
const float *dr0 = din_channel;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
for (int j = 0; j < h_out; j += 2) {
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
if (j == 0) {
din_ptr0 = zero;
din_ptr1 = dr0;
din_ptr2 = dr1;
din_ptr3 = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
din_ptr1 = zero;
case 2:
din_ptr2 = zero;
case 1:
din_ptr3 = zero;
default:
break;
}
}
//! process bottom remain
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU6
: [din0] "+r"(din_ptr0),
[din1] "+r"(din_ptr1),
[din2] "+r"(din_ptr2),
[din3] "+r"(din_ptr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[six_ptr] "r"(six),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p1_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
const float *dr0 = din_channel;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
for (int j = 0; j < h_out; j += 2) {
const float *dr0_ptr = dr0;
const float *dr1_ptr = dr1;
const float *dr2_ptr = dr2;
const float *dr3_ptr = dr3;
if (j == 0) {
dr0_ptr = zero;
dr1_ptr = dr0;
dr2_ptr = dr1;
dr3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
dr1_ptr = zero;
case 2:
dr2_ptr = zero;
case 1:
dr3_ptr = zero;
default:
break;
}
}
//! process bottom remain
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(dr0_ptr),
[din1] "+r"(dr1_ptr),
[din2] "+r"(dr2_ptr),
[din3] "+r"(dr3_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[vscale] "w"(vscale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(dr0_ptr),
[din1] "+r"(dr1_ptr),
[din2] "+r"(dr2_ptr),
[din3] "+r"(dr3_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[scale_ptr] "r"(scale),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p0_bias_relu6(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0
tile_w -= 1;
remain = 4;
size_pad_right = 2;
}
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU6
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vsix] "w"(vsix),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU6
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU6 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[six_ptr] "r"(six),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void conv_depthwise_3x3s1p0_bias_s_relu6(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
float bias_val = 0.f;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
bias_val = bias[i];
} else {
wbias = vdupq_n_f32(0.f);
}
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float *dr0 = din_channel + j * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 4 > h_in) {
switch (j + 4 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
default:
break;
}
}
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
unsigned int *vmask_ptr = vmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU6
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[six_ptr] "r"(six),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p0_bias_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0
tile_w -= 1;
remain = 4;
size_pad_right = 2;
}
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[vscale] "w"(vscale),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_LEAKY_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_LEAKY_RELU
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[scale_ptr] "r"(scale),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void conv_depthwise_3x3s1p0_bias_s_leakyRelu(float *dout,
const float *din,
const float *weights,
const float *bias,
const float *scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in)); vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
#endif
unsigned int vmask[8]; unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1); vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2); vst1q_u32(vmask + 4, vmask_rp2);
...@@ -3808,22 +3916,70 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, ...@@ -3808,22 +3916,70 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
doutr1 = trash_buf; doutr1 = trash_buf;
} }
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0_s(dr0, #ifdef __aarch64__
dr1, asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
dr2, : [din0] "+r"(dr0),
dr3, [din1] "+r"(dr1),
out_buf1, [din2] "+r"(dr2),
out_buf2, [din3] "+r"(dr3)
wr0, : [wr0] "w"(wr0),
wr1, [wr1] "w"(wr1),
wr2, [wr2] "w"(wr2),
vmask_rp1, [vbias] "w"(wbias),
vmask_rp2, [mask1] "w"(vmask_rp1),
vzero, [mask2] "w"(vmask_rp2),
wbias, [vzero] "w"(vzero),
vmask_ptr, [vscale] "w"(vscale),
bias_val, [out1] "r"(doutr0),
act_param); [out2] "r"(doutr1)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_LEAKY_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[scale_ptr] "r"(scale),
[bias_val] "r"(bias_val),
[out1] "r"(doutr0),
[out2] "r"(doutr1)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w]; *doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w]; *doutr1++ = out_buf2[w];
......
...@@ -1202,19 +1202,19 @@ namespace math { ...@@ -1202,19 +1202,19 @@ namespace math {
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4 * width > 4
*/ */
void conv_depthwise_3x3s1p1_bias_relu(float *dout, void conv_depthwise_3x3s1p1_bias_no_relu(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
ARMContext *ctx) { ARMContext *ctx) {
//! pad is done implicit //! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window //! for 4x6 convolution window
...@@ -1363,106 +1363,54 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout, ...@@ -1363,106 +1363,54 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
if (flag_relu) { asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
asm volatile( MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 : [cnt] "+r"(cnt),
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU [din_ptr0] "+r"(din_ptr0),
: [cnt] "+r"(cnt), [din_ptr1] "+r"(din_ptr1),
[din_ptr0] "+r"(din_ptr0), [din_ptr2] "+r"(din_ptr2),
[din_ptr1] "+r"(din_ptr1), [din_ptr3] "+r"(din_ptr3),
[din_ptr2] "+r"(din_ptr2), [din_ptr4] "+r"(din_ptr4),
[din_ptr3] "+r"(din_ptr3), [din_ptr5] "+r"(din_ptr5),
[din_ptr4] "+r"(din_ptr4), [doutr0] "+r"(doutr0),
[din_ptr5] "+r"(din_ptr5), [doutr1] "+r"(doutr1),
[doutr0] "+r"(doutr0), [doutr2] "+r"(doutr2),
[doutr1] "+r"(doutr1), [doutr3] "+r"(doutr3)
[doutr2] "+r"(doutr2), : [w0] "w"(wr0),
[doutr3] "+r"(doutr3) [w1] "w"(wr1),
: [w0] "w"(wr0), [w2] "w"(wr2),
[w1] "w"(wr1), [bias_val] "r"(vbias),
[w2] "w"(wr2), [vmask] "r"(vmask),
[bias_val] "r"(vbias), [rmask] "r"(rmask),
[vmask] "r"(vmask), [vzero] "w"(vzero)
[rmask] "r"(rmask), : "cc",
[vzero] "w"(vzero) "memory",
: "cc", "v0",
"memory", "v1",
"v0", "v2",
"v1", "v3",
"v2", "v4",
"v3", "v5",
"v4", "v6",
"v5", "v7",
"v6", "v8",
"v7", "v9",
"v8", "v10",
"v9", "v11",
"v10", "v12",
"v11", "v13",
"v12", "v14",
"v13", "v15",
"v14", "v16",
"v15", "v17",
"v16", "v18",
"v17", "v19",
"v18", "v20",
"v19", "v21",
"v20", "v22",
"v21", "v23",
"v22", "v24",
"v23", "v25");
"v24",
"v25");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
...@@ -1512,70 +1460,36 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout, ...@@ -1512,70 +1460,36 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout,
int cnt = cnt_col; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
if (flag_relu) { asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
asm volatile( MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1 : [dout_ptr1] "+r"(doutr0),
MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU [dout_ptr2] "+r"(doutr1),
: [dout_ptr1] "+r"(doutr0), [din0_ptr] "+r"(din_ptr0),
[dout_ptr2] "+r"(doutr1), [din1_ptr] "+r"(din_ptr1),
[din0_ptr] "+r"(din_ptr0), [din2_ptr] "+r"(din_ptr2),
[din1_ptr] "+r"(din_ptr1), [din3_ptr] "+r"(din_ptr3),
[din2_ptr] "+r"(din_ptr2), [cnt] "+r"(cnt),
[din3_ptr] "+r"(din_ptr3), [rmask] "+r"(rmask_ptr),
[cnt] "+r"(cnt), [vmask] "+r"(vmask_ptr)
[rmask] "+r"(rmask_ptr), : [wr0] "w"(wr0),
[vmask] "+r"(vmask_ptr) [wr1] "w"(wr1),
: [wr0] "w"(wr0), [wr2] "w"(wr2),
[wr1] "w"(wr1), [bias_val] "r"(bias_val),
[wr2] "w"(wr2), [vzero] "w"(vzero)
[bias_val] "r"(bias_val), : "cc",
[vzero] "w"(vzero) "memory",
: "cc", "q4",
"memory", "q5",
"q4", "q6",
"q5", "q7",
"q6", "q8",
"q7", "q9",
"q8", "q10",
"q9", "q11",
"q10", "q12",
"q11", "q13",
"q12", "q14",
"q13", "q15");
"q14",
"q15");
} else {
asm volatile(INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1 MID_COMPUTE_S1
MID_RESULT_S1 RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
...@@ -1583,221 +1497,7 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout, ...@@ -1583,221 +1497,7 @@ void conv_depthwise_3x3s1p1_bias_relu(float *dout,
} }
} }
/** void conv_depthwise_3x3s1p1_bias_relu(float *dout,
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p1_bias_s_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
}
#ifdef __aarch64__
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
}
#else
if (flag_relu) {
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p0_bias_relu(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
...@@ -1825,16 +1525,27 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1825,16 +1525,27 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
int tile_w = w_out >> 2; int tile_w = w_out >> 2;
int remain = w_out % 4; int remain = w_out % 4;
int cnt_col = tile_w - 1;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3}; const unsigned int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 5) {
size_pad_right = 1;
cnt_col -= 1;
remain = 4;
} else if (remain == 0 && size_pad_right == 6) {
size_pad_right = 2;
cnt_col -= 1;
remain = 4;
}
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result = uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx)); vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx));
unsigned int vmask[8]; unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1); vst1q_u32(vmask, vmask_rp1);
...@@ -1881,10 +1592,9 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1881,10 +1592,9 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
const float *din_ptr3 = dr3; const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4; const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5; const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero); float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__ #ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) { for (int i = 0; i < h_in; i += 4) {
//! process top pad pad_h = 1 //! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
...@@ -1897,26 +1607,37 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1897,26 +1607,37 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
doutr1 = doutr0 + w_out; doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out; doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out; doutr3 = doutr2 + w_out;
if (i == 0) {
dr0 = dr4; din_ptr0 = zero_ptr;
dr1 = dr5; din_ptr1 = dr0;
dr2 = dr1 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
din_ptr4 = dr3;
din_ptr5 = dr4;
dr0 = dr3;
dr1 = dr4;
dr2 = dr5;
} else {
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
}
dr3 = dr2 + w_in; dr3 = dr2 + w_in;
dr4 = dr3 + w_in; dr4 = dr3 + w_in;
dr5 = dr4 + w_in; dr5 = dr4 + w_in;
//! process bottom pad //! process bottom pad
if (i + 5 >= h_in) { if (i + 5 > h_in) {
switch (i + 5 - h_in) { switch (i + 5 - h_in) {
case 4: case 5:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
case 3: case 4:
din_ptr2 = zero_ptr; din_ptr2 = zero_ptr;
case 2: case 3:
din_ptr3 = zero_ptr; din_ptr3 = zero_ptr;
case 1: case 2:
din_ptr4 = zero_ptr; din_ptr4 = zero_ptr;
case 0: case 1:
din_ptr5 = zero_ptr; din_ptr5 = zero_ptr;
default: default:
break; break;
...@@ -1936,132 +1657,61 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -1936,132 +1657,61 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
} }
} }
int cnt = tile_w; int cnt = cnt_col;
if (flag_relu) { asm volatile(
asm volatile( INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
INIT_S1 MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/ : [cnt] "+r"(cnt),
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/ [din_ptr0] "+r"(din_ptr0),
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */ [din_ptr1] "+r"(din_ptr1),
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */ [din_ptr2] "+r"(din_ptr2),
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/ [din_ptr3] "+r"(din_ptr3),
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/ [din_ptr4] "+r"(din_ptr4),
MID_COMPUTE_S1 MID_RESULT_S1_RELU [din_ptr5] "+r"(din_ptr5),
"cmp %w[remain], #1 \n" [doutr0] "+r"(doutr0),
"blt 0f \n" RIGHT_COMPUTE_S1 [doutr1] "+r"(doutr1),
RIGHT_RESULT_S1_RELU "0: \n" [doutr2] "+r"(doutr2),
: [cnt] "+r"(cnt), [doutr3] "+r"(doutr3)
[din_ptr0] "+r"(din_ptr0), : [w0] "w"(wr0),
[din_ptr1] "+r"(din_ptr1), [w1] "w"(wr1),
[din_ptr2] "+r"(din_ptr2), [w2] "w"(wr2),
[din_ptr3] "+r"(din_ptr3), [bias_val] "r"(vbias),
[din_ptr4] "+r"(din_ptr4), [vmask] "r"(vmask),
[din_ptr5] "+r"(din_ptr5), [rmask] "r"(rmask),
[doutr0] "+r"(doutr0), [vzero] "w"(vzero)
[doutr1] "+r"(doutr1), : "cc",
[doutr2] "+r"(doutr2), "memory",
[doutr3] "+r"(doutr3) "v0",
: [w0] "w"(wr0), "v1",
[w1] "w"(wr1), "v2",
[w2] "w"(wr2), "v3",
[bias_val] "r"(vbias), "v4",
[vmask] "r"(vmask), "v5",
[rmask] "r"(rmask), "v6",
[vzero] "w"(vzero), "v7",
[remain] "r"(remain) "v8",
: "cc", "v9",
"memory", "v10",
"v0", "v11",
"v1", "v12",
"v2", "v13",
"v3", "v14",
"v4", "v15",
"v5", "v16",
"v6", "v17",
"v7", "v18",
"v8", "v19",
"v9", "v20",
"v10", "v21",
"v11", "v22",
"v12", "v23",
"v13", "v24",
"v14", "v25");
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
} else {
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
}
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
for (int i = 0; i < h_out; i += 2) { for (int i = 0; i < h_in; i += 2) {
//! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
din_ptr2 = dr2; din_ptr2 = dr2;
...@@ -2069,13 +1719,25 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -2069,13 +1719,25 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
doutr0 = dout_ptr; doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out; doutr1 = dout_ptr + w_out;
// unsigned int* rst_mask = rmask;
dr0 = dr2; if (i == 0) {
dr1 = dr3; din_ptr0 = zero_ptr;
dr2 = dr1 + w_in; din_ptr1 = dr0;
dr3 = dr2 + w_in; din_ptr2 = dr1;
din_ptr3 = dr2;
dr0 = dr1;
dr1 = dr2;
dr2 = dr3;
dr3 = dr2 + w_in;
} else {
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
}
//! process bottom pad //! process bottom pad
if (i + 3 >= h_in) { if (i + 3 > h_in) {
switch (i + 3 - h_in) { switch (i + 3 - h_in) {
case 3: case 3:
din_ptr1 = zero_ptr; din_ptr1 = zero_ptr;
...@@ -2083,8 +1745,6 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -2083,8 +1745,6 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
din_ptr2 = zero_ptr; din_ptr2 = zero_ptr;
case 1: case 1:
din_ptr3 = zero_ptr; din_ptr3 = zero_ptr;
case 0:
din_ptr3 = zero_ptr;
default: default:
break; break;
} }
...@@ -2093,131 +1753,73 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout, ...@@ -2093,131 +1753,73 @@ void conv_depthwise_3x3s1p0_bias_relu(float *dout,
if (i + 2 > h_out) { if (i + 2 > h_out) {
doutr1 = write_ptr; doutr1 = write_ptr;
} }
int cnt = tile_w; int cnt = cnt_col;
unsigned int *rmask_ptr = rmask; unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
if (flag_relu) { asm volatile(
asm volatile(INIT_S1 INIT_S1 LEFT_COMPUTE_S1 LEFT_RESULT_S1_RELU MID_COMPUTE_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n" MID_RESULT_S1_RELU RIGHT_COMPUTE_S1 RIGHT_RESULT_S1_RELU
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n" : [dout_ptr1] "+r"(doutr0),
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n" [dout_ptr2] "+r"(doutr1),
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n" [din0_ptr] "+r"(din_ptr0),
"vext.32 q6, q8, q9, #1 @ 0012\n" [din1_ptr] "+r"(din_ptr1),
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 [din2_ptr] "+r"(din_ptr2),
MID_RESULT_S1_RELU [din3_ptr] "+r"(din_ptr3),
"cmp %[remain], #1 \n" [cnt] "+r"(cnt),
"blt 0f \n" RIGHT_COMPUTE_S1 [rmask] "+r"(rmask_ptr),
RIGHT_RESULT_S1_RELU "0: \n" [vmask] "+r"(vmask_ptr)
: [dout_ptr1] "+r"(doutr0), : [wr0] "w"(wr0),
[dout_ptr2] "+r"(doutr1), [wr1] "w"(wr1),
[din0_ptr] "+r"(din_ptr0), [wr2] "w"(wr2),
[din1_ptr] "+r"(din_ptr1), [bias_val] "r"(bias_val),
[din2_ptr] "+r"(din_ptr2), [vzero] "w"(vzero)
[din3_ptr] "+r"(din_ptr3), : "cc",
[cnt] "+r"(cnt), "memory",
[rmask] "+r"(rmask_ptr), "q4",
[vmask] "+r"(vmask_ptr) "q5",
: [wr0] "w"(wr0), "q6",
[wr1] "w"(wr1), "q7",
[wr2] "w"(wr2), "q8",
[bias_val] "r"(bias_val), "q9",
[vzero] "w"(vzero), "q10",
[remain] "r"(remain) "q11",
: "cc", "q12",
"memory", "q13",
"q4", "q14",
"q5", "q15");
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} else {
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1 "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
dout_ptr += 2 * w_out; dout_ptr += 2 * w_out;
} //! end of processing mid rows } //! end of processing mid rows
#endif #endif
} }
} }
} }
/** /**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias, * \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4 * width <= 4
*/ */
void conv_depthwise_3x3s1p0_bias_s_relu(float *dout, void conv_depthwise_3x3s1p1_bias_s_no_relu(float *dout,
const float *din, const float *din,
const float *weights, const float *weights,
const float *bias, const float *bias,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
const int num, const int num,
const int ch_in, const int ch_in,
const int h_in, const int h_in,
const int w_in, const int w_in,
const int h_out, const int h_out,
const int w_out, const int w_out,
ARMContext *ctx) { ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm //! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit //! pad is done implicit
//! for 4x6 convolution window //! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0}; const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f}; const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in)); vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in; int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out; int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
...@@ -2231,38 +1833,907 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout, ...@@ -2231,38 +1833,907 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout,
float32x4_t wr0 = vld1q_f32(weight_ptr); float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3); float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6); float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
float32x4_t wbias; float32x4_t wbias;
if (flag_bias) { if (flag_bias) {
wbias = vdupq_n_f32(bias[i]); wbias = vdupq_n_f32(bias[i]);
} else { } else {
wbias = vdupq_n_f32(0.f); wbias = vdupq_n_f32(0.f);
} }
#endif // __aarch64__
int hs = -1;
int he = 3;
float out_buf1[4]; float out_buf1[4];
float out_buf2[4]; float out_buf2[4];
float trash_buf[4]; float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel; float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out; float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) { for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + j * w_in; const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in; const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in; const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in; const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out; if (hs == -1) {
doutr1 = doutr0 + w_out; dr0 = zero;
}
if (j + 3 >= h_in) { switch (he - h_in) {
switch (j + 3 - h_in) { case 2:
case 3: dr2 = zero;
dr1 = zero_ptr; doutr1 = trash_buf;
case 2: case 1:
dr2 = zero_ptr; dr3 = zero;
default:
break;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
void conv_depthwise_3x3s1p1_bias_s_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[4] = {3, 2, 1, 0};
const float zero[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(4 - w_in));
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
}
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[zero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17");
#else
asm volatile(COMPUTE_S_S1 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[mask] "w"(vmask_rp),
[bias] "w"(wbias),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w];
*doutr1++ = out_buf2[w];
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width > 4
*/
void conv_depthwise_3x3s1p0_bias_no_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 3 >= h_in) {
switch (i + 3 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
case 0:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(
INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1 MID_RESULT_S1
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1 RIGHT_RESULT_S1
"0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
void conv_depthwise_3x3s1p0_bias_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! pad is done implicit
const float zero[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
//! for 4x6 convolution window
const unsigned int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
float *zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float *write_ptr = zero_ptr + w_in;
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = w_out >> 2;
int remain = w_out % 4;
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_s32(vdupq_n_s32(remain), vld1q_s32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
unsigned int rmask[4];
vst1q_u32(rmask, vmask_result);
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int c = 0; c < ch_in; c++) {
float *dout_ptr = dout_batch + c * size_out_channel;
const float *din_ch_ptr = din_batch + c * size_in_channel;
float bias_val = flag_bias ? bias[c] : 0.f;
float vbias[4] = {bias_val, bias_val, bias_val, bias_val};
const float *wei_ptr = weights + c * w_stride;
float32x4_t wr0 = vld1q_f32(wei_ptr);
float32x4_t wr1 = vld1q_f32(wei_ptr + 3);
float32x4_t wr2 = vld1q_f32(wei_ptr + 6);
float *doutr0 = dout_ptr;
float *doutr1 = doutr0 + w_out;
float *doutr2 = doutr1 + w_out;
float *doutr3 = doutr2 + w_out;
const float *dr0 = din_ch_ptr;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr4 = dr3 + w_in;
const float *dr5 = dr4 + w_in;
const float *din_ptr0 = dr0;
const float *din_ptr1 = dr1;
const float *din_ptr2 = dr2;
const float *din_ptr3 = dr3;
const float *din_ptr4 = dr4;
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
din_ptr4 = dr4;
din_ptr5 = dr5;
doutr0 = dout_ptr;
doutr1 = doutr0 + w_out;
doutr2 = doutr1 + w_out;
doutr3 = doutr2 + w_out;
dr0 = dr4;
dr1 = dr5;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
dr5 = dr4 + w_in;
//! process bottom pad
if (i + 5 >= h_in) {
switch (i + 5 - h_in) {
case 4:
din_ptr1 = zero_ptr;
case 3:
din_ptr2 = zero_ptr;
case 2:
din_ptr3 = zero_ptr;
case 1:
din_ptr4 = zero_ptr;
case 0:
din_ptr5 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 4 > h_out) {
switch (i + 4 - h_out) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
int cnt = tile_w;
asm volatile(
INIT_S1
"ld1 {v8.4s}, [%[din_ptr4]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v10.4s}, [%[din_ptr5]], #16 \n" /*vld1q_f32(din_ptr0)*/
"ext v16.16b, v0.16b, v1.16b, #4 \n" /* v16 = 1234 */
"ext v17.16b, v0.16b, v1.16b, #8 \n" /* v17 = 2345 */
"ld1 {v9.4s}, [%[din_ptr4]] \n" /*vld1q_f32(din_ptr0)*/
"ld1 {v11.4s}, [%[din_ptr5]] \n" /*vld1q_f32(din_ptr0)*/
MID_COMPUTE_S1 MID_RESULT_S1_RELU
"cmp %w[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [cnt] "+r"(cnt),
[din_ptr0] "+r"(din_ptr0),
[din_ptr1] "+r"(din_ptr1),
[din_ptr2] "+r"(din_ptr2),
[din_ptr3] "+r"(din_ptr3),
[din_ptr4] "+r"(din_ptr4),
[din_ptr5] "+r"(din_ptr5),
[doutr0] "+r"(doutr0),
[doutr1] "+r"(doutr1),
[doutr2] "+r"(doutr2),
[doutr3] "+r"(doutr3)
: [w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[bias_val] "r"(vbias),
[vmask] "r"(vmask),
[rmask] "r"(rmask),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22",
"v23",
"v24",
"v25");
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_out; i += 2) {
din_ptr0 = dr0;
din_ptr1 = dr1;
din_ptr2 = dr2;
din_ptr3 = dr3;
doutr0 = dout_ptr;
doutr1 = dout_ptr + w_out;
dr0 = dr2;
dr1 = dr3;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (i + 3 >= h_in) {
switch (i + 3 - h_in) {
case 3:
din_ptr1 = zero_ptr;
case 2:
din_ptr2 = zero_ptr;
case 1:
din_ptr3 = zero_ptr;
case 0:
din_ptr3 = zero_ptr;
default:
break;
}
}
//! process bottom remain
if (i + 2 > h_out) {
doutr1 = write_ptr;
}
int cnt = tile_w;
unsigned int *rmask_ptr = rmask;
unsigned int *vmask_ptr = vmask;
asm volatile(INIT_S1
"sub %[din0_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din1_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din2_ptr], #8 @ 0pad + 2 float data overlap\n"
"sub %[din3_ptr], #8 @ 0pad + 2 float data overlap\n"
"vext.32 q6, q8, q9, #1 @ 0012\n"
"vext.32 q7, q8, q9, #2 @ 1234\n" MID_COMPUTE_S1
MID_RESULT_S1_RELU
"cmp %[remain], #1 \n"
"blt 0f \n" RIGHT_COMPUTE_S1
RIGHT_RESULT_S1_RELU "0: \n"
: [dout_ptr1] "+r"(doutr0),
[dout_ptr2] "+r"(doutr1),
[din0_ptr] "+r"(din_ptr0),
[din1_ptr] "+r"(din_ptr1),
[din2_ptr] "+r"(din_ptr2),
[din3_ptr] "+r"(din_ptr3),
[cnt] "+r"(cnt),
[rmask] "+r"(rmask_ptr),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias_val] "r"(bias_val),
[vzero] "w"(vzero),
[remain] "r"(remain)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
dout_ptr += 2 * w_out;
} //! end of processing mid rows
#endif
}
}
}
/**
* \brief depthwise convolution, kernel size 3x3, stride 1, pad 1, with bias,
* width <= 4
*/
void conv_depthwise_3x3s1p0_bias_s_no_relu(float *dout,
const float *din,
const float *weights,
const float *bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext *ctx) {
//! 3x3s1 convolution, implemented by direct algorithm
//! pad is done implicit
//! for 4x6 convolution window
const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#endif // __aarch64__
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float *dr0 = din_channel + j * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1: case 1:
dr3 = zero_ptr; dr3 = zero_ptr;
doutr1 = trash_buf; doutr1 = trash_buf;
...@@ -2276,133 +2747,227 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout, ...@@ -2276,133 +2747,227 @@ void conv_depthwise_3x3s1p0_bias_s_relu(float *dout,
} }
} }
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU : [din0] "+r"(dr0),
: [din0] "+r"(dr0), [din1] "+r"(dr1),
[din1] "+r"(dr1), [din2] "+r"(dr2),
[din2] "+r"(dr2), [din3] "+r"(dr3)
[din3] "+r"(dr3) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [vbias] "w"(wbias),
[vbias] "w"(wbias), [mask1] "w"(vmask_rp1),
[mask1] "w"(vmask_rp1), [mask2] "w"(vmask_rp2),
[mask2] "w"(vmask_rp2), [zero] "w"(vzero),
[zero] "w"(vzero), [out1] "r"(out_buf1),
[out1] "r"(out_buf1), [out2] "r"(out_buf2)
[out2] "r"(out_buf2) : "cc",
: "cc", "memory",
"memory", "v0",
"v0", "v1",
"v1", "v2",
"v2", "v3",
"v3", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15");
"v15");
} else {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else #else
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f; float bias_val = flag_bias ? bias[i] : 0.f;
if (flag_relu) { asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU : [din0] "+r"(dr0),
: [din0] "+r"(dr0), [din1] "+r"(dr1),
[din1] "+r"(dr1), [din2] "+r"(dr2),
[din2] "+r"(dr2), [din3] "+r"(dr3),
[din3] "+r"(dr3), [vmask] "+r"(vmask_ptr)
[vmask] "+r"(vmask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [vzero] "w"(vzero),
[vzero] "w"(vzero), [bias_val] "r"(bias_val),
[bias_val] "r"(bias_val), [out1] "r"(out_buf1),
[out1] "r"(out_buf1), [out2] "r"(out_buf2)
[out2] "r"(out_buf2) : "cc",
: "cc", "memory",
"memory", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); #endif
} else { for (int w = 0; w < w_out; ++w) {
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1 *doutr0++ = out_buf1[w];
: [din0] "+r"(dr0), *doutr1++ = out_buf2[w];
[din1] "+r"(dr1), }
[din2] "+r"(dr2), } // end of processing heights
[din3] "+r"(dr3), } // end of processing channels
[vmask] "+r"(vmask_ptr) } // end of processing batchs
: [wr0] "w"(wr0), }
[wr1] "w"(wr1),
[wr2] "w"(wr2), void conv_depthwise_3x3s1p0_bias_s_relu(float *dout,
[vzero] "w"(vzero), const float *din,
[bias_val] "r"(bias_val), const float *weights,
[out1] "r"(out_buf1), const float *bias,
[out2] "r"(out_buf2) bool flag_bias,
: "cc", bool flag_relu,
"memory", const int num,
"q4", const int ch_in,
"q5", const int h_in,
"q6", const int w_in,
"q7", const int h_out,
"q8", const int w_out,
"q9", ARMContext *ctx) {
"q10", //! 3x3s1 convolution, implemented by direct algorithm
"q11", //! pad is done implicit
"q12", //! for 4x6 convolution window
"q13", const int right_pad_idx[8] = {5, 4, 3, 2, 1, 0, 0, 0};
"q14", const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
"q15");
float32x4_t vzero = vdupq_n_f32(0.f);
uint32x4_t vmask_rp1 =
vcgeq_s32(vld1q_s32(right_pad_idx), vdupq_n_s32(6 - w_in));
uint32x4_t vmask_rp2 =
vcgeq_s32(vld1q_s32(right_pad_idx + 4), vdupq_n_s32(6 - w_in));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
vst1q_u32(vmask + 4, vmask_rp2);
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
for (int n = 0; n < num; ++n) {
const float *din_batch = din + n * ch_in * size_in_channel;
float *dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
float *dout_channel = dout_batch + i * size_out_channel;
const float *din_channel = din_batch + i * size_in_channel;
const float *weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#endif // __aarch64__
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_out; j += 2) {
const float *dr0 = din_channel + j * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
doutr1 = trash_buf;
case 0:
dr3 = zero_ptr;
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
default:
break;
}
} }
#ifdef __aarch64__
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vbias] "w"(wbias),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[zero] "w"(vzero),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
unsigned int *vmask_ptr = vmask;
float bias_val = flag_bias ? bias[i] : 0.f;
asm volatile(COMPUTE_S_S1_P0 RESULT_S_S1_RELU
: [din0] "+r"(dr0),
[din1] "+r"(dr1),
[din2] "+r"(dr2),
[din3] "+r"(dr3),
[vmask] "+r"(vmask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[vzero] "w"(vzero),
[bias_val] "r"(bias_val),
[out1] "r"(out_buf1),
[out2] "r"(out_buf2)
: "cc",
"memory",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif #endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*doutr0++ = out_buf1[w]; *doutr0++ = out_buf1[w];
......
...@@ -20,61 +20,117 @@ namespace paddle { ...@@ -20,61 +20,117 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void conv_depthwise_3x3s2p0_bias(float* dout, void conv_depthwise_3x3s2p0_bias_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_s(float* dout, void conv_depthwise_3x3s2p0_bias_leakyRelu(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* scale,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias(float* dout, void conv_depthwise_3x3s2p0_bias_s_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s(float* dout, void conv_depthwise_3x3s2p0_bias_s_leakyRelu(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* scale,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_relu6(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s_relu6(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* six,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din, void conv_depthwise_3x3s2_fp32(const float* din,
float* dout, float* dout,
...@@ -92,142 +148,275 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -92,142 +148,275 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
bool has_active = act_param.has_active; bool has_active = act_param.has_active;
bool flag_relu = false; auto act_type = act_param.active_type;
bool relu6 = false; float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
if (has_active) { if (has_active) {
if (act_param.active_type == lite_api::ActivationType::kRelu) { switch (act_type) {
flag_relu = true; case lite_api::ActivationType::kRelu:
} else { if (pad == 0) {
relu6 = true; if (w_in > 8) {
conv_depthwise_3x3s2p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
true,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kRelu6:
if (pad == 0) {
if (w_in > 8) {
conv_depthwise_3x3s2p0_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu6(dout,
din,
weights,
bias,
vsix,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
case lite_api::ActivationType::kLeakyRelu:
if (pad == 0) {
if (w_in > 8) {
conv_depthwise_3x3s2p0_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p0_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
if (pad == 1) {
if (w_in > 7) {
conv_depthwise_3x3s2p1_bias_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_leakyRelu(dout,
din,
weights,
bias,
vscale,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
}
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_type)
<< " fuse not support";
} }
} } else {
if (pad == 0) { if (pad == 0) {
if (w_in > 8) { if (w_in > 8) {
if (relu6) { conv_depthwise_3x3s2p0_bias_no_relu(dout,
conv_depthwise_3x3s2p0_bias(dout, din,
din, weights,
weights, bias,
bias, flag_bias,
flag_bias, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
act_param, ctx);
ctx);
} else {
conv_depthwise_3x3s2p0_bias_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p0_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s2p0_bias_s_relu(dout, conv_depthwise_3x3s2p0_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
} }
} }
} if (pad == 1) {
if (pad == 1) { if (w_in > 7) {
if (w_in > 7) { conv_depthwise_3x3s2p1_bias_no_relu(dout,
if (relu6) { din,
conv_depthwise_3x3s2p1_bias(dout, weights,
din, bias,
weights, flag_bias,
bias, false,
flag_bias, num,
num, ch_in,
ch_in, h_in,
h_in, w_in,
w_in, h_out,
h_out, w_out,
w_out, ctx);
act_param,
ctx);
} else { } else {
conv_depthwise_3x3s2p1_bias_relu(dout, conv_depthwise_3x3s2p1_bias_s_no_relu(dout,
din, din,
weights, weights,
bias, bias,
flag_bias, flag_bias,
flag_relu, false,
num, num,
ch_in, ch_in,
h_in, h_in,
w_in, w_in,
h_out, h_out,
w_out, w_out,
ctx); ctx);
}
} else {
if (relu6) {
conv_depthwise_3x3s2p1_bias_s(dout,
din,
weights,
bias,
flag_bias,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
act_param,
ctx);
} else {
conv_depthwise_3x3s2p1_bias_s_relu(dout,
din,
weights,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
ctx);
} }
} }
} }
} }
// clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
#define INIT_S2 \ #define INIT_S2 \
"prfm pldl1keep, [%[inptr0]] \n" \ "prfm pldl1keep, [%[inptr0]] \n" \
...@@ -746,6 +935,18 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -746,6 +935,18 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"fmax v4.4s, v4.4s, v9.4s \n" \ "fmax v4.4s, v4.4s, v9.4s \n" \
\ \
"st1 {v4.4s}, [%[out]] \n" "st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_RELU6 \
"fadd v4.4s, v4.4s, %[bias].4s \n" \
"fmax v4.4s, v4.4s, v9.4s \n" \
"fmin v4.4s, v4.4s, %[vsix].4s \n" \
\
"st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_LEAKY_RELU \
"fadd v4.4s, v4.4s, %[bias].4s \n" \
"fcmge v11.4s, v4.4s, %[vzero].4s \n"/* vcgeq_u32 */\
"fmul v12.4s, v4.4s, %[vscale].4s \n"\
"bif v4.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v4.4s}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \ #define COMPUTE_S_S2_P0 \
"movi v9.4s, #0 \n" \ "movi v9.4s, #0 \n" \
"ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \ "ld1 {v6.4s, v7.4s}, [%[mask_ptr]], #32 \n" \
...@@ -785,6 +986,15 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -785,6 +986,15 @@ void conv_depthwise_3x3s2_fp32(const float* din,
#define RESULT_S_S2_P0_RELU \ #define RESULT_S_S2_P0_RELU \
"fmax v4.4s, v4.4s, v9.4s \n" \ "fmax v4.4s, v4.4s, v9.4s \n" \
"st1 {v4.4s}, [%[out]] \n" "st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU6 \
"fmax v4.4s, v4.4s, v9.4s \n" \
"fmin v4.4s, v4.4s, %[vsix].4s \n" \
"st1 {v4.4s}, [%[out]] \n"
#define RESULT_S_S2_P0_LEAKY_RELU \
"fcmge v11.4s, v4.4s, %[vzero].4s \n"/* vcgeq_u32 */\
"fmul v12.4s, v4.4s, %[vscale].4s \n"\
"bif v4.16b, v12.16b, v11.16b \n" /* choose*/ \
"st1 {v4.4s}, [%[out]] \n"
#else #else
#define INIT_S2 \ #define INIT_S2 \
...@@ -822,14 +1032,13 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -822,14 +1032,13 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \ "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 1, out1\n" \
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \ "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 1, out1\n" \
\ \
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n" \ "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r1\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define LEFT_RESULT_S2 \ #define LEFT_RESULT_S2 \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vadd.f32 q3, q3, q4 @ add \n"\
"cmp %[cnt], #1 \n" \ "vadd.f32 q3, q3, q5 @ add \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \
"cmp %[cnt], #1 \n" \
"blt 1f \n" "blt 1f \n"
#define MID_COMPUTE_S2 \ #define MID_COMPUTE_S2 \
...@@ -860,12 +1069,11 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -860,12 +1069,11 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \
\ \
"vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n" \ "vld2.32 {d28-d31}, [%[din2_ptr]]! @ load din r2\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define MID_RESULT_S2 \ #define MID_RESULT_S2 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"subs %[cnt], #1 \n" \ "subs %[cnt], #1 \n" \
\ \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vst1.32 {d6-d7}, [%[outptr]]! \n" \
...@@ -910,36 +1118,104 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -910,36 +1118,104 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q6, %f[wr2][0] @ mul weight 2, out0\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RIGHT_RESULT_S2 \ #define RIGHT_RESULT_S2 \
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vbif.f32 q3, q10, q11 @ write mask\n" \ "vbif.f32 q3, q10, q11 @ write mask\n" \
\ \
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vst1.32 {d6-d7}, [%[outptr]]! \n" \
"3: \n" "3: \n"
#define LEFT_RESULT_S2_RELU \ #define LEFT_RESULT_S2_RELU \
"vmax.f32 q3, q3, q9 @ relu \n" \ "vadd.f32 q3, q3, q4 @ add \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vadd.f32 q3, q3, q5 @ add \n"\
"cmp %[cnt], #1 \n" \ "vmax.f32 q3, q3, q9 \n"\
"blt 1f \n" "cmp %[cnt], #1 \n"\
#define MID_RESULT_S2_RELU \ "vst1.32 {d6-d7}, [%[outptr]]! \n"\
"vmax.f32 q3, q3, q9 @ relu \n" \ "blt 1f \n"
"subs %[cnt], #1 \n" \ #define LEFT_RESULT_S2_RELU6 \
\ "vadd.f32 q3, q3, q4 @ add \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ "vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"bne 2b \n" "vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu \n"\
#define RIGHT_RESULT_S2_RELU \ "cmp %[cnt], #1 \n"\
"vmax.f32 q3, q3, q9 @ relu \n" \ "vmin.f32 q3, q3, q6 @ relu \n"\
"vbif.f32 q3, q10, q11 @ write mask\n" \ "vst1.32 {d6-d7}, [%[outptr]]! \n"\
\ "blt 1f \n"
"vst1.32 {d6-d7}, [%[outptr]]! \n" \ #define LEFT_RESULT_S2_LEAKY_RELU \
"3: \n" "vadd.f32 q3, q3, q4 \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"cmp %[cnt], #1 \n"\
"vbif q3, q8, q7 @ choose \n"\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"blt 1f \n"
#define MID_RESULT_S2_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"subs %[cnt], #1 \n"\
"vmax.f32 q3, q3, q9 @ relu \n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"bne 2b \n"
#define MID_RESULT_S2_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu \n"\
"subs %[cnt], #1 \n"\
"vmin.f32 q3, q3, q6 @ relu \n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"bne 2b \n"
#define MID_RESULT_S2_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"subs %[cnt], #1 \n"\
"vbif q3, q8, q7 @ choose \n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
"vbif.f32 q3, q10, q11 @ write mask\n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"3: \n"
#define RIGHT_RESULT_S2_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
"vmin.f32 q3, q3, q6 @ relu \n"\
\
"vbif.f32 q3, q10, q11 @ write mask\n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"3: \n"
#define RIGHT_RESULT_S2_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"vbif q3, q8, q7 @ choose \n"\
"vbif.f32 q3, q10, q11 @ write mask\n"\
\
"vst1.32 {d6-d7}, [%[outptr]]! \n"\
"3: \n"
#define COMPUTE_S_S2 \ #define COMPUTE_S_S2 \
"vmov.u32 q9, #0 \n" \ "vmov.u32 q9, #0 \n" \
"vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \ "vld1.f32 {d12-d15}, [%[mask_ptr]]! @ load mask\n" \
...@@ -976,17 +1252,36 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -976,17 +1252,36 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q4, q14, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %f[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q8, %e[wr2][0] @ mul weight 2, out0\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2 "vst1.32 {d6-d7}, [%[out]] \n" #define RESULT_S_S2 \
#define RESULT_S_S2_RELU \ "vadd.f32 q3, q3, q4 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n" \ "vadd.f32 q3, q3, q5 @ add \n"\
\ "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
\
"vst1.32 {d6-d7}, [%[out]] \n" "vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vmax.f32 q3, q3, q9 @ relu\n"\
"vmin.f32 q3, q3, q6 @ relu\n"\
\
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n"\
"vld1.f32 {d12-d13}, [%[scale_ptr]] \n"\
"vadd.f32 q3, q3, q5 @ add \n"\
"vcge.f32 q7, q3, q9 \n"\
"vmul.f32 q8, q3, q6 \n"\
"vbif q3, q8, q7 @ choose \n"\
\
"vst1.32 {d6-d7}, [%[out]] \n"
#define COMPUTE_S_S2_P0 \ #define COMPUTE_S_S2_P0 \
"vmov.u32 q9, #0 \n" \ "vmov.u32 q9, #0 \n" \
"vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \ "vld1.f32 {d12-d15}, [%[mask_ptr]] @ load mask\n" \
...@@ -1023,207 +1318,309 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -1023,207 +1318,309 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q4, q14, %e[wr2][0] @ mul weight 2, out0\n" \
"vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \ "vmla.f32 q5, q15, %e[wr2][1] @ mul weight 2, out0\n" \
"vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n" \ "vmla.f32 q3, q8, %f[wr2][0] @ mul weight 2, out0\n"
\
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n"
#define RESULT_S_S2_P0 "vst1.32 {d6-d7}, [%[out]] \n" #define RESULT_S_S2_P0 \
#define RESULT_S_S2_P0_RELU \ "vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU \
"vadd.f32 q3, q3, q4 @ add \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vmax.f32 q3, q3, q9 @ relu \n" \ "vmax.f32 q3, q3, q9 @ relu \n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_RELU6 \
"vadd.f32 q3, q3, q4 @ add \n" \
"vld1.f32 {d12-d13}, [%[six_ptr]] @ load six \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vmax.f32 q3, q3, q9 @ relu\n" \
"vmin.f32 q3, q3, q6 @ relu\n" \
"vst1.32 {d6-d7}, [%[out]] \n"
#define RESULT_S_S2_P0_LEAKY_RELU \
"vadd.f32 q3, q3, q4 @ add \n" \
"vld1.f32 {d12-d13}, [%[scale_ptr]] @ load six \n" \
"vadd.f32 q3, q3, q5 @ add \n" \
"vcge.f32 q7, q3, q9 \n" \
"vmul.f32 q8, q3, q6 \n" \
"vbif q3, q8, q7 @ choose \n" \
"vst1.32 {d6-d7}, [%[out]] \n" "vst1.32 {d6-d7}, [%[out]] \n"
#endif #endif
#ifdef __aarch64__ // clang-format on
void act_switch_3x3s2p1(const float* din0_ptr,
const float* din1_ptr,
const float* din2_ptr,
const float* din3_ptr,
const float* din4_ptr,
float* doutr0_ptr,
float* doutr1_ptr,
float32x4_t wr0,
float32x4_t wr1,
float32x4_t wr2,
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2,
uint32x4_t wmask,
float32x4_t wbias,
float32x4_t vzero,
int cnt,
int cnt_remain,
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef;
float ss = act_param.Leaky_relu_alpha;
float vsix[4] = {tmp, tmp, tmp, tmp};
float vscale[4] = {ss, ss, ss, ss};
switch (act_param.active_type) {
case lite_api::ActivationType::kRelu:
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
break;
case lite_api::ActivationType::kRelu6:
/* 0 <= din <= 6 */
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
}
}
#endif
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
* w_in > 7 * w_in > 7
*/ */
void conv_depthwise_3x3s2p1_bias(float* dout, void conv_depthwise_3x3s2p1_bias_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
int cnt_col = tile_w - 1;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else {
dr0 = dr4;
dr1 = dr0 + w_in;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i * 2 + 4 > h_in) {
switch (i * 2 + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(six),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
dr0 = dr1;
dr1 = dr2;
dr2 = dr1 + w_in;
} else {
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
}
//! process bottom pad
if (i * 2 + 2 > h_in) {
switch (i * 2 + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
asm volatile(
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU6 MID_COMPUTE_S2
MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU6
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[six_ptr] "r"(six),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
void conv_depthwise_3x3s2p1_bias_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in; int size_pad_bottom = h_out * 2 - h_in;
...@@ -1350,24 +1747,52 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1350,24 +1747,52 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = cnt_col; int cnt = cnt_col;
act_switch_3x3s2p1(din0_ptr, asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
din1_ptr, MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
din2_ptr, RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
din3_ptr, : [inptr0] "+r"(din0_ptr),
din4_ptr, [inptr1] "+r"(din1_ptr),
doutr0_ptr, [inptr2] "+r"(din2_ptr),
doutr1_ptr, [inptr3] "+r"(din3_ptr),
wr0, [inptr4] "+r"(din4_ptr),
wr1, [outptr0] "+r"(doutr0_ptr),
wr2, [outptr1] "+r"(doutr1_ptr),
vmask_rp1, [cnt] "+r"(cnt)
vmask_rp2, : [vzero] "w"(vzero),
wmask, [w0] "w"(wr0),
wbias, [w1] "w"(wr1),
vzero, [w2] "w"(wr2),
cnt, [remain] "r"(cnt_remain),
cnt_remain, [scale_ptr] "r"(scale),
act_param); [mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1404,8 +1829,9 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1404,8 +1829,9 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_LEAKY_RELU
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1416,6 +1842,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1416,6 +1842,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
[wr0] "w"(wr0), [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[scale_ptr] "r"(scale),
[bias] "r"(bias_c) [bias] "r"(bias_c)
: "cc", : "cc",
"memory", "memory",
...@@ -1432,10 +1859,6 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1432,10 +1859,6 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
"q13", "q13",
"q14", "q14",
"q15"); "q15");
// do act
if (act_param.has_active) {
act_switch_process(doutr0, doutr0, w_out, &act_param);
}
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
...@@ -1446,19 +1869,19 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1446,19 +1869,19 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
/** /**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4 * \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/ */
void conv_depthwise_3x3s2p1_bias_s(float* dout, void conv_depthwise_3x3s2p1_bias_s_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f}; float zeros[8] = {0.0f};
...@@ -1474,7 +1897,9 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1474,7 +1897,9 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
unsigned int dmask[8]; unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1); vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2); vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel; const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel; float* dout_batch = dout + n * ch_in * size_out_channel;
...@@ -1513,7 +1938,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1513,7 +1938,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S2 RESULT_S_S2 asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1522,6 +1947,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1522,6 +1947,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "w"(vbias), [bias] "w"(vbias),
[vsix] "w"(vsix),
[out] "r"(out_buf) [out] "r"(out_buf)
: "v4", : "v4",
"v5", "v5",
...@@ -1536,7 +1962,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1536,7 +1962,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
"v14", "v14",
"v15"); "v15");
#else #else
asm volatile(COMPUTE_S_S2 RESULT_S_S2 asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1545,6 +1971,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1545,6 +1971,7 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "r"(bias_c), [bias] "r"(bias_c),
[six_ptr] "r"(six),
[out] "r"(out_buf) [out] "r"(out_buf)
: "cc", : "cc",
"memory", "memory",
...@@ -1562,10 +1989,6 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1562,10 +1989,6 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
// do act
if (act_param.has_active) {
act_switch_process(out_buf, out_buf, w_out, &act_param);
}
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
} }
...@@ -1575,231 +1998,154 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout, ...@@ -1575,231 +1998,154 @@ void conv_depthwise_3x3s2p1_bias_s(float* dout,
} }
} }
} }
void conv_depthwise_3x3s2p1_bias_s_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__ #ifdef __aarch64__
void act_switch_3x3s2p0(const float* din0_ptr, float32x4_t vscale = vld1q_f32(scale);
const float* din1_ptr, float32x4_t vzero = vdupq_n_f32(0.f);
const float* din2_ptr, #endif
const float* din3_ptr, for (int n = 0; n < num; ++n) {
const float* din4_ptr, const float* din_batch = din + n * ch_in * size_in_channel;
float* doutr0_ptr, float* dout_batch = dout + n * ch_in * size_out_channel;
float* doutr1_ptr, #pragma omp parallel for
float32x4_t wr0, for (int i = 0; i < ch_in; ++i) {
float32x4_t wr1, const float* din_channel = din_batch + i * size_in_channel;
float32x4_t wr2, float* dout_channel = dout_batch + i * size_out_channel;
uint32x4_t vmask_rp1,
uint32x4_t vmask_rp2, const float* weight_ptr = weights + i * 9;
uint32x4_t wmask, float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wbias, float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t vzero, float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
int cnt,
int cnt_remain, float bias_c = 0.f;
const operators::ActivationParam act_param) {
float tmp = act_param.Relu_clipped_coef; if (flag_bias) {
float ss = act_param.Leaky_relu_alpha; bias_c = bias[i];
float vsix[4] = {tmp, tmp, tmp, tmp}; }
float vscale[4] = {ss, ss, ss, ss}; float32x4_t vbias = vdupq_n_f32(bias_c);
int hs = -1;
int he = 2;
float out_buf[4];
for (int j = 0; j < h_out; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
if (hs == -1) {
dr0 = zeros;
}
if (he > h_in) {
dr2 = zeros;
}
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
switch (act_param.active_type) { unsigned int* mask_ptr = dmask;
case lite_api::ActivationType::kRelu: #ifdef __aarch64__
asm volatile( asm volatile(COMPUTE_S_S2 RESULT_S_S2_LEAKY_RELU
INIT_S2 : [din0_ptr] "+r"(din0_ptr),
"ld1 {v15.4s}, [%[inptr0]] \n" [din1_ptr] "+r"(din1_ptr),
"ld1 {v18.4s}, [%[inptr1]] \n" [din2_ptr] "+r"(din2_ptr),
"ld1 {v19.4s}, [%[inptr2]] \n" [mask_ptr] "+r"(mask_ptr)
"ld1 {v20.4s}, [%[inptr3]] \n" : [wr0] "w"(wr0),
"ld1 {v21.4s}, [%[inptr4]] \n" [wr1] "w"(wr1),
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} [wr2] "w"(wr2),
MID_COMPUTE_S2 MID_RESULT_S2_RELU [bias] "w"(vbias),
"cmp %w[remain], #1 \n" [vzero] "w"(vzero),
"blt 4f \n" RIGHT_COMPUTE_S2 [vscale] "w"(vscale),
RIGHT_RESULT_S2_RELU [out] "r"(out_buf)
"4: \n" : "v4",
: [inptr0] "+r"(din0_ptr), "v5",
[inptr1] "+r"(din1_ptr), "v6",
[inptr2] "+r"(din2_ptr), "v7",
[inptr3] "+r"(din3_ptr), "v8",
[inptr4] "+r"(din4_ptr), "v9",
[outptr0] "+r"(doutr0_ptr), "v10",
[outptr1] "+r"(doutr1_ptr), "v11",
[cnt] "+r"(cnt) "v12",
: [vzero] "w"(vzero), "v13",
[w0] "w"(wr0), "v14",
[w1] "w"(wr1), "v15");
[w2] "w"(wr2), #else
[remain] "r"(cnt_remain), asm volatile(COMPUTE_S_S2 RESULT_S_S2_LEAKY_RELU
[mask1] "w"(vmask_rp1), : [din0_ptr] "+r"(din0_ptr),
[mask2] "w"(vmask_rp2), [din1_ptr] "+r"(din1_ptr),
[wmask] "w"(wmask), [din2_ptr] "+r"(din2_ptr),
[vbias] "w"(wbias) [mask_ptr] "+r"(mask_ptr)
: "cc", : [wr0] "w"(wr0),
"memory", [wr1] "w"(wr1),
"v0", [wr2] "w"(wr2),
"v1", [bias] "r"(bias_c),
"v2", [scale_ptr] "r"(scale),
"v3", [out] "r"(out_buf)
"v4", : "cc",
"v5", "memory",
"v6", "q3",
"v7", "q4",
"v8", "q5",
"v9", "q6",
"v10", "q7",
"v11", "q8",
"v12", "q9",
"v13", "q10",
"v14", "q11",
"v15", "q12",
"v16", "q13",
"v17", "q14",
"v18", "q15");
"v19", #endif
"v20", for (int w = 0; w < w_out; ++w) {
"v21"); *dout_channel++ = out_buf[w];
break; }
case lite_api::ActivationType::kRelu6: hs += 2;
/* 0 <= din <= 6 */ he += 2;
asm volatile( }
INIT_S2 }
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_RELU6
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU6
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vsix),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
case lite_api::ActivationType::kLeakyRelu:
/*din = din >= 0 ? din : din * scale*/
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
break;
default:
LOG(FATAL) << "this act_type: " << static_cast<int>(act_param.active_type)
<< " fuse not support";
} }
} }
#endif
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
*/ */
// w_in > 7 // w_in > 7
void conv_depthwise_3x3s2p0_bias(float* dout, void conv_depthwise_3x3s2p0_bias_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
...@@ -1918,24 +2264,63 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1918,24 +2264,63 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = tile_w; int cnt = tile_w;
act_switch_3x3s2p0(din0_ptr, asm volatile(
din1_ptr, INIT_S2
din2_ptr, "ld1 {v15.4s}, [%[inptr0]] \n"
din3_ptr, "ld1 {v18.4s}, [%[inptr1]] \n"
din4_ptr, "ld1 {v19.4s}, [%[inptr2]] \n"
doutr0_ptr, "ld1 {v20.4s}, [%[inptr3]] \n"
doutr1_ptr, "ld1 {v21.4s}, [%[inptr4]] \n"
wr0, "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
wr1, "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2
wr2, MID_RESULT_S2_RELU6
vmask_rp1, "cmp %w[remain], #1 \n"
vmask_rp2, "blt 4f \n" RIGHT_COMPUTE_S2
wmask, RIGHT_RESULT_S2_RELU6
wbias, "4: \n"
vzero, : [inptr0] "+r"(din0_ptr),
cnt, [inptr1] "+r"(din1_ptr),
cnt_remain, [inptr2] "+r"(din2_ptr),
act_param); [inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(six),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1963,8 +2348,8 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1963,8 +2348,8 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
} }
int cnt = tile_w; int cnt = tile_w;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2 asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU6 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 RIGHT_RESULT_S2_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -1972,6 +2357,7 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1972,6 +2357,7 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
[cnt] "+r"(cnt), [cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr) [mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain), : [remain] "r"(cnt_remain),
[six_ptr] "r"(six),
[wr0] "w"(wr0), [wr0] "w"(wr0),
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
...@@ -1991,9 +2377,257 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1991,9 +2377,257 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
"q13", "q13",
"q14", "q14",
"q15"); "q15");
if (act_param.has_active) { doutr0 = doutr0 + w_out;
act_switch_process(doutr0, doutr0, w_out, &act_param); }
#endif
}
}
}
void conv_depthwise_3x3s2p0_bias_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
dr0 = dr4;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i * 2 + 5 > h_in) {
switch (i * 2 + 5 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
case 0:
din4_ptr = zero_ptr;
default:
break;
}
} }
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[scale_ptr] "r"(scale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21",
"v22");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
//! process bottom pad
if (i * 2 + 3 > h_in) {
switch (i * 2 + 3 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[scale_ptr] "r"(scale),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
...@@ -2004,19 +2638,19 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -2004,19 +2638,19 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
/** /**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4 * \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/ */
void conv_depthwise_3x3s2p0_bias_s(float* dout, void conv_depthwise_3x3s2p0_bias_s_relu6(float* dout,
const float* din, const float* din,
const float* weights, const float* weights,
const float* bias, const float* bias,
bool flag_bias, const float* six,
const int num, bool flag_bias,
const int ch_in, const int num,
const int h_in, const int ch_in,
const int w_in, const int h_in,
const int h_out, const int w_in,
const int w_out, const int h_out,
const operators::ActivationParam act_param, const int w_out,
ARMContext* ctx) { ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f}; float zeros[8] = {0.0f};
...@@ -2033,6 +2667,10 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2033,6 +2667,10 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
unsigned int dmask[8]; unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1); vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2); vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__
float32x4_t vsix = vld1q_f32(six);
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel; const float* din_batch = din + n * ch_in * size_in_channel;
...@@ -2077,7 +2715,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2077,7 +2715,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr), [din2_ptr] "+r"(din2_ptr),
...@@ -2086,6 +2724,8 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2086,6 +2724,8 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
[wr1] "w"(wr1), [wr1] "w"(wr1),
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "w"(vbias), [bias] "w"(vbias),
[vzero] "w"(vzero),
[vsix] "w"(vsix),
[out] "r"(out_buf) [out] "r"(out_buf)
: "cc", : "cc",
"memory", "memory",
...@@ -2104,7 +2744,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2104,7 +2744,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
"v16"); "v16");
#else #else
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU6
: [din0_ptr] "+r"(din0_ptr), : [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr), [din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr) [din2_ptr] "+r"(din2_ptr)
...@@ -2113,6 +2753,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2113,6 +2753,7 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
[wr2] "w"(wr2), [wr2] "w"(wr2),
[bias] "r"(bias_c), [bias] "r"(bias_c),
[out] "r"(out_buf), [out] "r"(out_buf),
[six_ptr] "r"(six),
[mask_ptr] "r"(dmask) [mask_ptr] "r"(dmask)
: "cc", : "cc",
"memory", "memory",
...@@ -2130,9 +2771,145 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout, ...@@ -2130,9 +2771,145 @@ void conv_depthwise_3x3s2p0_bias_s(float* dout,
"q14", "q14",
"q15"); "q15");
#endif #endif
if (act_param.has_active) { for (int w = 0; w < w_out; ++w) {
act_switch_process(out_buf, out_buf, w_out, &act_param); *dout_channel++ = out_buf[w];
}
}
}
}
}
void conv_depthwise_3x3s2p0_bias_s_leakyRelu(float* dout,
const float* din,
const float* weights,
const float* bias,
const float* scale,
bool flag_bias,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
#ifdef __aarch64__
float32x4_t vscale = vld1q_f32(scale);
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
float out_buf[4];
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
for (int j = 0; j < h_out; j++) {
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
if (j * 2 + 2 >= h_in) {
switch (j + 2 - h_in) {
case 1:
din1_ptr = zero_ptr;
case 0:
din2_ptr = zero_ptr;
default:
break;
}
} }
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[vzero] "w"(vzero),
[vscale] "w"(vscale),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_LEAKY_RELU
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[scale_ptr] "r"(scale),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
} }
......
...@@ -20,6 +20,7 @@ namespace lite { ...@@ -20,6 +20,7 @@ namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
// clang-format off
#ifdef __aarch64__ #ifdef __aarch64__
#define INIT_S2 \ #define INIT_S2 \
"prfm pldl1keep, [%[inptr0]] \n" \ "prfm pldl1keep, [%[inptr0]] \n" \
...@@ -683,6 +684,7 @@ namespace math { ...@@ -683,6 +684,7 @@ namespace math {
"vst1.32 {d6-d7}, [%[out]] \n" "vst1.32 {d6-d7}, [%[out]] \n"
#endif #endif
// clang-format on
/** /**
* \brief depthwise convolution kernel 3x3, stride 2 * \brief depthwise convolution kernel 3x3, stride 2
...@@ -825,96 +827,50 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout, ...@@ -825,96 +827,50 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = cnt_col; int cnt = cnt_col;
if (flag_relu) { asm volatile(
asm volatile( INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU : [inptr0] "+r"(din0_ptr),
: [inptr0] "+r"(din0_ptr), [inptr1] "+r"(din1_ptr),
[inptr1] "+r"(din1_ptr), [inptr2] "+r"(din2_ptr),
[inptr2] "+r"(din2_ptr), [inptr3] "+r"(din3_ptr),
[inptr3] "+r"(din3_ptr), [inptr4] "+r"(din4_ptr),
[inptr4] "+r"(din4_ptr), [outptr0] "+r"(doutr0_ptr),
[outptr0] "+r"(doutr0_ptr), [outptr1] "+r"(doutr1_ptr),
[outptr1] "+r"(doutr1_ptr), [cnt] "+r"(cnt)
[cnt] "+r"(cnt) : [vzero] "w"(vzero),
: [vzero] "w"(vzero), [w0] "w"(wr0),
[w0] "w"(wr0), [w1] "w"(wr1),
[w1] "w"(wr1), [w2] "w"(wr2),
[w2] "w"(wr2), [remain] "r"(cnt_remain),
[remain] "r"(cnt_remain), [mask1] "w"(vmask_rp1),
[mask1] "w"(vmask_rp1), [mask2] "w"(vmask_rp2),
[mask2] "w"(vmask_rp2), [wmask] "w"(wmask),
[wmask] "w"(wmask), [vbias] "w"(wbias)
[vbias] "w"(wbias) : "cc",
: "cc", "memory",
"memory", "v0",
"v0", "v1",
"v1", "v2",
"v2", "v3",
"v3", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15",
"v15", "v16",
"v16", "v17",
"v17", "v18",
"v18", "v19",
"v19", "v20",
"v20", "v21");
"v21");
} else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -951,66 +907,286 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout, ...@@ -951,66 +907,286 @@ void conv_depthwise_3x3s2p1_bias_relu(float* dout,
} }
int cnt = cnt_col; int cnt = cnt_col;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
if (flag_relu) { asm volatile(
asm volatile( INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2
INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2_RELU MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
MID_RESULT_S2_RELU RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [outptr] "+r"(doutr0_ptr),
[outptr] "+r"(doutr0_ptr), [cnt] "+r"(cnt),
[cnt] "+r"(cnt), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [remain] "r"(cnt_remain),
: [remain] "r"(cnt_remain), [wr0] "w"(wr0),
[wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c)
[bias] "r"(bias_c) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); doutr0 = doutr0 + w_out;
}
#endif
}
}
}
void conv_depthwise_3x3s2p1_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4); //
int size_right_pad = w_out * 2 - w_in;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
din3_ptr = dr2;
din4_ptr = dr3;
dr0 = dr3;
dr1 = dr4;
} else { } else {
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2 dr0 = dr4;
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2 dr1 = dr0 + w_in;
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
if (i == 0) {
din0_ptr = zero_ptr;
din1_ptr = dr0;
din2_ptr = dr1;
dr0 = dr1;
dr1 = dr2;
dr2 = dr1 + w_in;
} else {
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
}
//! process bottom pad
if (i + 2 > h_in) {
switch (i + 2 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = cnt_col;
unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 LEFT_COMPUTE_S2 LEFT_RESULT_S2 MID_COMPUTE_S2
MID_RESULT_S2 RIGHT_COMPUTE_S2 RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
...@@ -1088,107 +1264,179 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, ...@@ -1088,107 +1264,179 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "w"(vbias),
[bias] "w"(vbias), [out] "r"(out_buf)
[out] "r"(out_buf) : "v4",
: "v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15");
"v15");
} else {
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
}
#else #else
if (flag_relu) { asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU
asm volatile(COMPUTE_S_S2 RESULT_S_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c),
[bias] "r"(bias_c), [out] "r"(out_buf)
[out] "r"(out_buf) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); #endif
} else { for (int w = 0; w < w_out; ++w) {
asm volatile(COMPUTE_S_S2 RESULT_S_S2 *dout_channel++ = out_buf[w];
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
} }
hs += 2;
he += 2;
}
}
}
}
void conv_depthwise_3x3s2p1_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
float zeros[8] = {0.0f};
uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
int hs = -1;
int he = 2;
float out_buf[4];
for (int j = 0; j < h_out; ++j) {
const float* dr0 = din_channel + hs * w_in;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
if (hs == -1) {
dr0 = zeros;
}
if (he > h_in) {
dr2 = zeros;
}
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(COMPUTE_S_S2 RESULT_S_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif #endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
...@@ -1334,117 +1582,60 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout, ...@@ -1334,117 +1582,60 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout,
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = tile_w; int cnt = tile_w;
if (flag_relu) { asm volatile(
asm volatile( INIT_S2
INIT_S2 "ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v15.4s}, [%[inptr0]] \n" "ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n" "ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n" "ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n" "ld1 {v21.4s}, [%[inptr4]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n" "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} MID_COMPUTE_S2 MID_RESULT_S2_RELU
MID_COMPUTE_S2 MID_RESULT_S2_RELU "cmp %w[remain], #1 \n"
"cmp %w[remain], #1 \n" "blt 4f \n" RIGHT_COMPUTE_S2
"blt 4f \n" RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU
RIGHT_RESULT_S2_RELU "4: \n"
"4: \n" : [inptr0] "+r"(din0_ptr),
: [inptr0] "+r"(din0_ptr), [inptr1] "+r"(din1_ptr),
[inptr1] "+r"(din1_ptr), [inptr2] "+r"(din2_ptr),
[inptr2] "+r"(din2_ptr), [inptr3] "+r"(din3_ptr),
[inptr3] "+r"(din3_ptr), [inptr4] "+r"(din4_ptr),
[inptr4] "+r"(din4_ptr), [outptr0] "+r"(doutr0_ptr),
[outptr0] "+r"(doutr0_ptr), [outptr1] "+r"(doutr1_ptr),
[outptr1] "+r"(doutr1_ptr), [cnt] "+r"(cnt)
[cnt] "+r"(cnt) : [vzero] "w"(vzero),
: [vzero] "w"(vzero), [w0] "w"(wr0),
[w0] "w"(wr0), [w1] "w"(wr1),
[w1] "w"(wr1), [w2] "w"(wr2),
[w2] "w"(wr2), [remain] "r"(cnt_remain),
[remain] "r"(cnt_remain), [mask1] "w"(vmask_rp1),
[mask1] "w"(vmask_rp1), [mask2] "w"(vmask_rp2),
[mask2] "w"(vmask_rp2), [wmask] "w"(wmask),
[wmask] "w"(wmask), [vbias] "w"(wbias)
[vbias] "w"(wbias) : "cc",
: "cc", "memory",
"memory", "v0",
"v0", "v1",
"v1", "v2",
"v2", "v3",
"v3", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15",
"v15", "v16",
"v16", "v17",
"v17", "v18",
"v18", "v19",
"v19", "v20",
"v20", "v21");
"v21");
} else {
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
"4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
}
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
...@@ -1472,72 +1663,284 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout, ...@@ -1472,72 +1663,284 @@ void conv_depthwise_3x3s2p0_bias_relu(float* dout,
} }
int cnt = tile_w; int cnt = tile_w;
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
if (flag_relu) { asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_COMPUTE_S2
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2_RELU RIGHT_RESULT_S2_RELU
RIGHT_COMPUTE_S2 RIGHT_RESULT_S2_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [outptr] "+r"(doutr0_ptr),
[outptr] "+r"(doutr0_ptr), [cnt] "+r"(cnt),
[cnt] "+r"(cnt), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [remain] "r"(cnt_remain),
: [remain] "r"(cnt_remain), [wr0] "w"(wr0),
[wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c)
[bias] "r"(bias_c) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15");
} else {
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
}
doutr0 = doutr0 + w_out; doutr0 = doutr0 + w_out;
} }
#endif #endif
} }
} }
} }
void conv_depthwise_3x3s2p0_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx) {
int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
int out_pad_idx[4] = {0, 1, 2, 3};
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
uint32x4_t wmask =
vcgtq_s32(vdupq_n_s32(cnt_remain), vld1q_s32(out_pad_idx)); // 0 1 2 3
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
float* zero_ptr = ctx->workspace_data<float>();
memset(zero_ptr, 0, w_in * sizeof(float));
float* write_ptr = zero_ptr + w_in;
unsigned int dmask[12];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
vst1q_u32(dmask + 8, wmask);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float32x4_t vzero = vdupq_n_f32(0.f);
#ifdef __aarch64__
float32x4_t wbias;
if (flag_bias) {
wbias = vdupq_n_f32(bias[i]);
} else {
wbias = vdupq_n_f32(0.f);
}
#else
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
#endif // __aarch64__
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
const float* dr3 = dr2 + w_in;
const float* dr4 = dr3 + w_in;
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
const float* din3_ptr = dr3;
const float* din4_ptr = dr4;
float* doutr0 = dout_channel;
float* doutr0_ptr = nullptr;
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
din3_ptr = dr3;
din4_ptr = dr4;
doutr0_ptr = doutr0;
doutr1_ptr = doutr0 + w_out;
dr0 = dr4;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
dr4 = dr3 + w_in;
//! process bottom pad
if (i * 2 + 5 > h_in) {
switch (i * 2 + 5 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
din2_ptr = zero_ptr;
case 2:
din3_ptr = zero_ptr;
case 1:
din4_ptr = zero_ptr;
case 0:
din4_ptr = zero_ptr;
default:
break;
}
}
//! process output pad
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = tile_w;
asm volatile(
INIT_S2
"ld1 {v15.4s}, [%[inptr0]] \n"
"ld1 {v18.4s}, [%[inptr1]] \n"
"ld1 {v19.4s}, [%[inptr2]] \n"
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2 "4: \n"
: [inptr0] "+r"(din0_ptr),
[inptr1] "+r"(din1_ptr),
[inptr2] "+r"(din2_ptr),
[inptr3] "+r"(din3_ptr),
[inptr4] "+r"(din4_ptr),
[outptr0] "+r"(doutr0_ptr),
[outptr1] "+r"(doutr1_ptr),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero),
[w0] "w"(wr0),
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
[vbias] "w"(wbias)
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16",
"v17",
"v18",
"v19",
"v20",
"v21");
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
doutr0_ptr = doutr0;
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
//! process bottom pad
if (i * 2 + 3 > h_in) {
switch (i * 2 + 3 - h_in) {
case 2:
din1_ptr = zero_ptr;
case 1:
din2_ptr = zero_ptr;
default:
break;
}
}
int cnt = tile_w;
unsigned int* mask_ptr = dmask;
asm volatile(INIT_S2 MID_COMPUTE_S2 MID_RESULT_S2 RIGHT_COMPUTE_S2
RIGHT_RESULT_S2
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[outptr] "+r"(doutr0_ptr),
[cnt] "+r"(cnt),
[mask_ptr] "+r"(mask_ptr)
: [remain] "r"(cnt_remain),
[wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
doutr0 = doutr0 + w_out;
}
#endif
}
}
}
/** /**
* \brief depthwise convolution kernel 3x3, stride 2, width <= 4 * \brief depthwise convolution kernel 3x3, stride 2, width <= 4
*/ */
...@@ -1614,113 +2017,189 @@ void conv_depthwise_3x3s2p0_bias_s_relu(float* dout, ...@@ -1614,113 +2017,189 @@ void conv_depthwise_3x3s2p0_bias_s_relu(float* dout,
unsigned int* mask_ptr = dmask; unsigned int* mask_ptr = dmask;
#ifdef __aarch64__ #ifdef __aarch64__
if (flag_relu) { asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr),
[din2_ptr] "+r"(din2_ptr), [mask_ptr] "+r"(mask_ptr)
[mask_ptr] "+r"(mask_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "w"(vbias),
[bias] "w"(vbias), [out] "r"(out_buf)
[out] "r"(out_buf) : "cc",
: "cc", "memory",
"memory", "v4",
"v4", "v5",
"v5", "v6",
"v6", "v7",
"v7", "v8",
"v8", "v9",
"v9", "v10",
"v10", "v11",
"v11", "v12",
"v12", "v13",
"v13", "v14",
"v14", "v15",
"v15", "v16");
"v16");
} else {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
}
#else #else
if (flag_relu) { asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0_RELU : [din0_ptr] "+r"(din0_ptr),
: [din0_ptr] "+r"(din0_ptr), [din1_ptr] "+r"(din1_ptr),
[din1_ptr] "+r"(din1_ptr), [din2_ptr] "+r"(din2_ptr)
[din2_ptr] "+r"(din2_ptr) : [wr0] "w"(wr0),
: [wr0] "w"(wr0), [wr1] "w"(wr1),
[wr1] "w"(wr1), [wr2] "w"(wr2),
[wr2] "w"(wr2), [bias] "r"(bias_c),
[bias] "r"(bias_c), [out] "r"(out_buf),
[out] "r"(out_buf), [mask_ptr] "r"(dmask)
[mask_ptr] "r"(dmask) : "cc",
: "cc", "memory",
"memory", "q3",
"q3", "q4",
"q4", "q5",
"q5", "q6",
"q6", "q7",
"q7", "q8",
"q8", "q9",
"q9", "q10",
"q10", "q11",
"q11", "q12",
"q12", "q13",
"q13", "q14",
"q14", "q15");
"q15"); #endif
} else { for (int w = 0; w < w_out; ++w) {
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0 *dout_channel++ = out_buf[w];
: [din0_ptr] "+r"(din0_ptr), }
[din1_ptr] "+r"(din1_ptr), }
[din2_ptr] "+r"(din2_ptr) }
: [wr0] "w"(wr0), }
[wr1] "w"(wr1), }
[wr2] "w"(wr2), void conv_depthwise_3x3s2p0_bias_s_no_relu(float* dout,
[bias] "r"(bias_c), const float* din,
[out] "r"(out_buf), const float* weights,
[mask_ptr] "r"(dmask) const float* bias,
: "cc", bool flag_bias,
"memory", bool flag_relu,
"q3", const int num,
"q4", const int ch_in,
"q5", const int h_in,
"q6", const int w_in,
"q7", const int h_out,
"q8", const int w_out,
"q9", ARMContext* ctx) {
"q10", int right_pad_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
"q11", int out_pad_idx[4] = {0, 1, 2, 3};
"q12", float zeros[8] = {0.0f};
"q13", const float zero_ptr[4] = {0.f, 0.f, 0.f, 0.f};
"q14",
"q15"); uint32x4_t vmask_rp1 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx)); // 0 2 4 6
uint32x4_t vmask_rp2 =
vcgtq_s32(vdupq_n_s32(w_in), vld1q_s32(right_pad_idx + 4)); // 1 3 5 7
int size_in_channel = w_in * h_in;
int size_out_channel = w_out * h_out;
unsigned int dmask[8];
vst1q_u32(dmask, vmask_rp1);
vst1q_u32(dmask + 4, vmask_rp2);
for (int n = 0; n < num; ++n) {
const float* din_batch = din + n * ch_in * size_in_channel;
float* dout_batch = dout + n * ch_in * size_out_channel;
#pragma omp parallel for
for (int i = 0; i < ch_in; ++i) {
const float* din_channel = din_batch + i * size_in_channel;
float* dout_channel = dout_batch + i * size_out_channel;
const float* weight_ptr = weights + i * 9;
float32x4_t wr0 = vld1q_f32(weight_ptr);
float32x4_t wr1 = vld1q_f32(weight_ptr + 3);
float32x4_t wr2 = vld1q_f32(weight_ptr + 6);
float bias_c = 0.f;
if (flag_bias) {
bias_c = bias[i];
}
float32x4_t vbias = vdupq_n_f32(bias_c);
float out_buf[4];
const float* dr0 = din_channel;
const float* dr1 = dr0 + w_in;
const float* dr2 = dr1 + w_in;
for (int j = 0; j < h_out; j++) {
const float* din0_ptr = dr0;
const float* din1_ptr = dr1;
const float* din2_ptr = dr2;
if (j * 2 + 2 >= h_in) {
switch (j + 2 - h_in) {
case 1:
din1_ptr = zero_ptr;
case 0:
din2_ptr = zero_ptr;
default:
break;
}
} }
dr0 = dr2;
dr1 = dr0 + w_in;
dr2 = dr1 + w_in;
unsigned int* mask_ptr = dmask;
#ifdef __aarch64__
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr),
[mask_ptr] "+r"(mask_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "w"(vbias),
[out] "r"(out_buf)
: "cc",
"memory",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15",
"v16");
#else
asm volatile(COMPUTE_S_S2_P0 RESULT_S_S2_P0
: [din0_ptr] "+r"(din0_ptr),
[din1_ptr] "+r"(din1_ptr),
[din2_ptr] "+r"(din2_ptr)
: [wr0] "w"(wr0),
[wr1] "w"(wr1),
[wr2] "w"(wr2),
[bias] "r"(bias_c),
[out] "r"(out_buf),
[mask_ptr] "r"(dmask)
: "cc",
"memory",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11",
"q12",
"q13",
"q14",
"q15");
#endif #endif
for (int w = 0; w < w_out; ++w) { for (int w = 0; w < w_out; ++w) {
*dout_channel++ = out_buf[w]; *dout_channel++ = out_buf[w];
......
...@@ -323,6 +323,118 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout, ...@@ -323,6 +323,118 @@ void conv_depthwise_3x3s2p1_bias_s_relu(float* dout,
const int w_out, const int w_out,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1p0_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p0_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p1_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s1p1_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p0_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
void conv_depthwise_3x3s2p1_bias_s_no_relu(float* dout,
const float* din,
const float* weights,
const float* bias,
bool flag_bias,
bool flag_relu,
const int num,
const int ch_in,
const int h_in,
const int w_in,
const int h_out,
const int w_out,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册