提交 4b9df8fb 编写于 作者: Y yiicy 提交者: Xiaoyang LI

improve dw conv performance

*  imporve prepack_input func speed in int8 3x3s1 dw conv

* fix code style

* fix code style

* improve 3x3s1 dw fp32 conv speed a little

* arm add 5x5s1 int8 dw conv, test=develop
上级 33c335a5
......@@ -56,13 +56,14 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const int win_round = wout_round + 2;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! llc_size = threads * win_round * hout_c_block * hin_r_block *
//! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(win_round * threads + hout_c_block * wout_round * threads * 4);
int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(win_round * threads * hout_c_block +
hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
......@@ -115,17 +116,9 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
prepack_input_nxwc8_int8_dw(
din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
......
......@@ -56,13 +56,14 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const int win_round = wout_round * 2 /*stride*/ + 1;
//! get h block
//! llc_size = threads * win_round * hin_r_block * sizeof(int8_t) + wout_round
//! * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! llc_size = threads * win_round * hin_r_block * hout_c_block *
//! sizeof(int8_t)
//! + wout_round * hout_c_block * hout_r_block * threads * sizeof(int32_t)
//! win_round = wout_round + 2
//! hin_r_block = hout_r_block + 2
int hout_r_block =
(llc_size - 2 * win_round * threads) /
(2 * win_round * threads + hout_c_block * wout_round * threads * 4);
int hout_r_block = (llc_size - 2 * win_round * threads * hout_c_block) /
(2 * win_round * threads * hout_c_block +
hout_c_block * wout_round * threads * 4);
hout_r_block = hout_r_block > hout ? hout : hout_r_block;
hout_r_block =
((hout_r_block + hout_r_kernel - 1) / hout_r_kernel) * hout_r_kernel;
......@@ -115,17 +116,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
int32_t* pre_out = reinterpret_cast<int32_t*>(tmp_din + pre_in_size);
auto pre_din = tmp_din;
#endif
prepack_input_nxw_c8_int8(din_batch,
pre_din,
c,
c + hout_c_block,
hs,
he,
ws,
we,
chin,
win,
hin);
prepack_input_nxwc8_int8_dw(
din_batch, pre_din, c, hs, he, ws, we, chin, win, hin);
const int8_t* block_inr0 = pre_din;
const int8_t* block_inr1 = block_inr0 + in_len;
const int8_t* block_inr2 = block_inr1 + in_len;
......
......@@ -389,237 +389,202 @@ inline void prepack_input_nxwc4_dw(const float* din,
}
}
inline void prepack_input_nxw_c8_int8(const int8_t* din,
int8_t* dout,
int cs,
int ce,
int hs,
int he,
int ws,
int we,
int channel,
int width,
int height) {
inline void prepack_input_nxwc8_int8_dw(const int8_t* din,
int8_t* dout,
int cs,
int hs,
int he,
int ws,
int we,
int channel,
int width,
int height) {
int n = he - hs;
if (n <= 0) {
LOG(FATAL) << "prepack_input_nxw_c8 input height must > 0";
return;
LOG(FATAL) << "prepack_dw_input_int8, valid height must > zero";
}
int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws;
int w1 = we > width ? width : we;
int size_w = we - ws;
int size_channel_in = width * height;
int size_out_row = size_w * 8;
int valid_w = w1 - w0;
size_t valid_w_byte = valid_w * sizeof(int8_t);
auto ptr_c = static_cast<int8_t*>(TargetMalloc(TARGET(kARM), 8 * size_w));
int8_t* ptr_r[8];
int8_t* ptr_c_ori[8] = {ptr_c,
ptr_c + size_w,
ptr_c + 2 * size_w,
ptr_c + 3 * size_w,
ptr_c + 4 * size_w,
ptr_c + 5 * size_w,
ptr_c + 6 * size_w,
ptr_c + 7 * size_w};
int pad_l = ws < 0 ? -ws : 0;
int pad_r = we > width ? we - width : 0;
int size_c = width * height;
int valid_cnt = valid_w >> 3;
int remain = valid_w & 7;
int8_t zero_ptr[size_w * 2]; // NOLINT
memset(zero_ptr, 0, size_w * 2);
int loop = size_w / 8;
int remain = size_w - loop * 8;
for (int c = cs; c < ce; c += 8) {
auto din_c = din + c * size_channel_in;
for (int j = 0; j < 8; ++j) {
ptr_r[j] = ptr_c_ori[j];
}
//! valid channel
if (c + 8 > channel) {
switch (c + 8 - channel) {
for (int h = hs; h < he; ++h) {
const int8_t* ptr_c0 = din + h * width + cs * size_c;
const int8_t* ptr_c1 = ptr_c0 + size_c;
const int8_t* ptr_c2 = ptr_c1 + size_c;
const int8_t* ptr_c3 = ptr_c2 + size_c;
const int8_t* ptr_c4 = ptr_c3 + size_c;
const int8_t* ptr_c5 = ptr_c4 + size_c;
const int8_t* ptr_c6 = ptr_c5 + size_c;
const int8_t* ptr_c7 = ptr_c6 + size_c;
if (h < 0 || h >= height) {
memset(dout, 0, 8 * size_w * sizeof(int8_t));
dout += size_w * 8;
continue;
} else if (cs + 8 > channel) {
switch (cs + 8 - channel) {
case 7:
ptr_r[1] = zero_ptr;
ptr_c1 = zero_ptr;
case 6:
ptr_r[2] = zero_ptr;
ptr_c2 = zero_ptr;
case 5:
ptr_r[3] = zero_ptr;
ptr_c3 = zero_ptr;
case 4:
ptr_r[4] = zero_ptr;
ptr_c4 = zero_ptr;
case 3:
ptr_r[5] = zero_ptr;
ptr_c5 = zero_ptr;
case 2:
ptr_r[6] = zero_ptr;
ptr_c6 = zero_ptr;
case 1:
ptr_r[7] = zero_ptr;
ptr_c7 = zero_ptr;
default:
break;
}
}
//! valid height
int j = 0;
for (int i = hs; i < he; i++) {
auto din_r = din_c + i * width;
for (int k = 0; k < 8; ++k) {
if (ptr_r[k] != zero_ptr) {
if (i < 0 || i >= height) {
ptr_r[k] = zero_ptr + size_w;
} else {
ptr_r[k] = ptr_c_ori[k];
auto ptr = ptr_r[k];
for (int w = ws; w < w0; ++w) {
*(ptr++) = 0;
}
memcpy(ptr, din_r + k * size_channel_in, valid_w_byte);
ptr += valid_w;
for (int w = w1; w < we; ++w) {
*(ptr++) = 0;
}
}
}
}
int cnt = loop;
int8_t* inr0 = ptr_r[0];
int8_t* inr1 = ptr_r[1];
int8_t* inr2 = ptr_r[2];
int8_t* inr3 = ptr_r[3];
int8_t* inr4 = ptr_r[4];
int8_t* inr5 = ptr_r[5];
int8_t* inr6 = ptr_r[6];
int8_t* inr7 = ptr_r[7];
auto ptr_out = dout + j * size_out_row;
if (cnt > 0) {
if (pad_l) {
memset(dout, 0, pad_l * 8 * sizeof(int8_t));
dout += pad_l * 8;
}
if (valid_cnt) {
int cnt = valid_cnt;
#ifdef __aarch64__
asm volatile(
/* main loop */
"1:\n"
"ldr d0, [%[r0]], #8\n"
"ldr d1, [%[r1]], #8\n"
"ldr d2, [%[r2]], #8\n"
"ldr d3, [%[r3]], #8\n"
"ldr d4, [%[r4]], #8\n"
"ldr d5, [%[r5]], #8\n"
"ldr d6, [%[r6]], #8\n"
"ldr d7, [%[r7]], #8\n"
"trn1 v8.8b, v0.8b, v1.8b\n"
"trn2 v9.8b, v0.8b, v1.8b\n"
"trn1 v10.8b, v2.8b, v3.8b\n"
"trn2 v11.8b, v2.8b, v3.8b\n"
"trn1 v12.8b, v4.8b, v5.8b\n"
"trn2 v13.8b, v4.8b, v5.8b\n"
"trn1 v14.8b, v6.8b, v7.8b\n"
"trn2 v15.8b, v6.8b, v7.8b\n"
"trn1 v0.4h, v8.4h, v10.4h\n"
"trn2 v1.4h, v8.4h, v10.4h\n"
"trn1 v2.4h, v9.4h, v11.4h\n"
"trn2 v3.4h, v9.4h, v11.4h\n"
"trn1 v4.4h, v12.4h, v14.4h\n"
"trn2 v5.4h, v12.4h, v14.4h\n"
"trn1 v6.4h, v13.4h, v15.4h\n"
"trn2 v7.4h, v13.4h, v15.4h\n"
"trn1 v8.2s, v0.2s, v4.2s\n"
"trn1 v9.2s, v2.2s, v6.2s\n"
"trn1 v10.2s, v1.2s, v5.2s\n"
"trn1 v11.2s, v3.2s, v7.2s\n"
"stp d8, d9, [%[ptr_out]], #16\n"
"trn2 v12.2s, v0.2s, v4.2s\n"
"trn2 v13.2s, v2.2s, v6.2s\n"
"stp d10, d11, [%[ptr_out]], #16\n"
"trn2 v14.2s, v1.2s, v5.2s\n"
"trn2 v15.2s, v3.2s, v7.2s\n"
"subs %w[cnt], %w[cnt], #1\n"
"stp d12, d13, [%[ptr_out]], #16\n"
"stp d14, d15, [%[ptr_out]], #16\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
asm volatile(
/* main loop */
"1:\n"
"ldr d0, [%[r0]], #8\n"
"ldr d1, [%[r1]], #8\n"
"ldr d2, [%[r2]], #8\n"
"ldr d3, [%[r3]], #8\n"
"ldr d4, [%[r4]], #8\n"
"ldr d5, [%[r5]], #8\n"
"ldr d6, [%[r6]], #8\n"
"ldr d7, [%[r7]], #8\n"
"trn1 v8.8b, v0.8b, v1.8b\n"
"trn2 v9.8b, v0.8b, v1.8b\n"
"trn1 v10.8b, v2.8b, v3.8b\n"
"trn2 v11.8b, v2.8b, v3.8b\n"
"trn1 v12.8b, v4.8b, v5.8b\n"
"trn2 v13.8b, v4.8b, v5.8b\n"
"trn1 v14.8b, v6.8b, v7.8b\n"
"trn2 v15.8b, v6.8b, v7.8b\n"
"trn1 v0.4h, v8.4h, v10.4h\n"
"trn2 v1.4h, v8.4h, v10.4h\n"
"trn1 v2.4h, v9.4h, v11.4h\n"
"trn2 v3.4h, v9.4h, v11.4h\n"
"trn1 v4.4h, v12.4h, v14.4h\n"
"trn2 v5.4h, v12.4h, v14.4h\n"
"trn1 v6.4h, v13.4h, v15.4h\n"
"trn2 v7.4h, v13.4h, v15.4h\n"
"trn1 v8.2s, v0.2s, v4.2s\n"
"trn1 v9.2s, v2.2s, v6.2s\n"
"trn1 v10.2s, v1.2s, v5.2s\n"
"trn1 v11.2s, v3.2s, v7.2s\n"
"stp d8, d9, [%[ptr_out]], #16\n"
"trn2 v12.2s, v0.2s, v4.2s\n"
"trn2 v13.2s, v2.2s, v6.2s\n"
"stp d10, d11, [%[ptr_out]], #16\n"
"trn2 v14.2s, v1.2s, v5.2s\n"
"trn2 v15.2s, v3.2s, v7.2s\n"
"subs %w[cnt], %w[cnt], #1\n"
"stp d12, d13, [%[ptr_out]], #16\n"
"stp d14, d15, [%[ptr_out]], #16\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(ptr_c0),
[r1] "+r"(ptr_c1),
[r2] "+r"(ptr_c2),
[r3] "+r"(ptr_c3),
[r4] "+r"(ptr_c4),
[r5] "+r"(ptr_c5),
[r6] "+r"(ptr_c6),
[r7] "+r"(ptr_c7),
[ptr_out] "+r"(dout)
:
: "cc",
"memory",
"v0",
"v1",
"v2",
"v3",
"v4",
"v5",
"v6",
"v7",
"v8",
"v9",
"v10",
"v11",
"v12",
"v13",
"v14",
"v15");
#else
asm volatile(
/* main loop */
"1:\n"
"vld1.32 {d0}, [%[r0]]!\n"
"vld1.32 {d1}, [%[r1]]!\n"
"vld1.32 {d2}, [%[r2]]!\n"
"vld1.32 {d3}, [%[r3]]!\n"
"vld1.32 {d4}, [%[r4]]!\n"
"vld1.32 {d5}, [%[r5]]!\n"
"vld1.32 {d6}, [%[r6]]!\n"
"vld1.32 {d7}, [%[r7]]!\n"
"vtrn.8 d0, d1\n"
"vtrn.8 d2, d3\n"
"vtrn.8 d4, d5\n"
"vtrn.8 d6, d7\n"
"vtrn.16 d0, d2\n"
"vtrn.16 d1, d3\n"
"vtrn.16 d4, d6\n"
"vtrn.16 d5, d7\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d3, d7\n"
"subs %[cnt], #1\n"
"vst1.32 {d0-d3}, [%[ptr_out]]!\n"
"vst1.32 {d4-d7}, [%[ptr_out]]!\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(inr0),
[r1] "+r"(inr1),
[r2] "+r"(inr2),
[r3] "+r"(inr3),
[r4] "+r"(inr4),
[r5] "+r"(inr5),
[r6] "+r"(inr6),
[r7] "+r"(inr7),
[ptr_out] "+r"(ptr_out)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif // aarch64
}
for (int k = 0; k < remain; ++k) {
ptr_out[0] = *(inr0++);
ptr_out[1] = *(inr1++);
ptr_out[2] = *(inr2++);
ptr_out[3] = *(inr3++);
ptr_out[4] = *(inr4++);
ptr_out[5] = *(inr5++);
ptr_out[6] = *(inr6++);
ptr_out[7] = *(inr7++);
ptr_out += 8;
}
j++;
asm volatile(
/* main loop */
"1:\n"
"vld1.32 {d0}, [%[r0]]!\n"
"vld1.32 {d1}, [%[r1]]!\n"
"vld1.32 {d2}, [%[r2]]!\n"
"vld1.32 {d3}, [%[r3]]!\n"
"vld1.32 {d4}, [%[r4]]!\n"
"vld1.32 {d5}, [%[r5]]!\n"
"vld1.32 {d6}, [%[r6]]!\n"
"vld1.32 {d7}, [%[r7]]!\n"
"vtrn.8 d0, d1\n"
"vtrn.8 d2, d3\n"
"vtrn.8 d4, d5\n"
"vtrn.8 d6, d7\n"
"vtrn.16 d0, d2\n"
"vtrn.16 d1, d3\n"
"vtrn.16 d4, d6\n"
"vtrn.16 d5, d7\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d3, d7\n"
"subs %[cnt], #1\n"
"vst1.32 {d0-d3}, [%[ptr_out]]!\n"
"vst1.32 {d4-d7}, [%[ptr_out]]!\n"
"bne 1b\n"
: [cnt] "+r"(cnt),
[r0] "+r"(ptr_c0),
[r1] "+r"(ptr_c1),
[r2] "+r"(ptr_c2),
[r3] "+r"(ptr_c3),
[r4] "+r"(ptr_c4),
[r5] "+r"(ptr_c5),
[r6] "+r"(ptr_c6),
[r7] "+r"(ptr_c7),
[ptr_out] "+r"(dout)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
#endif // __aarch64__
}
for (int i = 0; i < remain; ++i) {
dout[0] = *(ptr_c0++);
dout[1] = *(ptr_c1++);
dout[2] = *(ptr_c2++);
dout[3] = *(ptr_c3++);
dout[4] = *(ptr_c4++);
dout[5] = *(ptr_c5++);
dout[6] = *(ptr_c6++);
dout[7] = *(ptr_c7++);
dout += 8;
}
if (pad_r) {
memset(dout, 0, pad_r * 8 * sizeof(int8_t));
dout += pad_r * 8;
}
}
TargetFree(TARGET(kARM), ptr_c);
}
/*wirte result in outputs
......
......@@ -153,6 +153,24 @@ void conv_depthwise_5x5s2_fp32(const float* din,
bool flag_relu,
ARMContext* ctx);
template <typename Dtype>
void conv_depthwise_5x5s1_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -805,6 +805,88 @@ void conv_depthwise_3x3_int8_int8(const void* din,
}
}
void conv_depthwise_5x5_int8_fp32(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
void conv_depthwise_5x5_int8_int8(const void* din,
void* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const void* weights,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx,
const float* scale) {
int pad_h = param.paddings[0];
int pad_w = param.paddings[1];
int stride = param.strides[1];
bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1) {
conv_depthwise_5x5s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -97,7 +97,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......@@ -136,7 +136,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3) && (sw == 1 || sw == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1 && ph == 2);
bool flag_dw_5x5 = (kw == 5 && sw == 1);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
......
......@@ -31,7 +31,7 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// select dw conv kernel
if (kw == 3) {
VLOG(5) << "invoke 3x3 dw conv fp32";
/// trans weights
// trans weights
constexpr int cblock = 4;
auto oc = w_dims[0];
auto kh = w_dims[2];
......@@ -75,6 +75,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
}
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
......@@ -83,6 +84,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
......@@ -123,6 +134,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
/// select dw conv kernel
if (kw == 3) {
// trans weights
VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
......@@ -131,6 +143,16 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else if (kw == 5) {
// trans weights
VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 25);
flag_trans_weights_ = true;
} else {
LOG(FATAL) << "this type dw conv not impl";
}
......
......@@ -481,10 +481,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
}
#endif /// 3x3dw
#if 0 /// 5x5dw
#if 1 /// 5x5dw
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) {
for (auto& stride : {1}) {
for (auto& pad : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
......@@ -492,7 +492,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) {
for (auto &h : {1, 3, 15, 19, 28, 32, 75}) {
for (auto& h : {1, 3, 15, 19, 28, 32, 75}) {
dims.push_back(DDim({batch, c, h, h}));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册