提交 4b9df8fb 编写于 作者: Y yiicy 提交者: Xiaoyang LI

improve dw conv performance

*  imporve prepack_input func speed in int8 3x3s1 dw conv

* fix code style

* fix code style

* improve 3x3s1 dw fp32 conv speed a little

* arm add 5x5s1 int8 dw conv, test=develop
上级 33c335a5
......@@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const int win_round = wout_round + 2;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! llc_size = threads * win_round * hout_c_block * hin_r_block *
//! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(win_round * threads + hout_c_block * wout_round * threads * 4);
int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(win_round * threads * hout_c_block +
hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
......@@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
prepack_input_nxwc8_int8_dw(
din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
......
......@@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const int win_round = wout_round * 2 /*stride*/ + 1;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! llc_size = threads * win_round * hin_r_block * hout_c_block *
//! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(2 * win_round * threads + hout_c_block * wout_round * threads * 4);
int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(2 * win_round * threads * hout_c_block +
hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
......@@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
prepack_input_nxwc8_int8_dw(
din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
......
......@@ -389,10 +389,9 @@ inline void prepack_input_nxwc4_dw(const float* din,
}
}
inline void prepack_input_nxw_c8_int8(const int8_t* din,
inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
int8_t* dout,
int cs,
int ce,
int hs,
int he,
int ws,
......@@ -402,95 +401,61 @@ inline void prepack_input_nxw_c8_int8(const int8_t* din,
int height) {
int n = he - hs;
if (n <= 0) {
LOG(FATAL) << "prepack_input_nxw_c8 input height must > 0";
return;
LOG(FATAL) << "prepack_dw_input_int8, valid height must > zero";
}
int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we;
int size_w = we - ws;
int size_channel_in = width * height;
int size_out_row = size_w * 8;
int valid_w = w1 - w0;
size_t valid_w_byte = valid_w * sizeof(int8_t);
auto ptr_c = static_cast<int8_t*>(TargetMalloc(TARGET(kARM), 8 * size_w));
int8_t* ptr_r[8];
int8_t* ptr_c_ori[8] = {ptr_c,
ptr_c + size_w,
ptr_c + 2 * size_w,
ptr_c + 3 * size_w,
ptr_c + 4 * size_w,
ptr_c + 5 * size_w,
ptr_c + 6 * size_w,
ptr_c + 7 * size_w};
int pad_l = ws < 0 ? -ws : 0;
int pad_r = we > width ? we - width : 0;
int size_c = width * height;
int valid_cnt = valid_w >> 3;
int remain = valid_w & 7;
int8_t zero_ptr[size_w * 2]; // NOLINT
memset(zero_ptr, 0, size_w * 2);
int loop = size_w / 8;
int remain = size_w - loop * 8;
for (int c = cs; c < ce; c += 8) {
auto din_c = din + c * size_channel_in;
for (int j = 0; j < 8; ++j) {
ptr_r[j] = ptr_c_ori[j];
}
//! valid channel
if (c + 8 > channel) {
switch (c + 8 - channel) {
for (int h = hs; h < he; ++h) {
const int8_t* ptr_c0 = din + h * width + cs * size_c;
const int8_t* ptr_c1 = ptr_c0 + size_c;
const int8_t* ptr_c2 = ptr_c1 + size_c;
const int8_t* ptr_c3 = ptr_c2 + size_c;
const int8_t* ptr_c4 = ptr_c3 + size_c;
const int8_t* ptr_c5 = ptr_c4 + size_c;
const int8_t* ptr_c6 = ptr_c5 + size_c;
const int8_t* ptr_c7 = ptr_c6 + size_c;
if (h < 0 || h >= height) {
memset(dout, 0, 8 * size_w * sizeof(int8_t));
dout += size_w * 8;
continue;
} else if (cs + 8 > channel) {
switch (cs + 8 - channel) {
case 7:
ptr_r[1] = zero_ptr;
ptr_c1 = zero_ptr;
case 6:
ptr_r[2] = zero_ptr;
ptr_c2 = zero_ptr;
case 5:
ptr_r[3] = zero_ptr;
ptr_c3 = zero_ptr;
case 4:
ptr_r[4] = zero_ptr;
ptr_c4 = zero_ptr;
case 3:
ptr_r[5] = zero_ptr;
ptr_c5 = zero_ptr;
case 2:
ptr_r[6] = zero_ptr;
ptr_c6 = zero_ptr;
case 1:
ptr_r[7] = zero_ptr;
ptr_c7 = zero_ptr;
default:
break;
}
}
//! valid height
int j = 0;
for (int i = hs; i < he; i++) {
auto din_r = din_c + i * width;
for (int k = 0; k < 8; ++k) {
if (ptr_r[k] != zero_ptr) {
if (i < 0 || i >= height) {
ptr_r[k] = zero_ptr + size_w;
} else {
ptr_r[k] = ptr_c_ori[k];
auto ptr = ptr_r[k];
for (int w = ws; w < w0; ++w) {
*(ptr++) = 0;
if (pad_l) {
memset(dout, 0, pad_l * 8 * sizeof(int8_t));
dout += pad_l * 8;
}
memcpy(ptr, din_r + k * size_channel_in, valid_w_byte);
ptr += valid_w;
for (int w = w1; w < we; ++w) {
*(ptr++) = 0;
}
}
}
}
int cnt = loop;
int8_t* inr0 = ptr_r[0];
int8_t* inr1 = ptr_r[1];
int8_t* inr2 = ptr_r[2];
int8_t* inr3 = ptr_r[3];
int8_t* inr4 = ptr_r[4];
int8_t* inr5 = ptr_r[5];
int8_t* inr6 = ptr_r[6];
int8_t* inr7 = ptr_r[7];
auto ptr_out = dout + j * size_out_row;
if (cnt > 0) {
if (valid_cnt) {
int cnt = valid_cnt;
#ifdef __aarch64__
asm volatile(
/* main loop */
......@@ -534,15 +499,15 @@ inline void prepack_input_nxw_c8_int8(const int8_t* din,
"stp d14, d15, [%[ptr_out]], #16\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
[r0] "+r"(ptr_c0),
[r1] "+r"(ptr_c1),
[r2] "+r"(ptr_c2),
[r3] "+r"(ptr_c3),
[r4] "+r"(ptr_c4),
[r5] "+r"(ptr_c5),
[r6] "+r"(ptr_c6),
[r7] "+r"(ptr_c7),
[ptr_out] "+r"(dout)
:
: "cc",
"memory",
......@@ -591,35 +556,35 @@ inline void prepack_input_nxw_c8_int8(const int8_t* din,
"vst1.32 {d4-d7}, [%[ptr_out]]!\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
[r0] "+r"(ptr_c0),
[r1] "+r"(ptr_c1),
[r2] "+r"(ptr_c2),
[r3] "+r"(ptr_c3),
[r4] "+r"(ptr_c4),
[r5] "+r"(ptr_c5),
[r6] "+r"(ptr_c6),
[r7] "+r"(ptr_c7),
[ptr_out] "+r"(dout)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif // aarch64
#endif // __aarch64__
}
for (int k = 0; k < remain; ++k) {
ptr_out[0] = *(inr0++);
ptr_out[1] = *(inr1++);
ptr_out[2] = *(inr2++);
ptr_out[3] = *(inr3++);
ptr_out[4] = *(inr4++);
ptr_out[5] = *(inr5++);
ptr_out[6] = *(inr6++);
ptr_out[7] = *(inr7++);
ptr_out += 8;
for (int i = 0; i < remain; ++i) {
dout[0] = *(ptr_c0++);
dout[1] = *(ptr_c1++);
dout[2] = *(ptr_c2++);
dout[3] = *(ptr_c3++);
dout[4] = *(ptr_c4++);
dout[5] = *(ptr_c5++);
dout[6] = *(ptr_c6++);
dout[7] = *(ptr_c7++);
dout += 8;
}
j++;
if (pad_r) {
memset(dout, 0, pad_r * 8 * sizeof(int8_t));
dout += pad_r * 8;
}
}
TargetFree(TARGET(kARM), ptr_c);
}
/*wirte result in outputs
......
......@@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din,
bool flag_relu,
ARMContext* ctx);
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din,
}
}
void conv_depthwise_5x5_int8_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
void conv_depthwise_5x5_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -97,7 +97,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......@@ -136,7 +136,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......
......@@ -31,7 +31,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// select dw conv kernel
if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv fp32";
/// trans weights
// trans weights
constexpr int cblock = 4;
auto oc = w_dims[0];
auto kh = w_dims[2];
......@@ -75,6 +75,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
}
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
......@@ -83,6 +84,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
......@@ -123,6 +134,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
......@@ -131,6 +143,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
......
......@@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
}
#endif /// 3x3dw
#if 0 /// 5x5dw
#if 1 /// 5x5dw
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
for (auto& stride : {1}) {
for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
......@@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto &h : {1, 3, 15, 19, 28, 32, 75}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册