未验证 提交 12a17969 编写于 作者: H HappyAngel 提交者: GitHub

【arm】fix pooling no-equal padding problem (#2956)

上级 294375f9
...@@ -92,7 +92,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -92,7 +92,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const operators::ActivationParam act_param, const operators::ActivationParam act_param,
ARMContext* ctx) { ARMContext* ctx) {
if (pad == 0) { if (pad == 0) {
if (w_in > 7) { if (w_in > 8) {
conv_depthwise_3x3s2p0_bias(dout, conv_depthwise_3x3s2p0_bias(dout,
din, din,
weights, weights,
...@@ -476,7 +476,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -476,7 +476,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \
"fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \ "fmul v12.4s, v17.4s, v22.4s \n" \
\ \
"ld1 {v20.4s}, [%[inptr3]] \n" \ "ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \ "ld1 {v21.4s}, [%[inptr4]] \n" \
...@@ -552,6 +552,7 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -552,6 +552,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"ld1 {v20.4s}, [%[inptr3]] \n" \ "ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \ "ld1 {v21.4s}, [%[inptr4]] \n" \
\ \
"fadd v17.4s, v17.4s, v14.4s \n" \
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \ "bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
"ext v10.16b, v0.16b, v15.16b, #4 \n" \ "ext v10.16b, v0.16b, v15.16b, #4 \n" \
"fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \ "fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
......
...@@ -113,9 +113,9 @@ namespace math { ...@@ -113,9 +113,9 @@ namespace math {
"fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \ "fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \ "fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \ "bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \ "bif v20.16b, v4.16b, v3.16b \n" /* choose*/ \
"bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \ "bif v21.16b, v6.16b, v5.16b \n" /* choose*/ \
"bif v19.16b, v8.16b, v7.16b \n" /* choose*/ "bif v22.16b, v8.16b, v7.16b \n" /* choose*/
#define STORE /* save result */ \ #define STORE /* save result */ \
"str q19, [%[outc0]], #16\n" \ "str q19, [%[outc0]], #16\n" \
"str q20, [%[outc1]], #16\n" \ "str q20, [%[outc1]], #16\n" \
......
...@@ -67,7 +67,6 @@ void pooling_basic(const float* din, ...@@ -67,7 +67,6 @@ void pooling_basic(const float* din,
} }
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
// Pooling_average_include_padding // Pooling_average_include_padding
// Pooling_average_exclude_padding
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out; float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in; const float* din_batch = din + n * chin * size_channel_in;
...@@ -906,7 +905,9 @@ void pooling1x1s2p0_max(const float* din, ...@@ -906,7 +905,9 @@ void pooling1x1s2p0_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win) { int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1021,7 +1022,9 @@ void pooling2x2s2_max(const float* din, ...@@ -1021,7 +1022,9 @@ void pooling2x2s2_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win) { int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1104,7 +1107,9 @@ void pooling2x2s2_avg(const float* din, ...@@ -1104,7 +1107,9 @@ void pooling2x2s2_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive) { bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1117,6 +1122,9 @@ void pooling2x2s2_avg(const float* din, ...@@ -1117,6 +1122,9 @@ void pooling2x2s2_avg(const float* din,
int w_unroll_size = wout / 4; int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4; int w_unroll_remian = wout - w_unroll_size * 4;
float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4 float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
memset(zero_ptr, 0, win * sizeof(float));
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
...@@ -1132,7 +1140,7 @@ void pooling2x2s2_avg(const float* din, ...@@ -1132,7 +1140,7 @@ void pooling2x2s2_avg(const float* din,
auto dr0 = r0; auto dr0 = r0;
auto dr1 = r1; auto dr1 = r1;
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = r0; dr1 = zero_ptr;
} }
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
...@@ -1178,6 +1186,7 @@ void pooling2x2s2_avg(const float* din, ...@@ -1178,6 +1186,7 @@ void pooling2x2s2_avg(const float* din,
} }
} }
} }
TargetFree(TARGET(kARM), zero_ptr);
} }
void pooling3x3s1p1_max(const float* din, void pooling3x3s1p1_max(const float* din,
...@@ -1188,7 +1197,9 @@ void pooling3x3s1p1_max(const float* din, ...@@ -1188,7 +1197,9 @@ void pooling3x3s1p1_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win) { int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1331,7 +1342,9 @@ void pooling3x3s1p1_avg(const float* din, ...@@ -1331,7 +1342,9 @@ void pooling3x3s1p1_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive) { bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1389,7 +1402,13 @@ void pooling3x3s1p1_avg(const float* din, ...@@ -1389,7 +1402,13 @@ void pooling3x3s1p1_avg(const float* din,
if (exclusive) { if (exclusive) {
coef_h = 1.f; coef_h = 1.f;
} else { } else {
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom == 1) {
coef_h = 0.5f; coef_h = 0.5f;
} else {
coef_h = 1.f;
}
} }
break; break;
case 1: case 1:
...@@ -1401,7 +1420,11 @@ void pooling3x3s1p1_avg(const float* din, ...@@ -1401,7 +1420,11 @@ void pooling3x3s1p1_avg(const float* din,
coef_h = 0.5f; coef_h = 0.5f;
} }
} else { } else {
if (pad_bottom >= 1) {
coef_h = 1.f / 3; coef_h = 1.f / 3;
} else {
coef_h = 0.5f;
}
} }
default: default:
break; break;
...@@ -1477,8 +1500,12 @@ void pooling3x3s1p1_avg(const float* din, ...@@ -1477,8 +1500,12 @@ void pooling3x3s1p1_avg(const float* din,
int st = wstart > 0 ? wstart : 0; int st = wstart > 0 ? wstart : 0;
if (wstart + K > win) { if (wstart + K > win) {
wend = win; wend = win;
if (!exclusive && wstart + K - win == 2) { if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2; coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
} }
} }
if (exclusive) { if (exclusive) {
...@@ -1509,7 +1536,9 @@ void pooling3x3s1p0_max(const float* din, ...@@ -1509,7 +1536,9 @@ void pooling3x3s1p0_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win) { int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1646,7 +1675,9 @@ void pooling3x3s1p0_avg(const float* din, ...@@ -1646,7 +1675,9 @@ void pooling3x3s1p0_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive) { bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1692,7 +1723,13 @@ void pooling3x3s1p0_avg(const float* din, ...@@ -1692,7 +1723,13 @@ void pooling3x3s1p0_avg(const float* din,
if (exclusive) { if (exclusive) {
coef_h = 1.f; coef_h = 1.f;
} else { } else {
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom = 1) {
coef_h = 0.5f; coef_h = 0.5f;
} else {
coef_h = 1.f;
}
} }
break; break;
case 1: case 1:
...@@ -1704,7 +1741,11 @@ void pooling3x3s1p0_avg(const float* din, ...@@ -1704,7 +1741,11 @@ void pooling3x3s1p0_avg(const float* din,
coef_h = 0.5f; coef_h = 0.5f;
} }
} else { } else {
if (pad_bottom >= 1) {
coef_h = 1.f / 3; coef_h = 1.f / 3;
} else {
coef_h = 0.5f;
}
} }
default: default:
break; break;
...@@ -1776,8 +1817,12 @@ void pooling3x3s1p0_avg(const float* din, ...@@ -1776,8 +1817,12 @@ void pooling3x3s1p0_avg(const float* din,
int st = wstart > 0 ? wstart : 0; int st = wstart > 0 ? wstart : 0;
if (wstart + K > win) { if (wstart + K > win) {
wend = win; wend = win;
if (!exclusive && wstart + K - win == 2) { if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2; coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
} }
} }
if (exclusive) { if (exclusive) {
...@@ -1811,7 +1856,9 @@ void pooling3x3s2p1_max(const float* din, ...@@ -1811,7 +1856,9 @@ void pooling3x3s2p1_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win) { int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1955,7 +2002,9 @@ void pooling3x3s2p1_avg(const float* din, ...@@ -1955,7 +2002,9 @@ void pooling3x3s2p1_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive) { bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -2015,7 +2064,13 @@ void pooling3x3s2p1_avg(const float* din, ...@@ -2015,7 +2064,13 @@ void pooling3x3s2p1_avg(const float* din,
if (exclusive) { if (exclusive) {
coef_h = 1.f; coef_h = 1.f;
} else { } else {
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom == 1) {
coef_h = 0.5f; coef_h = 0.5f;
} else {
coef_h = 1.f;
}
} }
break; break;
case 1: case 1:
...@@ -2026,9 +2081,13 @@ void pooling3x3s2p1_avg(const float* din, ...@@ -2026,9 +2081,13 @@ void pooling3x3s2p1_avg(const float* din,
} else { } else {
coef_h = 0.5f; coef_h = 0.5f;
} }
} else {
if (pad_bottom == 0) {
coef_h = 1.f / 2;
} else { } else {
coef_h = 1.f / 3; coef_h = 1.f / 3;
} }
}
default: default:
break; break;
} }
...@@ -2102,8 +2161,12 @@ void pooling3x3s2p1_avg(const float* din, ...@@ -2102,8 +2161,12 @@ void pooling3x3s2p1_avg(const float* din,
float coef = coef_h / 3.f; float coef = coef_h / 3.f;
if (wstart + K > win) { if (wstart + K > win) {
wend = win; wend = win;
if (!exclusive && wstart + K - win == 2) { if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2; coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
} }
} }
int st = wstart > 0 ? wstart : 0; int st = wstart > 0 ? wstart : 0;
...@@ -2135,7 +2198,9 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2135,7 +2198,9 @@ void pooling3x3s2p0_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win) { int win,
int pad_bottom,
int pad_right) {
const int K = 3; const int K = 3;
const int P = 0; const int P = 0;
const int S = 2; const int S = 2;
...@@ -2261,7 +2326,9 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -2261,7 +2326,9 @@ void pooling3x3s2p0_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive) { bool exclusive,
int pad_bottom,
int pad_right) {
const int K = 3; const int K = 3;
const int P = 0; const int P = 0;
const int S = 2; const int S = 2;
...@@ -2303,11 +2370,33 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -2303,11 +2370,33 @@ void pooling3x3s2p0_avg(const float* din,
case 2: case 2:
dr1 = zero_ptr; dr1 = zero_ptr;
dr2 = zero_ptr; dr2 = zero_ptr;
if (exclusive) {
coef_h = 1.f; coef_h = 1.f;
} else {
if (pad_bottom >= 2) {
coef_h = 1.f / 3;
} else if (pad_bottom == 1) {
coef_h = 0.5f;
} else {
coef_h = 1.0f;
}
}
break; break;
case 1: case 1:
dr2 = zero_ptr; dr2 = zero_ptr;
if (exclusive) {
if (fabsf(coef_h - 0.5f) < 1e-6f) {
coef_h = 1.f;
} else {
coef_h = 0.5f;
}
} else {
if (pad_bottom >= 1) {
coef_h = 1.0f / 3;
} else {
coef_h = 0.5f; coef_h = 0.5f;
}
}
break; break;
default: default:
break; break;
...@@ -2366,22 +2455,34 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -2366,22 +2455,34 @@ void pooling3x3s2p0_avg(const float* din,
dr2 -= 8; dr2 -= 8;
} }
// deal with right pad // deal with right pad
int rem = win - (w_unroll_size * 4) * S; int wstart = w_unroll_size * 4 * S - P;
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) { for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem); int wend = wstart + K; // std::min(wstart + K, win);
float coef = coef_h / (wend - wstart); float coef = coef_h / 3.f;
if (wstart + K > win) {
wend = win;
if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
}
}
int st = wstart > 0 ? wstart : 0;
if (exclusive) {
coef = coef_h / (wend - st);
}
float tmp = 0.f; float tmp = 0.f;
for (int i = wstart; i < wend; i++) { for (int i = 0; i < wend - st; i++) {
tmp += dr0[i]; tmp += dr0[i] + dr1[i] + dr2[i];
tmp += dr1[i];
tmp += dr2[i];
} }
tmp *= coef; *(dr_out++) = tmp * coef;
*(dr_out++) = tmp; dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
dr2 += S - (st - wstart);
wstart += S; wstart += S;
} }
r0 = r2; r0 = r2;
r1 = r0 + win; r1 = r0 + win;
r2 = r1 + win; r2 = r1 + win;
......
...@@ -72,7 +72,9 @@ void pooling1x1s2p0_max(const float* din, ...@@ -72,7 +72,9 @@ void pooling1x1s2p0_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win); int win,
int pad_bottom,
int pad_right);
void pooling2x2s2_max(const float* din, void pooling2x2s2_max(const float* din,
float* dout, float* dout,
...@@ -82,7 +84,9 @@ void pooling2x2s2_max(const float* din, ...@@ -82,7 +84,9 @@ void pooling2x2s2_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win); int win,
int pad_bottom,
int pad_right);
void pooling2x2s2_avg(const float* din, void pooling2x2s2_avg(const float* din,
float* dout, float* dout,
...@@ -93,7 +97,9 @@ void pooling2x2s2_avg(const float* din, ...@@ -93,7 +97,9 @@ void pooling2x2s2_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive); bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_max(const float* din, void pooling3x3s1p1_max(const float* din,
float* dout, float* dout,
...@@ -103,7 +109,9 @@ void pooling3x3s1p1_max(const float* din, ...@@ -103,7 +109,9 @@ void pooling3x3s1p1_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win); int win,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_avg(const float* din, void pooling3x3s1p1_avg(const float* din,
float* dout, float* dout,
...@@ -114,7 +122,9 @@ void pooling3x3s1p1_avg(const float* din, ...@@ -114,7 +122,9 @@ void pooling3x3s1p1_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive); bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s2p1_max(const float* din, void pooling3x3s2p1_max(const float* din,
float* dout, float* dout,
...@@ -124,7 +134,9 @@ void pooling3x3s2p1_max(const float* din, ...@@ -124,7 +134,9 @@ void pooling3x3s2p1_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win); int win,
int pad_bottom,
int pad_right);
void pooling3x3s1p0_max(const float* din, void pooling3x3s1p0_max(const float* din,
float* dout, float* dout,
...@@ -134,7 +146,9 @@ void pooling3x3s1p0_max(const float* din, ...@@ -134,7 +146,9 @@ void pooling3x3s1p0_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win); int win,
int pad_bottom,
int pad_right);
void pooling3x3s1p0_avg(const float* din, void pooling3x3s1p0_avg(const float* din,
float* dout, float* dout,
...@@ -145,7 +159,9 @@ void pooling3x3s1p0_avg(const float* din, ...@@ -145,7 +159,9 @@ void pooling3x3s1p0_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive); bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s2p1_avg(const float* din, void pooling3x3s2p1_avg(const float* din,
float* dout, float* dout,
...@@ -156,7 +172,9 @@ void pooling3x3s2p1_avg(const float* din, ...@@ -156,7 +172,9 @@ void pooling3x3s2p1_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive); bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s2p0_max(const float* din, void pooling3x3s2p0_max(const float* din,
float* dout, float* dout,
...@@ -166,7 +184,9 @@ void pooling3x3s2p0_max(const float* din, ...@@ -166,7 +184,9 @@ void pooling3x3s2p0_max(const float* din,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win); int win,
int pad_bottom,
int pad_right);
void pooling3x3s2p0_avg(const float* din, void pooling3x3s2p0_avg(const float* din,
float* dout, float* dout,
...@@ -177,7 +197,9 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -177,7 +197,9 @@ void pooling3x3s2p0_avg(const float* din,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive); bool exclusive,
int pad_bottom,
int pad_right);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -24,16 +24,27 @@ namespace mir { ...@@ -24,16 +24,27 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"}; std::vector<std::string> act_types{"relu"};
bool has_int8 = false;
bool has_arm_float = false;
bool has_cuda = false;
for (auto& place : graph->valid_places()) { for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kCUDA)) { if (place.precision == PRECISION(kInt8)) {
act_types.push_back("leaky_relu"); has_int8 = true;
break;
} }
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm_float = true;
}
if (place.target == TARGET(kCUDA)) {
has_cuda = true;
}
}
if (!has_int8 && has_arm_float) {
act_types.push_back("relu6"); act_types.push_back("relu6");
act_types.push_back("leaky_relu"); act_types.push_back("leaky_relu");
break;
} }
if (!has_int8 && has_cuda) {
act_types.push_back("leaky_relu");
} }
for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) {
for (auto act_type : act_types) { for (auto act_type : act_types) {
......
...@@ -53,14 +53,6 @@ void ConvActivationFuser::BuildPattern() { ...@@ -53,14 +53,6 @@ void ConvActivationFuser::BuildPattern() {
void ConvActivationFuser::InsertNewNode(SSAGraph* graph, void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
// not fuse quantized conv2d + relu6 for now
auto conv2d_op_desc = matched.at("conv2d")->stmt()->op_info();
bool is_conv2d_quantized = conv2d_op_desc->HasAttr("enable_int8") &&
conv2d_op_desc->GetAttr<bool>("enable_int8");
if (act_type_ == "relu6" && is_conv2d_quantized) {
return;
}
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_); auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op(); auto conv_old = matched.at("conv2d")->stmt()->op();
......
...@@ -47,13 +47,16 @@ void PoolCompute::Run() { ...@@ -47,13 +47,16 @@ void PoolCompute::Run() {
bool use_quantizer = param.use_quantizer; bool use_quantizer = param.use_quantizer;
std::string& data_format = param.data_format; std::string& data_format = param.data_format;
bool pads_equal = (paddings[0] == paddings[1]) && bool pads_less =
(paddings[2] == paddings[3]) && (paddings[0] == paddings[2]) && (paddings[1] < 2) && (paddings[3] < 2);
(paddings[0] == paddings[2]);
bool pads_equal = (paddings[0] == paddings[2]) &&
(paddings[0] == paddings[1]) &&
(paddings[2] == paddings[3]);
bool kps_equal = bool kps_equal =
(ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_equal; (ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_less;
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
...@@ -96,7 +99,9 @@ void PoolCompute::Run() { ...@@ -96,7 +99,9 @@ void PoolCompute::Run() {
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3]); in_dims[3],
paddings[1],
paddings[3]);
return; return;
} }
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
...@@ -110,7 +115,9 @@ void PoolCompute::Run() { ...@@ -110,7 +115,9 @@ void PoolCompute::Run() {
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3]); in_dims[3],
paddings[1],
paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din, lite::arm::math::pooling2x2s2_avg(din,
...@@ -122,7 +129,9 @@ void PoolCompute::Run() { ...@@ -122,7 +129,9 @@ void PoolCompute::Run() {
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive); exclusive,
paddings[1],
paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
...@@ -136,7 +145,9 @@ void PoolCompute::Run() { ...@@ -136,7 +145,9 @@ void PoolCompute::Run() {
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3]); in_dims[3],
paddings[1],
paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p1_avg(din, lite::arm::math::pooling3x3s1p1_avg(din,
...@@ -148,7 +159,9 @@ void PoolCompute::Run() { ...@@ -148,7 +159,9 @@ void PoolCompute::Run() {
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive); exclusive,
paddings[1],
paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
...@@ -162,7 +175,9 @@ void PoolCompute::Run() { ...@@ -162,7 +175,9 @@ void PoolCompute::Run() {
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3]); in_dims[3],
paddings[1],
paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p0_avg(din, lite::arm::math::pooling3x3s1p0_avg(din,
...@@ -174,7 +189,9 @@ void PoolCompute::Run() { ...@@ -174,7 +189,9 @@ void PoolCompute::Run() {
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive); exclusive,
paddings[1],
paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
...@@ -188,7 +205,9 @@ void PoolCompute::Run() { ...@@ -188,7 +205,9 @@ void PoolCompute::Run() {
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3]); in_dims[3],
paddings[1],
paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p0_avg(din, lite::arm::math::pooling3x3s2p0_avg(din,
...@@ -200,7 +219,9 @@ void PoolCompute::Run() { ...@@ -200,7 +219,9 @@ void PoolCompute::Run() {
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive); exclusive,
paddings[1],
paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
...@@ -214,7 +235,9 @@ void PoolCompute::Run() { ...@@ -214,7 +235,9 @@ void PoolCompute::Run() {
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3]); in_dims[3],
paddings[1],
paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p1_avg(din, lite::arm::math::pooling3x3s2p1_avg(din,
...@@ -226,11 +249,14 @@ void PoolCompute::Run() { ...@@ -226,11 +249,14 @@ void PoolCompute::Run() {
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive); exclusive,
paddings[1],
paddings[3]);
return; return;
} }
} }
} }
lite::arm::math::pooling_basic(din, lite::arm::math::pooling_basic(din,
dout, dout,
out_dims[0], out_dims[0],
......
...@@ -232,7 +232,7 @@ TEST(pool_arm, compute) { ...@@ -232,7 +232,7 @@ TEST(pool_arm, compute) {
lite::Tensor x; lite::Tensor x;
lite::Tensor output; lite::Tensor output;
lite::Tensor output_ref; lite::Tensor output_ref;
#if 0
// speedup for ci // speedup for ci
for (auto pooling_type : {"max", "avg"}) { for (auto pooling_type : {"max", "avg"}) {
for (auto ceil_mode : {true, false}) { for (auto ceil_mode : {true, false}) {
...@@ -337,6 +337,7 @@ TEST(pool_arm, compute) { ...@@ -337,6 +337,7 @@ TEST(pool_arm, compute) {
} }
} }
} }
#endif
} }
TEST(pool_arm, retrive_op) { TEST(pool_arm, retrive_op) {
......
...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode, ...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests"); DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result"); DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size"); DEFINE_int32(batch, 1, "batch size");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册