未验证 提交 7eb4a391 编写于 作者: H HappyAngel 提交者: GitHub

improve pooling speed in gaze model. (#3881)

* improve pooling speed in gaze. test=develoop

* fix format test=develop
上级 6d787479
......@@ -21,6 +21,17 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
int AdaptStartIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
floor(static_cast<double>(ph * input_size) / output_size));
}
int AdaptEndIndex(int ph, int input_size, int output_size) {
return static_cast<int>(
ceil(static_cast<double>((ph + 1) * input_size) / output_size));
}
void pooling_basic(const float* din,
float* dout,
int num,
......@@ -88,15 +99,27 @@ void pooling_basic(const float* din,
#pragma omp parallel for
for (int ind_c = 0; ind_c < chin; ++ind_c) {
for (int ind_h = 0; ind_h < hout; ++ind_h) {
int sh = ind_h * stride_h;
int eh = sh + kernel_h;
int sh, eh;
if (adaptive) {
sh = AdaptStartIndex(ind_h, hin, hout);
eh = AdaptEndIndex(ind_h, hin, hout);
} else {
sh = ind_h * stride_h;
eh = sh + kernel_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > hin ? hin : eh - pad_h;
}
for (int ind_w = 0; ind_w < wout; ++ind_w) {
int sw = ind_w * stride_w;
int ew = sw + kernel_w;
int sw, ew;
if (adaptive) {
sw = AdaptStartIndex(ind_w, win, wout);
ew = AdaptEndIndex(ind_w, win, wout);
} else {
sw = ind_w * stride_w;
ew = sw + kernel_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > win ? win : ew - pad_w;
}
float result = static_cast<float>(0);
int dst_ind = (ind_n * chout + ind_c) * size_channel_out +
ind_h * wout + ind_w;
......@@ -183,6 +206,20 @@ void pooling_basic(const float* din,
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/
#define P2x2S2P1_MAX \
"ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \
"ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \
"sub %[dr0], %[dr0], #4\n" /* sub */ \
"sub %[dr1], %[dr1], #4\n" /* sub */ \
"fmax v4.4s, v0.4s, v6.4s\n" /* max */ \
"fmax v5.4s, v2.4s, v8.4s\n" /* max */ \
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \
"fmax v6.4s, v4.4s, v5.4s\n" /* max reduce */ \
"subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \
"st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"ble 2f\n" /* bne s3_max_loop_mid */
#define P2x2S2P0_MAX \
"1: \n" \
"fmax v4.4s, v0.4s, v1.4s\n" /* max */ \
......@@ -194,6 +231,21 @@ void pooling_basic(const float* din,
"st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"bne 1b\n" /* bne s3_max_loop_mid */
#define P2x2S2P1_AVG \
"ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \
"ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \
"sub %[dr0], %[dr0], #4\n" /* sub */ \
"sub %[dr1], %[dr1], #4\n" /* sub */ \
"fadd v4.4s, v0.4s, v6.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
"fadd v5.4s, v2.4s, v8.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \
"fadd v6.4s, v4.4s, v5.4s\n" /* add reduce */ \
"subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \
"fmul v4.4s, v6.4s, %[vcoef_left].4s\n" /* mul coef */ \
"st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"ble 2f\n" /* bne s3_max_loop_mid */
#define P2x2S2P0_AVG \
"1: \n" /* load bias to q2, q3*/ \
"fadd v4.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
......@@ -205,6 +257,7 @@ void pooling_basic(const float* din,
"fmul v4.4s, v6.4s, %[vcoef].4s\n" /* mul coef */ \
"st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"bne 1b\n" /* bne s3_max_loop_mid */
#define P3x3S1_INIT \
"ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \
"ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \
......@@ -495,16 +548,45 @@ void pooling_basic(const float* din,
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n"
#define P2x2S2P1_MAX \
"vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \
"vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \
"sub %[dr0], #4 @sub \n" \
"sub %[dr1], #4 @sub \n" \
"vmax.f32 q8, q0, q4 @ max \n" \
"vmax.f32 q9, q2, q5 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vmax.f32 q5, q9, q8 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vst1.f32 {d10-d11}, [%[dr_out]]! @ store 4 out \n" \
"ble 2f @ bne \n"
#define P2x2S2P0_MAX \
"1: @ main loop\n" \
"vmax.f32 q4, q0, q1 @ max \n" \
"vmax.f32 q5, q2, q3 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vmax.f32 q6, q4, q5 @ max reduce\n" \
"vmax.f32 q8, q4, q5 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vst1.f32 {d12-d13}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne "
"vst1.f32 {d16-d17}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne \n"
#define P2x2S2P1_AVG \
"vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \
"vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \
"sub %[dr0], #4 @sub \n" \
"sub %[dr1], #4 @sub \n" \
"vadd.f32 q9, q0, q4 @ max \n" \
"vadd.f32 q8, q2, q5 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vadd.f32 q5, q9, q8 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vmul.f32 q4, q5, %q[vcoef_left] @ mul coef \n" \
"vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \
"ble 2f @ bne\n"
#define P2x2S2P0_AVG \
"1: @ main loop\n" \
......@@ -512,9 +594,9 @@ void pooling_basic(const float* din,
"vadd.f32 q5, q2, q3 @ add 0, 2, 4, 6 \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load d0-d3 \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load d4-d7 \n" \
"vadd.f32 q6, q4, q5 @ add reduce \n" \
"vadd.f32 q8, q4, q5 @ add reduce \n" \
"subs %[cnt_num], #1 @ subs \n" \
"vmul.f32 q4, q6, %q[vcoef] @ mul coef \n" \
"vmul.f32 q4, q8, %q[vcoef] @ mul coef \n" \
"vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne \n"
......@@ -1014,7 +1096,7 @@ void pooling1x1s2p0_max(const float* din,
TargetFree(TARGET(kARM), write_ptr);
}
void pooling2x2s2_max(const float* din,
void pooling2x2s2p0_max(const float* din,
float* dout,
int num,
int chout,
......@@ -1072,7 +1154,7 @@ void pooling2x2s2_max(const float* din,
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8");
#endif
dr0 -= 8;
dr1 -= 8;
......@@ -1098,7 +1180,7 @@ void pooling2x2s2_max(const float* din,
}
}
void pooling2x2s2_avg(const float* din,
void pooling2x2s2p0_avg(const float* din,
float* dout,
int num,
int chout,
......@@ -1135,12 +1217,14 @@ void pooling2x2s2_avg(const float* din,
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
vcoef = vdupq_n_f32(0.25f);
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
if (h * S + K - P > hin) {
dr1 = zero_ptr;
vcoef = vdupq_n_f32(0.5f);
}
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
......@@ -1161,7 +1245,7 @@ void pooling2x2s2_avg(const float* din,
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8");
#endif
dr0 -= 8;
dr1 -= 8;
......@@ -1171,8 +1255,14 @@ void pooling2x2s2_avg(const float* din,
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
float coef = 0.5f / (wend - wstart);
float coef = 0.25f;
float tmp = 0.f;
if (wend - wstart == 1 && pad_right == 0) {
coef *= 2;
}
if (h * S + K - P > hin && pad_bottom == 0) {
coef *= 2;
}
for (int i = wstart; i < wend; i++) {
tmp += dr0[i] + dr1[i];
}
......@@ -1189,6 +1279,235 @@ void pooling2x2s2_avg(const float* din,
TargetFree(TARGET(kARM), zero_ptr);
}
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
const int K = 2;
const int P = 1;
const int S = 2;
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
float32x4_t vzero = vdupq_n_f32(std::numeric_limits<float>::lowest());
if (w_unroll_remian == 0) {
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
if (h == 0) {
dr0 = r0;
dr1 = r0;
r0 = r1;
r1 = r0 + win;
} else {
r0 = r1 + win;
r1 = r0 + win;
}
if (h * S + K - P > hin) {
dr1 = dr0;
if (h * S + K - P > hin + 1) {
memset(dr_out, 0, wout * sizeof(float));
continue;
}
}
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(
P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vzero] "w"(vzero)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8");
#else
asm volatile(
P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vzero] "w"(vzero)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9");
#endif
dr0 -= 8;
dr1 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = wend == st ? 0.f : dr0[0];
for (int i = 0; i < wend - st; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
}
*(dr_out++) = tmp;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
wstart += S;
}
data_out_channel += wout;
}
}
}
}
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
const int K = 2;
const int P = 1;
const int S = 2;
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
float32x4_t vzero = vdupq_n_f32(0.f);
memset(zero_ptr, 0, win * sizeof(float));
if (w_unroll_remian == 0) {
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
float coef_h = 0.5f;
if (h == 0) {
dr0 = zero_ptr;
dr1 = r0;
r0 = r1;
r1 = r0 + win;
if (exclusive) {
coef_h = 1.f;
}
} else {
r0 = r1 + win;
r1 = r0 + win;
}
if (h * S + K - P > hin) {
dr1 = zero_ptr;
if (exclusive) {
coef_h = 1.f;
}
if (h * S + K - P > hin + 1) {
memset(dr_out, 0, wout * sizeof(float));
continue;
}
}
float coef_left_most = exclusive ? coef_h : coef_h / 2;
float32x4_t vcoef = vdupq_n_f32(coef_h / 2);
float coef_left[4] = {
coef_left_most, coef_h / 2, coef_h / 2, coef_h / 2};
float32x4_t vcoef_left = vld1q_f32(coef_left);
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(
P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef),
[vzero] "w"(vzero),
[vcoef_left] "w"(vcoef_left)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8");
#else
asm volatile(
P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef),
[vzero] "w"(vzero),
[vcoef_left] "w"(vcoef_left)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9");
#endif
dr0 -= 8;
dr1 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = 0.f;
float coef = coef_h / 2;
if (exclusive && wend - st == 1) {
coef = coef_h;
}
for (int i = 0; i < wend - st; i++) {
tmp += dr0[i] + dr1[i];
}
*(dr_out++) = tmp * coef;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
wstart += S;
}
data_out_channel += wout;
}
}
}
TargetFree(TARGET(kARM), zero_ptr);
}
void pooling3x3s1p1_max(const float* din,
float* dout,
int num,
......@@ -2217,6 +2536,9 @@ void pooling3x3s2p0_max(const float* din,
w_unroll_remian = wout - w_unroll_size * 4;
}
int remain = w_unroll_remian - 1;
int right = wout * 2 + 1 - win; // if need right pad
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
......@@ -2243,6 +2565,7 @@ void pooling3x3s2p0_max(const float* din,
}
}
int cnt_num = w_unroll_size;
int cnt_remain = remain;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
......@@ -2266,12 +2589,53 @@ void pooling3x3s2p0_max(const float* din,
"v9",
"v10",
"v11");
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
wstart += S;
}
#else
asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
asm volatile(
P3x3S2P0_INIT P3x3S2P0_MAX
"cmp %[remain], #0 @cmp cnt_num\n"
"sub %[dr0], #32 @sub - 8\n"
"sub %[dr1], #32 @sub - 8\n"
"sub %[dr2], #32 @sub - 8\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load \n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load \n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load \n"
"vmov.f32 s3,s2 @mov \n"
"vmov.f32 s7,s6 @mov \n"
"vmov.f32 s11,s10 @mov \n"
"vmax.f32 q0, q0, q1 @max n"
"sub %[dr0], #8 @add w \n"
"sub %[dr1], #8 @add w \n"
"sub %[dr2], #8 @add w \n"
"vmax.f32 q0, q0, q2 @max \n"
"vpmax.f32 d0, d0, d1 @pmax \n"
"vpmax.f32 d0, d0, d0 @pmax \n"
"subs %[remain], #1 @subs \n"
"vst1.f32 d0[0], [%[dr_out]]! @vst \n"
"bne 2b @bne \n"
"4: @exit\n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr2] "+r"(dr2),
[dr_out] "+r"(dr_out),
[remain] "+r"(cnt_remain),
[cnt_num] "+r"(cnt_num)
:
: "cc",
......@@ -2288,24 +2652,17 @@ void pooling3x3s2p0_max(const float* din,
"q9",
"q10",
"q11");
#endif
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
}
// deal with right pad
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
if (right) {
int wstart = (w_unroll_size * 4 + remain) * S;
int wend = std::min(wstart + K, win);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
wstart += S;
}
#endif
}
r0 = r2;
......@@ -2344,7 +2701,9 @@ void pooling3x3s2p0_avg(const float* din,
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
// do overflow process
w_unroll_size -= 1;
w_unroll_remian += 4;
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
memset(zero_ptr, 0, win * sizeof(float));
......
......@@ -76,7 +76,7 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom,
int pad_right);
void pooling2x2s2_max(const float* din,
void pooling2x2s2p0_max(const float* din,
float* dout,
int num,
int chout,
......@@ -88,7 +88,32 @@ void pooling2x2s2_max(const float* din,
int pad_bottom,
int pad_right);
void pooling2x2s2_avg(const float* din,
void pooling2x2s2p0_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
......
......@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling;
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0;
......@@ -107,7 +108,7 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din,
lite::arm::math::pooling2x2s2p0_max(din,
dout,
out_dims[0],
out_dims[1],
......@@ -120,7 +121,7 @@ void PoolCompute::Run() {
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din,
lite::arm::math::pooling2x2s2p0_avg(din,
dout,
out_dims[0],
out_dims[1],
......@@ -134,8 +135,38 @@ void PoolCompute::Run() {
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din,
dout,
......@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din,
dout,
......@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din,
dout,
......@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din,
dout,
......@@ -276,7 +307,6 @@ void PoolCompute::Run() {
use_quantizer,
pooling_type);
}
} // namespace arm
} // namespace kernels
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册