提交 4b9df8fb 编写于 作者: Y yiicy 提交者: Xiaoyang LI

improve dw conv performance

*  imporve prepack_input func speed in int8 3x3s1 dw conv

* fix code style

* fix code style

* improve 3x3s1 dw fp32 conv speed a little

* arm add 5x5s1 int8 dw conv, test=develop
上级 33c335a5
...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, ...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const int win_round = wout_round + 2; const int win_round = wout_round + 2;
//! get h block //! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round //! llc_size = threads * win_round * hout_c_block * hin_r_block *
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t) //! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2 //! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2 //! hin_r_block = hout_r_block + 2
int hout_r_block = int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(llc_size - 2 * win_round * threads) / (win_round * threads * hout_c_block +
(win_round * threads + hout_c_block * wout_round * threads * 4); hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
...@@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, ...@@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size); int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din; auto pre_din = tmp_din;
#endif #endif
prepack_input_nxw_c8_int8(din_batch, prepack_input_nxwc8_int8_dw(
pre_din, din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
const int8_t* block_inr0 = pre_din; const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len; const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len; const int8_t* block_inr2 = block_inr1 + in_len;
......
...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, ...@@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const int win_round = wout_round * 2 /*stride*/ + 1; const int win_round = wout_round * 2 /*stride*/ + 1;
//! get h block //! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round //! llc_size = threads * win_round * hin_r_block * hout_c_block *
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t) //! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2 //! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2 //! hin_r_block = hout_r_block + 2
int hout_r_block = int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(llc_size - 2 * win_round * threads) / (2 * win_round * threads * hout_c_block +
(2 * win_round * threads + hout_c_block * wout_round * threads * 4); hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block; hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block = hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel; ((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
...@@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout, ...@@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size); int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din; auto pre_din = tmp_din;
#endif #endif
prepack_input_nxw_c8_int8(din_batch, prepack_input_nxwc8_int8_dw(
pre_din, din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
const int8_t* block_inr0 = pre_din; const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len; const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len; const int8_t* block_inr2 = block_inr1 + in_len;
......
...@@ -389,237 +389,202 @@ inline void prepack_input_nxwc4_dw(const float* din, ...@@ -389,237 +389,202 @@ inline void prepack_input_nxwc4_dw(const float* din,
} }
} }
inline void prepack_input_nxw_c8_int8(const int8_t* din, inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
int8_t* dout, int8_t* dout,
int cs, int cs,
int ce, int hs,
int hs, int he,
int he, int ws,
int ws, int we,
int we, int channel,
int channel, int width,
int width, int height) {
int height) {
int n = he - hs; int n = he - hs;
if (n <= 0) { if (n <= 0) {
LOG(FATAL) << "prepack_input_nxw_c8 input height must > 0"; LOG(FATAL) << "prepack_dw_input_int8, valid height must > zero";
return;
} }
int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws; int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we; int w1 = we > width ? width : we;
int size_w = we - ws;
int size_channel_in = width * height;
int size_out_row = size_w * 8;
int valid_w = w1 - w0; int valid_w = w1 - w0;
size_t valid_w_byte = valid_w * sizeof(int8_t); int pad_l = ws < 0 ? -ws : 0;
int pad_r = we > width ? we - width : 0;
auto ptr_c = static_cast<int8_t*>(TargetMalloc(TARGET(kARM), 8 * size_w)); int size_c = width * height;
int8_t* ptr_r[8];
int8_t* ptr_c_ori[8] = {ptr_c, int valid_cnt = valid_w >> 3;
ptr_c + size_w, int remain = valid_w & 7;
ptr_c + 2 * size_w,
ptr_c + 3 * size_w,
ptr_c + 4 * size_w,
ptr_c + 5 * size_w,
ptr_c + 6 * size_w,
ptr_c + 7 * size_w};
int8_t zero_ptr[size_w * 2]; // NOLINT int8_t zero_ptr[size_w * 2]; // NOLINT
memset(zero_ptr, 0, size_w * 2); memset(zero_ptr, 0, size_w * 2);
int loop = size_w / 8; for (int h = hs; h < he; ++h) {
int remain = size_w - loop * 8; const int8_t* ptr_c0 = din + h * width + cs * size_c;
const int8_t* ptr_c1 = ptr_c0 + size_c;
for (int c = cs; c < ce; c += 8) { const int8_t* ptr_c2 = ptr_c1 + size_c;
auto din_c = din + c * size_channel_in; const int8_t* ptr_c3 = ptr_c2 + size_c;
for (int j = 0; j < 8; ++j) { const int8_t* ptr_c4 = ptr_c3 + size_c;
ptr_r[j] = ptr_c_ori[j]; const int8_t* ptr_c5 = ptr_c4 + size_c;
} const int8_t* ptr_c6 = ptr_c5 + size_c;
//! valid channel const int8_t* ptr_c7 = ptr_c6 + size_c;
if (c + 8 > channel) { if (h < 0 || h >= height) {
switch (c + 8 - channel) { memset(dout, 0, 8 * size_w * sizeof(int8_t));
dout += size_w * 8;
continue;
} else if (cs + 8 > channel) {
switch (cs + 8 - channel) {
case 7: case 7:
ptr_r[1] = zero_ptr; ptr_c1 = zero_ptr;
case 6: case 6:
ptr_r[2] = zero_ptr; ptr_c2 = zero_ptr;
case 5: case 5:
ptr_r[3] = zero_ptr; ptr_c3 = zero_ptr;
case 4: case 4:
ptr_r[4] = zero_ptr; ptr_c4 = zero_ptr;
case 3: case 3:
ptr_r[5] = zero_ptr; ptr_c5 = zero_ptr;
case 2: case 2:
ptr_r[6] = zero_ptr; ptr_c6 = zero_ptr;
case 1: case 1:
ptr_r[7] = zero_ptr; ptr_c7 = zero_ptr;
default: default:
break; break;
} }
} }
//! valid height if (pad_l) {
int j = 0; memset(dout, 0, pad_l * 8 * sizeof(int8_t));
for (int i = hs; i < he; i++) { dout += pad_l * 8;
auto din_r = din_c + i * width; }
for (int k = 0; k < 8; ++k) { if (valid_cnt) {
if (ptr_r[k] != zero_ptr) { int cnt = valid_cnt;
if (i < 0 || i >= height) {
ptr_r[k] = zero_ptr + size_w;
} else {
ptr_r[k] = ptr_c_ori[k];
auto ptr = ptr_r[k];
for (int w = ws; w < w0; ++w) {
*(ptr++) = 0;
}
memcpy(ptr, din_r + k * size_channel_in, valid_w_byte);
ptr += valid_w;
for (int w = w1; w < we; ++w) {
*(ptr++) = 0;
}
}
}
}
int cnt = loop;
int8_t* inr0 = ptr_r[0];
int8_t* inr1 = ptr_r[1];
int8_t* inr2 = ptr_r[2];
int8_t* inr3 = ptr_r[3];
int8_t* inr4 = ptr_r[4];
int8_t* inr5 = ptr_r[5];
int8_t* inr6 = ptr_r[6];
int8_t* inr7 = ptr_r[7];
auto ptr_out = dout + j * size_out_row;
if (cnt > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile( asm volatile(
/* main loop */ /* main loop */
"1:\n" "1:\n"
"ldr d0, [%[r0]], #8\n" "ldr d0, [%[r0]], #8\n"
"ldr d1, [%[r1]], #8\n" "ldr d1, [%[r1]], #8\n"
"ldr d2, [%[r2]], #8\n" "ldr d2, [%[r2]], #8\n"
"ldr d3, [%[r3]], #8\n" "ldr d3, [%[r3]], #8\n"
"ldr d4, [%[r4]], #8\n" "ldr d4, [%[r4]], #8\n"
"ldr d5, [%[r5]], #8\n" "ldr d5, [%[r5]], #8\n"
"ldr d6, [%[r6]], #8\n" "ldr d6, [%[r6]], #8\n"
"ldr d7, [%[r7]], #8\n" "ldr d7, [%[r7]], #8\n"
"trn1 v8.8b, v0.8b, v1.8b\n" "trn1 v8.8b, v0.8b, v1.8b\n"
"trn2 v9.8b, v0.8b, v1.8b\n" "trn2 v9.8b, v0.8b, v1.8b\n"
"trn1 v10.8b, v2.8b, v3.8b\n" "trn1 v10.8b, v2.8b, v3.8b\n"
"trn2 v11.8b, v2.8b, v3.8b\n" "trn2 v11.8b, v2.8b, v3.8b\n"
"trn1 v12.8b, v4.8b, v5.8b\n" "trn1 v12.8b, v4.8b, v5.8b\n"
"trn2 v13.8b, v4.8b, v5.8b\n" "trn2 v13.8b, v4.8b, v5.8b\n"
"trn1 v14.8b, v6.8b, v7.8b\n" "trn1 v14.8b, v6.8b, v7.8b\n"
"trn2 v15.8b, v6.8b, v7.8b\n" "trn2 v15.8b, v6.8b, v7.8b\n"
"trn1 v0.4h, v8.4h, v10.4h\n" "trn1 v0.4h, v8.4h, v10.4h\n"
"trn2 v1.4h, v8.4h, v10.4h\n" "trn2 v1.4h, v8.4h, v10.4h\n"
"trn1 v2.4h, v9.4h, v11.4h\n" "trn1 v2.4h, v9.4h, v11.4h\n"
"trn2 v3.4h, v9.4h, v11.4h\n" "trn2 v3.4h, v9.4h, v11.4h\n"
"trn1 v4.4h, v12.4h, v14.4h\n" "trn1 v4.4h, v12.4h, v14.4h\n"
"trn2 v5.4h, v12.4h, v14.4h\n" "trn2 v5.4h, v12.4h, v14.4h\n"
"trn1 v6.4h, v13.4h, v15.4h\n" "trn1 v6.4h, v13.4h, v15.4h\n"
"trn2 v7.4h, v13.4h, v15.4h\n" "trn2 v7.4h, v13.4h, v15.4h\n"
"trn1 v8.2s, v0.2s, v4.2s\n" "trn1 v8.2s, v0.2s, v4.2s\n"
"trn1 v9.2s, v2.2s, v6.2s\n" "trn1 v9.2s, v2.2s, v6.2s\n"
"trn1 v10.2s, v1.2s, v5.2s\n" "trn1 v10.2s, v1.2s, v5.2s\n"
"trn1 v11.2s, v3.2s, v7.2s\n" "trn1 v11.2s, v3.2s, v7.2s\n"
"stp d8, d9, [%[ptr_out]], #16\n" "stp d8, d9, [%[ptr_out]], #16\n"
"trn2 v12.2s, v0.2s, v4.2s\n" "trn2 v12.2s, v0.2s, v4.2s\n"
"trn2 v13.2s, v2.2s, v6.2s\n" "trn2 v13.2s, v2.2s, v6.2s\n"
"stp d10, d11, [%[ptr_out]], #16\n" "stp d10, d11, [%[ptr_out]], #16\n"
"trn2 v14.2s, v1.2s, v5.2s\n" "trn2 v14.2s, v1.2s, v5.2s\n"
"trn2 v15.2s, v3.2s, v7.2s\n" "trn2 v15.2s, v3.2s, v7.2s\n"
"subs %w[cnt], %w[cnt], #1\n" "subs %w[cnt], %w[cnt], #1\n"
"stp d12, d13, [%[ptr_out]], #16\n" "stp d12, d13, [%[ptr_out]], #16\n"
"stp d14, d15, [%[ptr_out]], #16\n" "stp d14, d15, [%[ptr_out]], #16\n"
"bne 1b\n" "bne 1b\n"
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[r0] "+r"(inr0), [r0] "+r"(ptr_c0),
[r1] "+r"(inr1), [r1] "+r"(ptr_c1),
[r2] "+r"(inr2), [r2] "+r"(ptr_c2),
[r3] "+r"(inr3), [r3] "+r"(ptr_c3),
[r4] "+r"(inr4), [r4] "+r"(ptr_c4),
[r5] "+r"(inr5), [r5] "+r"(ptr_c5),
[r6] "+r"(inr6), [r6] "+r"(ptr_c6),
[r7] "+r"(inr7), [r7] "+r"(ptr_c7),
[ptr_out] "+r"(ptr_out) [ptr_out] "+r"(dout)
: :
: "cc", : "cc",
"memory", "memory",
"v0", "v0",
"v1", "v1",
"v2", "v2",
"v3", "v3",
"v4", "v4",
"v5", "v5",
"v6", "v6",
"v7", "v7",
"v8", "v8",
"v9", "v9",
"v10", "v10",
"v11", "v11",
"v12", "v12",
"v13", "v13",
"v14", "v14",
"v15"); "v15");
#else #else
asm volatile( asm volatile(
/* main loop */ /* main loop */
"1:\n" "1:\n"
"vld1.32 {d0}, [%[r0]]!\n" "vld1.32 {d0}, [%[r0]]!\n"
"vld1.32 {d1}, [%[r1]]!\n" "vld1.32 {d1}, [%[r1]]!\n"
"vld1.32 {d2}, [%[r2]]!\n" "vld1.32 {d2}, [%[r2]]!\n"
"vld1.32 {d3}, [%[r3]]!\n" "vld1.32 {d3}, [%[r3]]!\n"
"vld1.32 {d4}, [%[r4]]!\n" "vld1.32 {d4}, [%[r4]]!\n"
"vld1.32 {d5}, [%[r5]]!\n" "vld1.32 {d5}, [%[r5]]!\n"
"vld1.32 {d6}, [%[r6]]!\n" "vld1.32 {d6}, [%[r6]]!\n"
"vld1.32 {d7}, [%[r7]]!\n" "vld1.32 {d7}, [%[r7]]!\n"
"vtrn.8 d0, d1\n" "vtrn.8 d0, d1\n"
"vtrn.8 d2, d3\n" "vtrn.8 d2, d3\n"
"vtrn.8 d4, d5\n" "vtrn.8 d4, d5\n"
"vtrn.8 d6, d7\n" "vtrn.8 d6, d7\n"
"vtrn.16 d0, d2\n" "vtrn.16 d0, d2\n"
"vtrn.16 d1, d3\n" "vtrn.16 d1, d3\n"
"vtrn.16 d4, d6\n" "vtrn.16 d4, d6\n"
"vtrn.16 d5, d7\n" "vtrn.16 d5, d7\n"
"vtrn.32 d0, d4\n" "vtrn.32 d0, d4\n"
"vtrn.32 d2, d6\n" "vtrn.32 d2, d6\n"
"vtrn.32 d1, d5\n" "vtrn.32 d1, d5\n"
"vtrn.32 d3, d7\n" "vtrn.32 d3, d7\n"
"subs %[cnt], #1\n" "subs %[cnt], #1\n"
"vst1.32 {d0-d3}, [%[ptr_out]]!\n" "vst1.32 {d0-d3}, [%[ptr_out]]!\n"
"vst1.32 {d4-d7}, [%[ptr_out]]!\n" "vst1.32 {d4-d7}, [%[ptr_out]]!\n"
"bne 1b\n" "bne 1b\n"
: [cnt] "+r"(cnt), : [cnt] "+r"(cnt),
[r0] "+r"(inr0), [r0] "+r"(ptr_c0),
[r1] "+r"(inr1), [r1] "+r"(ptr_c1),
[r2] "+r"(inr2), [r2] "+r"(ptr_c2),
[r3] "+r"(inr3), [r3] "+r"(ptr_c3),
[r4] "+r"(inr4), [r4] "+r"(ptr_c4),
[r5] "+r"(inr5), [r5] "+r"(ptr_c5),
[r6] "+r"(inr6), [r6] "+r"(ptr_c6),
[r7] "+r"(inr7), [r7] "+r"(ptr_c7),
[ptr_out] "+r"(ptr_out) [ptr_out] "+r"(dout)
: :
: "cc", "memory", "q0", "q1", "q2", "q3"); : "cc", "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
#endif // aarch64 }
} for (int i = 0; i < remain; ++i) {
for (int k = 0; k < remain; ++k) { dout[0] = *(ptr_c0++);
ptr_out[0] = *(inr0++); dout[1] = *(ptr_c1++);
ptr_out[1] = *(inr1++); dout[2] = *(ptr_c2++);
ptr_out[2] = *(inr2++); dout[3] = *(ptr_c3++);
ptr_out[3] = *(inr3++); dout[4] = *(ptr_c4++);
ptr_out[4] = *(inr4++); dout[5] = *(ptr_c5++);
ptr_out[5] = *(inr5++); dout[6] = *(ptr_c6++);
ptr_out[6] = *(inr6++); dout[7] = *(ptr_c7++);
ptr_out[7] = *(inr7++); dout += 8;
ptr_out += 8; }
} if (pad_r) {
j++; memset(dout, 0, pad_r * 8 * sizeof(int8_t));
dout += pad_r * 8;
} }
} }
TargetFree(TARGET(kARM), ptr_c);
} }
/*wirte result in outputs /*wirte result in outputs
......
...@@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din, ...@@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din,
bool flag_relu, bool flag_relu,
ARMContext* ctx); ARMContext* ctx);
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din, ...@@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din,
} }
} }
void conv_depthwise_5x5_int8_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
void conv_depthwise_5x5_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -97,7 +97,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -97,7 +97,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
...@@ -136,7 +136,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -136,7 +136,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2); bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2); bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......
...@@ -31,7 +31,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -31,7 +31,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv fp32"; VLOG(5) << "invoke 3x3 dw conv fp32";
/// trans weights // trans weights
constexpr int cblock = 4; constexpr int cblock = 4;
auto oc = w_dims[0]; auto oc = w_dims[0];
auto kh = w_dims[2]; auto kh = w_dims[2];
...@@ -75,6 +75,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -75,6 +75,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
} }
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out"; VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
...@@ -83,6 +84,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -83,6 +84,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>(); auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else { } else {
LOG(FATAL) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
} }
...@@ -123,6 +134,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -123,6 +134,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
} }
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out"; VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
...@@ -131,6 +143,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -131,6 +143,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>(); auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else { } else {
LOG(FATAL) << "this type dw conv not impl"; LOG(FATAL) << "this type dw conv not impl";
} }
......
...@@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
} }
#endif /// 3x3dw #endif /// 3x3dw
#if 0 /// 5x5dw #if 1 /// 5x5dw
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) { for (auto& stride : {1}) {
for (auto& pad : {0, 1, 2}) { for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_relu : {false, true}) {
...@@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto &h : {1, 3, 15, 19, 28, 32, 75}) { for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h})); dims.push_back(DDim({batch, c, h, h}));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册