提交 c35d8e14 编写于 作者: H HappyAngel 提交者: yiicy

[arm] add conv_5x5s2_dw to support any padding (#2770)

1. add conv_5x5s2_dw to support any padding
2. add 1x1s2pooling impl
3. fix conv dw 3x3 s1p01 bug
上级 f926fb40
......@@ -2339,17 +2339,29 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
int size_out_channel = w_out * h_out;
int w_stride = 9;
int tile_w = (w_in + 3) >> 2;
int cnt_col = tile_w - 2;
int tile_w = w_out >> 2;
int remain = w_out % 4;
int cnt_col = tile_w - 1;
unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in);
unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in);
const unsigned int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 5) {
size_pad_right = 1;
cnt_col -= 1;
remain = 4;
} else if (remain == 0 && size_pad_right == 6) {
size_pad_right = 2;
cnt_col -= 1;
remain = 4;
}
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result =
vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx));
unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1);
......@@ -2398,7 +2410,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
......@@ -2484,7 +2496,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
dout_ptr = dout_ptr + 4 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
for (int i = 0; i < h_out; i += 2) {
//! process top pad pad_h = 1
din_ptr0 = dr0;
din_ptr1 = dr1;
......@@ -2883,39 +2895,57 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
wbias = vdupq_n_f32(0.f);
}
int hs = -1;
int he = 3;
float out_buf1[4];
float out_buf2[4];
float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) {
const float *dr0 = din_channel + hs * w_in;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
const float *dr0 = din_channel;
const float *dr1 = dr0 + w_in;
const float *dr2 = dr1 + w_in;
const float *dr3 = dr2 + w_in;
if (hs == -1) {
dr0 = zero;
for (int j = 0; j < h_out; j += 2) {
const float *dr0_ptr = dr0;
const float *dr1_ptr = dr1;
const float *dr2_ptr = dr2;
const float *dr3_ptr = dr3;
if (j == 0) {
dr0_ptr = zero;
dr1_ptr = dr0;
dr2_ptr = dr1;
dr3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
dr1_ptr = zero;
case 2:
dr2_ptr = zero;
case 1:
dr3_ptr = zero;
default:
break;
}
}
switch (he - h_in) {
case 2:
dr2 = zero;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
//! process bottom remain
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
act_switch_3x3s1p1_s(dr0,
dr1,
dr2,
dr3,
act_switch_3x3s1p1_s(dr0_ptr,
dr1_ptr,
dr2_ptr,
dr3_ptr,
out_buf1,
out_buf2,
wr0,
......@@ -2931,8 +2961,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
}
doutr0 = doutr1;
doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights
} // end of processing channels
} // end of processing batchs
......@@ -3458,6 +3486,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0
tile_w -= 1;
remain = 4;
size_pad_right = 2;
}
uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 =
......@@ -4016,22 +4050,21 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) {
switch (j + 3 - h_in) {
if (j + 4 > h_in) {
switch (j + 4 - h_in) {
case 3:
dr1 = zero_ptr;
case 2:
dr2 = zero_ptr;
case 1:
dr3 = zero_ptr;
doutr1 = trash_buf;
case 0:
dr3 = zero_ptr;
doutr1 = trash_buf;
default:
break;
}
}
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0_s(dr0,
dr1,
......
......@@ -1202,15 +1202,17 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2;
int size_right_remain = w_in - (7 + cnt_col * 8);
if (size_right_remain >= 9) {
cnt_col++;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4);
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
int size_right_pad = w_out * 2 - w_in;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
int cnt_col = tile_w - 1;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
......@@ -1276,7 +1278,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
float* doutr1_ptr = nullptr;
#ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) {
for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
......@@ -1303,8 +1305,8 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
dr4 = dr3 + w_in;
//! process bottom pad
if (i + 4 > h_in) {
switch (i + 4 - h_in) {
if (i * 2 + 4 > h_in) {
switch (i * 2 + 4 - h_in) {
case 4:
din1_ptr = zero_ptr;
case 3:
......@@ -1318,7 +1320,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
}
}
//! process output pad
if (i / 2 + 2 > h_out) {
if (i + 2 > h_out) {
doutr1_ptr = write_ptr;
}
int cnt = cnt_col;
......@@ -1343,7 +1345,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
doutr0 = doutr0 + 2 * w_out;
}
#else
for (int i = 0; i < h_in; i += 2) {
for (int i = 0; i < h_out; i++) {
din0_ptr = dr0;
din1_ptr = dr1;
din2_ptr = dr2;
......@@ -1641,7 +1643,8 @@ void act_switch_3x3s2p0(const float* din0_ptr,
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU6
"ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_RELU6
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU6
......@@ -1700,7 +1703,8 @@ void act_switch_3x3s2p0(const float* din0_ptr,
"ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU
"ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU
......@@ -1718,7 +1722,7 @@ void act_switch_3x3s2p0(const float* din0_ptr,
[w1] "w"(wr1),
[w2] "w"(wr2),
[remain] "r"(cnt_remain),
[six_ptr] "r"(vscale),
[scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2),
[wmask] "w"(wmask),
......@@ -1834,7 +1838,14 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
int tile_w = w_out >> 2;
int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3));
unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -150,11 +150,26 @@ void conv_depthwise_5x5s2_fp32(const float* din,
int win,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_5x5s2p2_fp32(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din,
......
......@@ -589,10 +589,9 @@ void conv_depthwise_3x3_fp32(const void* din,
int stride = param.strides[1];
int pad = pad_w;
bool flag_bias = param.bias != nullptr;
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (stride == 1) {
if (pads_equal && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -624,9 +623,8 @@ void conv_depthwise_3x3_fp32(const void* din,
act_param,
ctx);
}
} else if (stride == 2) {
if (pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -678,12 +676,13 @@ void conv_depthwise_5x5_fp32(const void* din,
ARMContext* ctx,
const float* scale) {
auto paddings = *param.paddings;
auto act_param = param.activation_param;
int pad = paddings[0];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
ctx->ExtendWorkspace((w_in + w_out) * sizeof(float));
if (pad == 2 && stride == 2) {
if (stride == 2) {
conv_depthwise_5x5s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -695,9 +694,8 @@ void conv_depthwise_5x5_fp32(const void* din,
w_in,
reinterpret_cast<const float*>(weights),
bias,
pad,
flag_bias,
flag_relu,
param,
act_param,
ctx);
} else if (stride == 1) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<const float*>(din),
......
......@@ -898,6 +898,119 @@ void pooling_global_avg(const float* din,
}
}
void pooling1x1s2p0_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
int win_ext = w_unroll_size * 8;
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
memset(zero_ptr, 0, win * sizeof(float));
auto write_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), wout * sizeof(float)));
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
for (int h = 0; h < hout; h += 4) {
const float* din0_ptr = data_in_channel + h * 2 * win;
const float* din1_ptr = din0_ptr + 2 * win;
const float* din2_ptr = din1_ptr + 2 * win;
const float* din3_ptr = din2_ptr + 2 * win;
float* doutr0 = data_out_channel + h * wout;
float* doutr1 = doutr0 + wout;
float* doutr2 = doutr1 + wout;
float* doutr3 = doutr2 + wout;
if (h + 4 > hout) {
switch (h + 4 - hout) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
if (h * 2 + 4 >= hin) {
switch (h * 2 + 4 - hin) {
case 4:
din0_ptr = zero_ptr;
case 3:
case 2:
din1_ptr = zero_ptr;
case 1:
case 0:
din2_ptr = zero_ptr;
din3_ptr = zero_ptr;
default:
break;
}
}
for (int i = 0; i < w_unroll_size; i++) {
float32x4x2_t din0 = vld2q_f32(din0_ptr);
float32x4x2_t din1 = vld2q_f32(din1_ptr);
float32x4x2_t din2 = vld2q_f32(din2_ptr);
float32x4x2_t din3 = vld2q_f32(din3_ptr);
din0_ptr += 8;
din1_ptr += 8;
din2_ptr += 8;
din3_ptr += 8;
vst1q_f32(doutr0, din0.val[0]);
vst1q_f32(doutr1, din1.val[0]);
vst1q_f32(doutr2, din2.val[0]);
vst1q_f32(doutr3, din3.val[0]);
doutr0 += 4;
doutr1 += 4;
doutr2 += 4;
doutr3 += 4;
}
int j = win_ext;
for (int i = 0; i < w_unroll_remian; i++) {
if (j >= win) {
*doutr0++ = 0.f;
*doutr1++ = 0.f;
*doutr2++ = 0.f;
*doutr3++ = 0.f;
} else {
*doutr0++ = *din0_ptr;
*doutr1++ = *din1_ptr;
*doutr2++ = *din2_ptr;
*doutr3++ = *din3_ptr;
din0_ptr += 2;
din1_ptr += 2;
din2_ptr += 2;
din3_ptr += 2;
}
j += 2;
}
}
}
}
TargetFree(TARGET(kARM), zero_ptr);
TargetFree(TARGET(kARM), write_ptr);
}
void pooling2x2s2_max(const float* din,
float* dout,
int num,
......
......@@ -64,6 +64,16 @@ void pooling_global_avg(const float* din,
int hin,
int win);
void pooling1x1s2p0_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win);
void pooling2x2s2_max(const float* din,
float* dout,
int num,
......
......@@ -79,6 +79,9 @@ cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr("act_type", act_type_);
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "relu6") {
float alpha = act_op_desc.GetAttr<float>("threshold");
op_desc.SetAttr("fuse_brelu_threshold", alpha);
} else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha");
op_desc.SetAttr("leaky_relu_alpha", alpha);
......
......@@ -56,13 +56,12 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2));
bool flag_dw_5x5 = pads_all_equal && ((kw == 5 && stride == 1) ||
(kw == 5 && stride == 2 && pad == 2));
bool flag_dw_5x5 = (paddings[0] == paddings[2]) &&
((kw == 5 && stride == 1) || (kw == 5 && stride == 2));
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl
if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
no_dilation && flag_dw) {
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
/// dw conv impl
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking dw conv";
......
......@@ -28,16 +28,13 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
auto kw = w_dims[3];
auto paddings = *param.paddings;
// select dw conv kernel
if (kw == 3) {
// VLOG(5) << "invoke 3x3 dw conv fp32";
auto paddings = *param.paddings;
bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
if (pads_equal && paddings[0] == paddings[2] &&
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) {
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
flag_trans_weights_ = false;
} else {
// trans weights
......@@ -50,11 +47,25 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto w_data_in = param.filter->data<float>();
lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw);
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
flag_trans_weights_ = true;
}
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
} else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32";
if (param.strides[0] == 2) { // conv5x5s2_dw
constexpr int cblock = 4;
auto oc = w_dims[0];
auto kh = w_dims[2];
auto cround = ROUNDUP(oc, cblock);
weights_.Resize({cround, 1, kh, kw});
auto w_data = weights_.mutable_data<float>();
auto w_data_in = param.filter->data<float>();
lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw);
flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else {
LOG(FATAL) << "this type dw conv not impl";
......
......@@ -85,7 +85,22 @@ void PoolCompute::Run() {
return;
}
} else {
if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) {
if (ksize[0] == 1 && strides[0] == 2 && paddings[0] == 0 && kps_equal) {
auto& ctx = this->ctx_->template As<ARMContext>();
if (pooling_type == "max") {
lite::arm::math::pooling1x1s2p0_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
return;
}
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din,
dout,
......
......@@ -85,6 +85,10 @@ class ConvOpLite : public OpLite {
if (act_type == "relu") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu;
param_.fuse_relu = true;
} else if (act_type == "relu6") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu6;
param_.activation_param.Relu_clipped_coef =
op_desc.GetAttr<float>("fuse_brelu_threshold"); // 6.f
} else if (act_type == "leaky_relu") {
param_.activation_param.active_type =
lite_api::ActivationType::kLeakyRelu;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册