未验证 提交 688166b5 编写于 作者: Y yiicy 提交者: GitHub

[ARM] add int8 5x5s2 dw conv impl, test=develop (#2813)

* [ARM] add 5x5s2 depthwise conv armv8 impl, test=develop

* [ARM] add int8 5x5s2 dw conv armv7 impl, test=develop

* [ARM] add int8 5x5s2 dw conv impl, test=develop

* [ARM] close int8 conv ut, test=develop
上级 0b6210b7
...@@ -78,6 +78,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -78,6 +78,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv3x3s2_depthwise_int8.cc conv3x3s2_depthwise_int8.cc
conv5x5s1_depthwise_int8.cc conv5x5s1_depthwise_int8.cc
conv5x5s1_depthwise_fp32.cc conv5x5s1_depthwise_fp32.cc
conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc conv3x3_winograd_fp32_c4.cc
conv_winograd_3x3.cc conv_winograd_3x3.cc
......
此差异已折叠。
...@@ -189,6 +189,24 @@ void conv_depthwise_5x5s1_int8(Dtype* dout, ...@@ -189,6 +189,24 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
int padh, int padh,
ARMContext* ctx); ARMContext* ctx);
template <typename Dtype>
void conv_depthwise_5x5s2_int8(Dtype* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
bool flag_relu,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -880,6 +880,23 @@ void conv_depthwise_5x5_int8_fp32(const void* din, ...@@ -880,6 +880,23 @@ void conv_depthwise_5x5_int8_fp32(const void* din,
pad_w, pad_w,
pad_h, pad_h,
ctx); ctx);
} else if (stride == 2) {
conv_depthwise_5x5s2_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else { } else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
} }
...@@ -922,6 +939,23 @@ void conv_depthwise_5x5_int8_int8(const void* din, ...@@ -922,6 +939,23 @@ void conv_depthwise_5x5_int8_int8(const void* din,
pad_w, pad_w,
pad_h, pad_h,
ctx); ctx);
} else if (stride == 2) {
conv_depthwise_5x5s2_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_relu,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
} else { } else {
LOG(FATAL) << "unsupport this type 5x5 dw conv int8"; LOG(FATAL) << "unsupport this type 5x5 dw conv int8";
} }
......
...@@ -107,7 +107,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -107,7 +107,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2));
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1); bool flag_dw_5x5 = pads_all_equal && (kw == 5 && (sw == 1 || sw == 2));
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
...@@ -152,7 +152,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -152,7 +152,7 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh); bool kps_equal = (pw == ph) && (sh == sw) && (kw == kh);
bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1); bool no_dilation = (dilations[0] == 1) && (dilations[1] == 1);
bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2)); bool flag_dw_3x3 = (kw == 3 && kh == 3 && (sw == 1 || sw == 2));
bool flag_dw_5x5 = pads_all_equal && (kw == 5 && kh == 5 && sw == 1); bool flag_dw_5x5 = pads_all_equal && (kw == 5 && (sw == 1 || sw == 2));
bool flag_dw = flag_dw_3x3 || flag_dw_5x5; bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
if (param.groups == ic && ic == oc && kps_equal && pads_equal && if (param.groups == ic && ic == oc && kps_equal && pads_equal &&
......
...@@ -457,7 +457,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -457,7 +457,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
const std::vector<int>& power_mode) {} const std::vector<int>& power_mode) {}
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
#if 1 /// 3x3dw #if 0 /// 3x3dw
TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
...@@ -494,7 +494,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -494,7 +494,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
#if 1 /// 5x5dw #if 1 /// 5x5dw
TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& stride : {1}) { for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2, 3, 4}) { for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) { for (auto& flag_relu : {false, true}) {
...@@ -525,7 +525,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -525,7 +525,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
} }
#endif /// 5x5dw #endif /// 5x5dw
#if 1 /// conv1x1s1 #if 0 /// conv1x1s1
TEST(TestConv1x1s1Int8, test_conv1x1s1) { TEST(TestConv1x1s1Int8, test_conv1x1s1) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 32}) { for (auto& cin : {1, 3, 8, 32}) {
...@@ -562,7 +562,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -562,7 +562,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
} }
#endif /// conv1x1s1 #endif /// conv1x1s1
#if 1 /// conv3x3s1 #if 0 /// conv3x3s1
TEST(TestConv3x3s1Int8, test_conv_3x3s1) { TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 8, 33}) { for (auto& cin : {1, 3, 8, 33}) {
...@@ -602,7 +602,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -602,7 +602,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
} }
#endif /// conv3x3s1 #endif /// conv3x3s1
#if 1 /// conv3x3s2 #if 0 /// conv3x3s2
TEST(TestConv3x3s2Int8, test_conv_3x3s2) { TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 3, 31}) { for (auto& cin : {1, 3, 31}) {
...@@ -642,7 +642,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -642,7 +642,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
} }
#endif /// conv3x3s2 #endif /// conv3x3s2
#if 1 /// random param conv #if 0 /// random param conv
TEST(TestConvRandInt8, test_conv_rand) { TEST(TestConvRandInt8, test_conv_rand) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
for (auto& cin : {1, 17}) { for (auto& cin : {1, 17}) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册