提交 74f4a312 编写于 作者: H HappyAngel 提交者: yiicy

[arm] add conv_5x5s2_dw to support any padding (#2770)

1. add conv_5x5s2_dw to support any padding
2. add 1x1s2pooling impl
3. fix conv dw 3x3 s1p01 bug
上级 e1da1af9
...@@ -2339,17 +2339,29 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2339,17 +2339,29 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
int size_out_channel = w_out * h_out; int size_out_channel = w_out * h_out;
int w_stride = 9; int w_stride = 9;
int tile_w = (w_in + 3) >> 2; int tile_w = w_out >> 2;
int cnt_col = tile_w - 2; int remain = w_out % 4;
int cnt_col = tile_w - 1;
unsigned int size_pad_right = (unsigned int)(1 + (tile_w << 2) - w_in); unsigned int size_pad_right = (unsigned int)(5 + (tile_w << 2) - w_in);
const unsigned int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 5) {
size_pad_right = 1;
cnt_col -= 1;
remain = 4;
} else if (remain == 0 && size_pad_right == 6) {
size_pad_right = 2;
cnt_col -= 1;
remain = 4;
}
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx + 4), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_result = uint32x4_t vmask_result =
vcgtq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); vcgtq_u32(vdupq_n_u32(remain), vld1q_u32(remian_idx));
unsigned int vmask[8]; unsigned int vmask[8];
vst1q_u32(vmask, vmask_rp1); vst1q_u32(vmask, vmask_rp1);
...@@ -2398,7 +2410,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2398,7 +2410,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
const float *din_ptr5 = dr5; const float *din_ptr5 = dr5;
float *ptr_zero = const_cast<float *>(zero); float *ptr_zero = const_cast<float *>(zero);
#ifdef __aarch64__ #ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) { for (int i = 0; i < h_out; i += 4) {
//! process top pad pad_h = 1 //! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
...@@ -2484,7 +2496,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout, ...@@ -2484,7 +2496,7 @@ void conv_depthwise_3x3s1p1_bias(float *dout,
dout_ptr = dout_ptr + 4 * w_out; dout_ptr = dout_ptr + 4 * w_out;
} }
#else #else
for (int i = 0; i < h_in; i += 2) { for (int i = 0; i < h_out; i += 2) {
//! process top pad pad_h = 1 //! process top pad pad_h = 1
din_ptr0 = dr0; din_ptr0 = dr0;
din_ptr1 = dr1; din_ptr1 = dr1;
...@@ -2883,39 +2895,57 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, ...@@ -2883,39 +2895,57 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
wbias = vdupq_n_f32(0.f); wbias = vdupq_n_f32(0.f);
} }
int hs = -1;
int he = 3;
float out_buf1[4]; float out_buf1[4];
float out_buf2[4]; float out_buf2[4];
float trash_buf[4]; float trash_buf[4];
int h_cnt = (h_out + 1) >> 1;
float *doutr0 = dout_channel; float *doutr0 = dout_channel;
float *doutr1 = dout_channel + w_out; float *doutr1 = dout_channel + w_out;
for (int j = 0; j < h_cnt; ++j) { const float *dr0 = din_channel;
const float *dr0 = din_channel + hs * w_in; const float *dr1 = dr0 + w_in;
const float *dr1 = dr0 + w_in; const float *dr2 = dr1 + w_in;
const float *dr2 = dr1 + w_in; const float *dr3 = dr2 + w_in;
const float *dr3 = dr2 + w_in;
if (hs == -1) { for (int j = 0; j < h_out; j += 2) {
dr0 = zero; const float *dr0_ptr = dr0;
const float *dr1_ptr = dr1;
const float *dr2_ptr = dr2;
const float *dr3_ptr = dr3;
if (j == 0) {
dr0_ptr = zero;
dr1_ptr = dr0;
dr2_ptr = dr1;
dr3_ptr = dr2;
dr0 = dr1;
dr1 = dr2;
} else {
dr0 = dr2;
dr1 = dr3;
}
dr2 = dr1 + w_in;
dr3 = dr2 + w_in;
//! process bottom pad
if (j + 3 > h_in) {
switch (j + 3 - h_in) {
case 3:
dr1_ptr = zero;
case 2:
dr2_ptr = zero;
case 1:
dr3_ptr = zero;
default:
break;
}
} }
switch (he - h_in) { //! process bottom remain
case 2: if (j + 2 > h_out) {
dr2 = zero; doutr1 = trash_buf;
doutr1 = trash_buf;
case 1:
dr3 = zero;
default:
break;
} }
act_switch_3x3s1p1_s(dr0, act_switch_3x3s1p1_s(dr0_ptr,
dr1, dr1_ptr,
dr2, dr2_ptr,
dr3, dr3_ptr,
out_buf1, out_buf1,
out_buf2, out_buf2,
wr0, wr0,
...@@ -2931,8 +2961,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout, ...@@ -2931,8 +2961,6 @@ void conv_depthwise_3x3s1p1_bias_s(float *dout,
} }
doutr0 = doutr1; doutr0 = doutr1;
doutr1 += w_out; doutr1 += w_out;
hs += 2;
he += 2;
} // end of processing heights } // end of processing heights
} // end of processing channels } // end of processing channels
} // end of processing batchs } // end of processing batchs
...@@ -3458,6 +3486,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout, ...@@ -3458,6 +3486,12 @@ void conv_depthwise_3x3s1p0_bias(float *dout,
unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in); unsigned int size_pad_right = (unsigned int)(6 + (tile_w << 2) - w_in);
const int remian_idx[4] = {0, 1, 2, 3}; const int remian_idx[4] = {0, 1, 2, 3};
if (remain == 0 && size_pad_right == 6) { // w_in == w_out and w_out % 4 == 0
tile_w -= 1;
remain = 4;
size_pad_right = 2;
}
uint32x4_t vmask_rp1 = uint32x4_t vmask_rp1 =
vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right)); vcgeq_u32(vld1q_u32(right_pad_idx), vdupq_n_u32(size_pad_right));
uint32x4_t vmask_rp2 = uint32x4_t vmask_rp2 =
...@@ -4016,22 +4050,21 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout, ...@@ -4016,22 +4050,21 @@ void conv_depthwise_3x3s1p0_bias_s(float *dout,
doutr0 = dout_channel + j * w_out; doutr0 = dout_channel + j * w_out;
doutr1 = doutr0 + w_out; doutr1 = doutr0 + w_out;
if (j + 3 >= h_in) { if (j + 4 > h_in) {
switch (j + 3 - h_in) { switch (j + 4 - h_in) {
case 3: case 3:
dr1 = zero_ptr; dr1 = zero_ptr;
case 2: case 2:
dr2 = zero_ptr; dr2 = zero_ptr;
case 1: case 1:
dr3 = zero_ptr; dr3 = zero_ptr;
doutr1 = trash_buf;
case 0:
dr3 = zero_ptr;
doutr1 = trash_buf;
default: default:
break; break;
} }
} }
if (j + 2 > h_out) {
doutr1 = trash_buf;
}
unsigned int *vmask_ptr = vmask; unsigned int *vmask_ptr = vmask;
act_switch_3x3s1p0_s(dr0, act_switch_3x3s1p0_s(dr0,
dr1, dr1,
......
...@@ -1202,15 +1202,17 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1202,15 +1202,17 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
int out_pad_idx[4] = {0, 1, 2, 3}; int out_pad_idx[4] = {0, 1, 2, 3};
int size_pad_bottom = h_out * 2 - h_in; int size_pad_bottom = h_out * 2 - h_in;
int cnt_col = (w_out >> 2) - 2; int tile_w = w_out >> 2;
int size_right_remain = w_in - (7 + cnt_col * 8); int cnt_remain = w_out % 4;
if (size_right_remain >= 9) { unsigned int size_right_remain = (unsigned int)(7 + (tile_w << 3) - w_in);
cnt_col++; size_right_remain = 8 - size_right_remain;
size_right_remain -= 8;
}
int cnt_remain = (size_right_remain == 8) ? 4 : (w_out % 4);
int size_right_pad = w_out * 2 - w_in; if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
int cnt_col = tile_w - 1;
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6 vld1q_s32(right_pad_idx)); // 0 2 4 6
...@@ -1276,7 +1278,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1276,7 +1278,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
float* doutr1_ptr = nullptr; float* doutr1_ptr = nullptr;
#ifdef __aarch64__ #ifdef __aarch64__
for (int i = 0; i < h_in; i += 4) { for (int i = 0; i < h_out; i += 2) {
din0_ptr = dr0; din0_ptr = dr0;
din1_ptr = dr1; din1_ptr = dr1;
din2_ptr = dr2; din2_ptr = dr2;
...@@ -1303,8 +1305,8 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1303,8 +1305,8 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
dr4 = dr3 + w_in; dr4 = dr3 + w_in;
//! process bottom pad //! process bottom pad
if (i + 4 > h_in) { if (i * 2 + 4 > h_in) {
switch (i + 4 - h_in) { switch (i * 2 + 4 - h_in) {
case 4: case 4:
din1_ptr = zero_ptr; din1_ptr = zero_ptr;
case 3: case 3:
...@@ -1318,7 +1320,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1318,7 +1320,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
} }
} }
//! process output pad //! process output pad
if (i / 2 + 2 > h_out) { if (i + 2 > h_out) {
doutr1_ptr = write_ptr; doutr1_ptr = write_ptr;
} }
int cnt = cnt_col; int cnt = cnt_col;
...@@ -1343,7 +1345,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout, ...@@ -1343,7 +1345,7 @@ void conv_depthwise_3x3s2p1_bias(float* dout,
doutr0 = doutr0 + 2 * w_out; doutr0 = doutr0 + 2 * w_out;
} }
#else #else
for (int i = 0; i < h_in; i += 2) { for (int i = 0; i < h_out; i++) {
din0_ptr = dr0; din0_ptr = dr0;
din1_ptr = dr1; din1_ptr = dr1;
din2_ptr = dr2; din2_ptr = dr2;
...@@ -1641,7 +1643,8 @@ void act_switch_3x3s2p0(const float* din0_ptr, ...@@ -1641,7 +1643,8 @@ void act_switch_3x3s2p0(const float* din0_ptr,
"ld1 {v20.4s}, [%[inptr3]] \n" "ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n" "ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_RELU6 "ld1 {v22.4s}, [%[six_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_RELU6
"cmp %w[remain], #1 \n" "cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2 "blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_RELU6 RIGHT_RESULT_S2_RELU6
...@@ -1700,7 +1703,8 @@ void act_switch_3x3s2p0(const float* din0_ptr, ...@@ -1700,7 +1703,8 @@ void act_switch_3x3s2p0(const float* din0_ptr,
"ld1 {v20.4s}, [%[inptr3]] \n" "ld1 {v20.4s}, [%[inptr3]] \n"
"ld1 {v21.4s}, [%[inptr4]] \n" "ld1 {v21.4s}, [%[inptr4]] \n"
"ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8} "ext v10.16b, v0.16b, v15.16b, #4 \n" // v10 = {2,4,6,8}
MID_COMPUTE_S2 MID_RESULT_S2_LEAKY_RELU "ld1 {v22.4s}, [%[scale_ptr]] \n" MID_COMPUTE_S2
MID_RESULT_S2_LEAKY_RELU
"cmp %w[remain], #1 \n" "cmp %w[remain], #1 \n"
"blt 4f \n" RIGHT_COMPUTE_S2 "blt 4f \n" RIGHT_COMPUTE_S2
RIGHT_RESULT_S2_LEAKY_RELU RIGHT_RESULT_S2_LEAKY_RELU
...@@ -1718,7 +1722,7 @@ void act_switch_3x3s2p0(const float* din0_ptr, ...@@ -1718,7 +1722,7 @@ void act_switch_3x3s2p0(const float* din0_ptr,
[w1] "w"(wr1), [w1] "w"(wr1),
[w2] "w"(wr2), [w2] "w"(wr2),
[remain] "r"(cnt_remain), [remain] "r"(cnt_remain),
[six_ptr] "r"(vscale), [scale_ptr] "r"(vscale),
[mask1] "w"(vmask_rp1), [mask1] "w"(vmask_rp1),
[mask2] "w"(vmask_rp2), [mask2] "w"(vmask_rp2),
[wmask] "w"(wmask), [wmask] "w"(wmask),
...@@ -1834,7 +1838,14 @@ void conv_depthwise_3x3s2p0_bias(float* dout, ...@@ -1834,7 +1838,14 @@ void conv_depthwise_3x3s2p0_bias(float* dout,
int tile_w = w_out >> 2; int tile_w = w_out >> 2;
int cnt_remain = w_out % 4; int cnt_remain = w_out % 4;
unsigned int size_right_remain = (unsigned int)(w_in - (tile_w << 3)); unsigned int size_right_remain = (unsigned int)(8 + (tile_w << 3) - w_in);
size_right_remain = 8 - size_right_remain;
if (cnt_remain == 0 && size_right_remain == 0) {
cnt_remain = 4;
tile_w -= 1;
size_right_remain = 8;
}
uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain), uint32x4_t vmask_rp1 = vcgtq_s32(vdupq_n_s32(size_right_remain),
vld1q_s32(right_pad_idx)); // 0 2 4 6 vld1q_s32(right_pad_idx)); // 0 2 4 6
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -150,11 +150,26 @@ void conv_depthwise_5x5s2_fp32(const float* din, ...@@ -150,11 +150,26 @@ void conv_depthwise_5x5s2_fp32(const float* din,
int win, int win,
const float* weights, const float* weights,
const float* bias, const float* bias,
int pad, const operators::ConvParam& param,
bool flag_bias, const operators::ActivationParam act_param,
bool flag_relu,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_5x5s2p2_fp32(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
template <typename Dtype> template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout, void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din, const int8_t* din,
......
...@@ -589,10 +589,9 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -589,10 +589,9 @@ void conv_depthwise_3x3_fp32(const void* din,
int stride = param.strides[1]; int stride = param.strides[1];
int pad = pad_w; int pad = pad_w;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
bool pads_equal = bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
if (stride == 1) { if (stride == 1) {
if (pads_equal && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1] if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
...@@ -624,9 +623,8 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -624,9 +623,8 @@ void conv_depthwise_3x3_fp32(const void* din,
act_param, act_param,
ctx); ctx);
} }
} else if (stride == 2) { } else if (stride == 2) {
if (pad_h == pad_w && (pad < 2)) { // support pad = [0, 1] if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
...@@ -678,12 +676,13 @@ void conv_depthwise_5x5_fp32(const void* din, ...@@ -678,12 +676,13 @@ void conv_depthwise_5x5_fp32(const void* din,
ARMContext* ctx, ARMContext* ctx,
const float* scale) { const float* scale) {
auto paddings = *param.paddings; auto paddings = *param.paddings;
auto act_param = param.activation_param;
int pad = paddings[0]; int pad = paddings[0];
int stride = param.strides[1]; int stride = param.strides[1];
bool flag_relu = param.fuse_relu; bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
ctx->ExtendWorkspace((w_in + w_out) * sizeof(float)); ctx->ExtendWorkspace((w_in + w_out) * sizeof(float));
if (pad == 2 && stride == 2) { if (stride == 2) {
conv_depthwise_5x5s2_fp32(reinterpret_cast<const float*>(din), conv_depthwise_5x5s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
...@@ -695,9 +694,8 @@ void conv_depthwise_5x5_fp32(const void* din, ...@@ -695,9 +694,8 @@ void conv_depthwise_5x5_fp32(const void* din,
w_in, w_in,
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
pad, param,
flag_bias, act_param,
flag_relu,
ctx); ctx);
} else if (stride == 1) { } else if (stride == 1) {
conv_depthwise_5x5s1_fp32(reinterpret_cast<const float*>(din), conv_depthwise_5x5s1_fp32(reinterpret_cast<const float*>(din),
......
...@@ -898,6 +898,119 @@ void pooling_global_avg(const float* din, ...@@ -898,6 +898,119 @@ void pooling_global_avg(const float* din,
} }
} }
void pooling1x1s2p0_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
int win_ext = w_unroll_size * 8;
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
memset(zero_ptr, 0, win * sizeof(float));
auto write_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), wout * sizeof(float)));
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
for (int h = 0; h < hout; h += 4) {
const float* din0_ptr = data_in_channel + h * 2 * win;
const float* din1_ptr = din0_ptr + 2 * win;
const float* din2_ptr = din1_ptr + 2 * win;
const float* din3_ptr = din2_ptr + 2 * win;
float* doutr0 = data_out_channel + h * wout;
float* doutr1 = doutr0 + wout;
float* doutr2 = doutr1 + wout;
float* doutr3 = doutr2 + wout;
if (h + 4 > hout) {
switch (h + 4 - hout) {
case 3:
doutr1 = write_ptr;
case 2:
doutr2 = write_ptr;
case 1:
doutr3 = write_ptr;
default:
break;
}
}
if (h * 2 + 4 >= hin) {
switch (h * 2 + 4 - hin) {
case 4:
din0_ptr = zero_ptr;
case 3:
case 2:
din1_ptr = zero_ptr;
case 1:
case 0:
din2_ptr = zero_ptr;
din3_ptr = zero_ptr;
default:
break;
}
}
for (int i = 0; i < w_unroll_size; i++) {
float32x4x2_t din0 = vld2q_f32(din0_ptr);
float32x4x2_t din1 = vld2q_f32(din1_ptr);
float32x4x2_t din2 = vld2q_f32(din2_ptr);
float32x4x2_t din3 = vld2q_f32(din3_ptr);
din0_ptr += 8;
din1_ptr += 8;
din2_ptr += 8;
din3_ptr += 8;
vst1q_f32(doutr0, din0.val[0]);
vst1q_f32(doutr1, din1.val[0]);
vst1q_f32(doutr2, din2.val[0]);
vst1q_f32(doutr3, din3.val[0]);
doutr0 += 4;
doutr1 += 4;
doutr2 += 4;
doutr3 += 4;
}
int j = win_ext;
for (int i = 0; i < w_unroll_remian; i++) {
if (j >= win) {
*doutr0++ = 0.f;
*doutr1++ = 0.f;
*doutr2++ = 0.f;
*doutr3++ = 0.f;
} else {
*doutr0++ = *din0_ptr;
*doutr1++ = *din1_ptr;
*doutr2++ = *din2_ptr;
*doutr3++ = *din3_ptr;
din0_ptr += 2;
din1_ptr += 2;
din2_ptr += 2;
din3_ptr += 2;
}
j += 2;
}
}
}
}
TargetFree(TARGET(kARM), zero_ptr);
TargetFree(TARGET(kARM), write_ptr);
}
void pooling2x2s2_max(const float* din, void pooling2x2s2_max(const float* din,
float* dout, float* dout,
int num, int num,
......
...@@ -64,6 +64,16 @@ void pooling_global_avg(const float* din, ...@@ -64,6 +64,16 @@ void pooling_global_avg(const float* din,
int hin, int hin,
int win); int win);
void pooling1x1s2p0_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win);
void pooling2x2s2_max(const float* din, void pooling2x2s2_max(const float* din,
float* dout, float* dout,
int num, int num,
......
...@@ -79,6 +79,9 @@ cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -79,6 +79,9 @@ cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr("act_type", act_type_); op_desc.SetAttr("act_type", act_type_);
if (act_type_ == "relu") { if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true); op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "relu6") {
float alpha = act_op_desc.GetAttr<float>("threshold");
op_desc.SetAttr("fuse_brelu_threshold", alpha);
} else if (act_type_ == "leaky_relu") { } else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha"); float alpha = act_op_desc.GetAttr<float>("alpha");
op_desc.SetAttr("leaky_relu_alpha", alpha); op_desc.SetAttr("leaky_relu_alpha", alpha);
......
...@@ -56,13 +56,12 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -56,13 +56,12 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh); bool kps_equal = (param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2)); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (stride == 1 || stride == 2));
bool flag_dw_5x5 = pads_all_equal && ((kw == 5 && stride == 1) || bool flag_dw_5x5 = (paddings[0] == paddings[2]) &&
(kw == 5 && stride == 2 && pad == 2)); ((kw == 5 && stride == 1) || (kw == 5 && stride == 2));
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl /// select conv impl
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
no_dilation && flag_dw) {
/// dw conv impl /// dw conv impl
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking dw conv"; // VLOG(3) << "invoking dw conv";
......
...@@ -28,16 +28,13 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -28,16 +28,13 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto kw = w_dims[3]; auto kw = w_dims[3];
auto paddings = *param.paddings;
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
// VLOG(5) << "invoke 3x3 dw conv fp32"; // VLOG(5) << "invoke 3x3 dw conv fp32";
auto paddings = *param.paddings; bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
bool pads_equal = if (pads_less && paddings[0] == paddings[2] &&
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
if (pads_equal && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) { (paddings[0] == 0 || paddings[0] == 1)) {
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
flag_trans_weights_ = false; flag_trans_weights_ = false;
} else { } else {
// trans weights // trans weights
...@@ -50,11 +47,25 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -50,11 +47,25 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto w_data_in = param.filter->data<float>(); auto w_data_in = param.filter->data<float>();
lite::arm::math::conv_trans_weights_numc( lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw); w_data_in, w_data, oc, 1, cblock, kh * kw);
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
flag_trans_weights_ = true; flag_trans_weights_ = true;
} }
impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
} else if (kw == 5) { } else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32"; // VLOG(5) << "invoke 5x5 dw conv fp32";
if (param.strides[0] == 2) { // conv5x5s2_dw
constexpr int cblock = 4;
auto oc = w_dims[0];
auto kh = w_dims[2];
auto cround = ROUNDUP(oc, cblock);
weights_.Resize({cround, 1, kh, kw});
auto w_data = weights_.mutable_data<float>();
auto w_data_in = param.filter->data<float>();
lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw);
flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
impl_ = lite::arm::math::conv_depthwise_5x5_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
} else { } else {
LOG(FATAL) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
......
...@@ -85,7 +85,22 @@ void PoolCompute::Run() { ...@@ -85,7 +85,22 @@ void PoolCompute::Run() {
return; return;
} }
} else { } else {
if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) { if (ksize[0] == 1 && strides[0] == 2 && paddings[0] == 0 && kps_equal) {
auto& ctx = this->ctx_->template As<ARMContext>();
if (pooling_type == "max") {
lite::arm::math::pooling1x1s2p0_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
return;
}
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, lite::arm::math::pooling2x2s2_max(din,
dout, dout,
......
...@@ -85,6 +85,10 @@ class ConvOpLite : public OpLite { ...@@ -85,6 +85,10 @@ class ConvOpLite : public OpLite {
if (act_type == "relu") { if (act_type == "relu") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu; param_.activation_param.active_type = lite_api::ActivationType::kRelu;
param_.fuse_relu = true; param_.fuse_relu = true;
} else if (act_type == "relu6") {
param_.activation_param.active_type = lite_api::ActivationType::kRelu6;
param_.activation_param.Relu_clipped_coef =
op_desc.GetAttr<float>("fuse_brelu_threshold"); // 6.f
} else if (act_type == "leaky_relu") { } else if (act_type == "leaky_relu") {
param_.activation_param.active_type = param_.activation_param.active_type =
lite_api::ActivationType::kLeakyRelu; lite_api::ActivationType::kLeakyRelu;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册