diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index 0955b09d92f64066000b03c4487f359880f1c2a5..3e6cbff0660be8f2542d059a39115bed52122ff1 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -21,6 +21,17 @@ namespace paddle { namespace lite { namespace arm { namespace math { + +int AdaptStartIndex(int ph, int input_size, int output_size) { + return static_cast( + floor(static_cast(ph * input_size) / output_size)); +} + +int AdaptEndIndex(int ph, int input_size, int output_size) { + return static_cast( + ceil(static_cast((ph + 1) * input_size) / output_size)); +} + void pooling_basic(const float* din, float* dout, int num, @@ -88,15 +99,27 @@ void pooling_basic(const float* din, #pragma omp parallel for for (int ind_c = 0; ind_c < chin; ++ind_c) { for (int ind_h = 0; ind_h < hout; ++ind_h) { - int sh = ind_h * stride_h; - int eh = sh + kernel_h; - sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; - eh = (eh - pad_h) > hin ? hin : eh - pad_h; + int sh, eh; + if (adaptive) { + sh = AdaptStartIndex(ind_h, hin, hout); + eh = AdaptEndIndex(ind_h, hin, hout); + } else { + sh = ind_h * stride_h; + eh = sh + kernel_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > hin ? hin : eh - pad_h; + } for (int ind_w = 0; ind_w < wout; ++ind_w) { - int sw = ind_w * stride_w; - int ew = sw + kernel_w; - sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; - ew = (ew - pad_w) > win ? win : ew - pad_w; + int sw, ew; + if (adaptive) { + sw = AdaptStartIndex(ind_w, win, wout); + ew = AdaptEndIndex(ind_w, win, wout); + } else { + sw = ind_w * stride_w; + ew = sw + kernel_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > win ? win : ew - pad_w; + } float result = static_cast(0); int dst_ind = (ind_n * chout + ind_c) * size_channel_out + ind_h * wout + ind_w; @@ -183,6 +206,20 @@ void pooling_basic(const float* din, "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ +#define P2x2S2P1_MAX \ + "ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \ + "ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \ + "sub %[dr0], %[dr0], #4\n" /* sub */ \ + "sub %[dr1], %[dr1], #4\n" /* sub */ \ + "fmax v4.4s, v0.4s, v6.4s\n" /* max */ \ + "fmax v5.4s, v2.4s, v8.4s\n" /* max */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "fmax v6.4s, v4.4s, v5.4s\n" /* max reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* bne s3_max_loop_mid */ + #define P2x2S2P0_MAX \ "1: \n" \ "fmax v4.4s, v0.4s, v1.4s\n" /* max */ \ @@ -194,6 +231,21 @@ void pooling_basic(const float* din, "st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ "bne 1b\n" /* bne s3_max_loop_mid */ +#define P2x2S2P1_AVG \ + "ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \ + "ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \ + "sub %[dr0], %[dr0], #4\n" /* sub */ \ + "sub %[dr1], %[dr1], #4\n" /* sub */ \ + "fadd v4.4s, v0.4s, v6.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "fadd v5.4s, v2.4s, v8.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "fadd v6.4s, v4.4s, v5.4s\n" /* add reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "fmul v4.4s, v6.4s, %[vcoef_left].4s\n" /* mul coef */ \ + "st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* bne s3_max_loop_mid */ + #define P2x2S2P0_AVG \ "1: \n" /* load bias to q2, q3*/ \ "fadd v4.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ @@ -205,6 +257,7 @@ void pooling_basic(const float* din, "fmul v4.4s, v6.4s, %[vcoef].4s\n" /* mul coef */ \ "st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ "bne 1b\n" /* bne s3_max_loop_mid */ + #define P3x3S1_INIT \ "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ "ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \ @@ -495,16 +548,45 @@ void pooling_basic(const float* din, "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" +#define P2x2S2P1_MAX \ + "vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \ + "vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \ + "sub %[dr0], #4 @sub \n" \ + "sub %[dr1], #4 @sub \n" \ + "vmax.f32 q8, q0, q4 @ max \n" \ + "vmax.f32 q9, q2, q5 @ max \n" \ + "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ + "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ + "vmax.f32 q5, q9, q8 @ max reduce\n" \ + "subs %[cnt_num], #1 @ subs cnt_num \n" \ + "vst1.f32 {d10-d11}, [%[dr_out]]! @ store 4 out \n" \ + "ble 2f @ bne \n" + #define P2x2S2P0_MAX \ "1: @ main loop\n" \ "vmax.f32 q4, q0, q1 @ max \n" \ "vmax.f32 q5, q2, q3 @ max \n" \ "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ - "vmax.f32 q6, q4, q5 @ max reduce\n" \ + "vmax.f32 q8, q4, q5 @ max reduce\n" \ "subs %[cnt_num], #1 @ subs cnt_num \n" \ - "vst1.f32 {d12-d13}, [%[dr_out]]! @ store 4 out \n" \ - "bne 1b @ bne " + "vst1.f32 {d16-d17}, [%[dr_out]]! @ store 4 out \n" \ + "bne 1b @ bne \n" + +#define P2x2S2P1_AVG \ + "vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \ + "vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \ + "sub %[dr0], #4 @sub \n" \ + "sub %[dr1], #4 @sub \n" \ + "vadd.f32 q9, q0, q4 @ max \n" \ + "vadd.f32 q8, q2, q5 @ max \n" \ + "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ + "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ + "vadd.f32 q5, q9, q8 @ max reduce\n" \ + "subs %[cnt_num], #1 @ subs cnt_num \n" \ + "vmul.f32 q4, q5, %q[vcoef_left] @ mul coef \n" \ + "vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \ + "ble 2f @ bne\n" #define P2x2S2P0_AVG \ "1: @ main loop\n" \ @@ -512,9 +594,9 @@ void pooling_basic(const float* din, "vadd.f32 q5, q2, q3 @ add 0, 2, 4, 6 \n" \ "vld2.f32 {d0-d3}, [%[dr0]]! @ load d0-d3 \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load d4-d7 \n" \ - "vadd.f32 q6, q4, q5 @ add reduce \n" \ + "vadd.f32 q8, q4, q5 @ add reduce \n" \ "subs %[cnt_num], #1 @ subs \n" \ - "vmul.f32 q4, q6, %q[vcoef] @ mul coef \n" \ + "vmul.f32 q4, q8, %q[vcoef] @ mul coef \n" \ "vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \ "bne 1b @ bne \n" @@ -1014,17 +1096,17 @@ void pooling1x1s2p0_max(const float* din, TargetFree(TARGET(kARM), write_ptr); } -void pooling2x2s2_max(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - int pad_bottom, - int pad_right) { +void pooling2x2s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right) { int size_channel_out = wout * hout; int size_channel_in = win * hin; auto data_out = static_cast(dout); @@ -1072,7 +1154,7 @@ void pooling2x2s2_max(const float* din, [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8"); #endif dr0 -= 8; dr1 -= 8; @@ -1098,18 +1180,18 @@ void pooling2x2s2_max(const float* din, } } -void pooling2x2s2_avg(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - bool exclusive, - int pad_bottom, - int pad_right) { +void pooling2x2s2p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right) { int size_channel_out = wout * hout; int size_channel_in = win * hin; auto data_out = static_cast(dout); @@ -1135,12 +1217,14 @@ void pooling2x2s2_avg(const float* din, const float* data_in_channel = data_in_batch + c * size_channel_in; const float* r0 = data_in_channel; const float* r1 = r0 + win; + vcoef = vdupq_n_f32(0.25f); for (int h = 0; h < hout; h++) { float* dr_out = data_out_channel; auto dr0 = r0; auto dr1 = r1; if (h * S + K - P > hin) { dr1 = zero_ptr; + vcoef = vdupq_n_f32(0.5f); } int cnt_num = w_unroll_size; if (w_unroll_size > 0) { @@ -1161,7 +1245,7 @@ void pooling2x2s2_avg(const float* din, [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : [vcoef] "w"(vcoef) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8"); #endif dr0 -= 8; dr1 -= 8; @@ -1171,8 +1255,14 @@ void pooling2x2s2_avg(const float* din, int wstart = 0; for (int j = 0; j < w_unroll_remian; ++j) { int wend = std::min(wstart + K, rem); - float coef = 0.5f / (wend - wstart); + float coef = 0.25f; float tmp = 0.f; + if (wend - wstart == 1 && pad_right == 0) { + coef *= 2; + } + if (h * S + K - P > hin && pad_bottom == 0) { + coef *= 2; + } for (int i = wstart; i < wend; i++) { tmp += dr0[i] + dr1[i]; } @@ -1189,6 +1279,235 @@ void pooling2x2s2_avg(const float* din, TargetFree(TARGET(kARM), zero_ptr); } +void pooling2x2s2p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 2; + const int P = 1; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + float32x4_t vzero = vdupq_n_f32(std::numeric_limits::lowest()); + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + if (h == 0) { + dr0 = r0; + dr1 = r0; + r0 = r1; + r1 = r0 + win; + } else { + r0 = r1 + win; + r1 = r0 + win; + } + if (h * S + K - P > hin) { + dr1 = dr0; + if (h * S + K - P > hin + 1) { + memset(dr_out, 0, wout * sizeof(float)); + continue; + } + } + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile( + P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vzero] "w"(vzero) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8"); +#else + asm volatile( + P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vzero] "w"(vzero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9"); +#endif + dr0 -= 8; + dr1 -= 8; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = wend == st ? 0.f : dr0[0]; + for (int i = 0; i < wend - st; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + } + *(dr_out++) = tmp; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + wstart += S; + } + data_out_channel += wout; + } + } + } +} + +void pooling2x2s2p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 2; + const int P = 1; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + float32x4_t vzero = vdupq_n_f32(0.f); + memset(zero_ptr, 0, win * sizeof(float)); + + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + float coef_h = 0.5f; + if (h == 0) { + dr0 = zero_ptr; + dr1 = r0; + r0 = r1; + r1 = r0 + win; + if (exclusive) { + coef_h = 1.f; + } + } else { + r0 = r1 + win; + r1 = r0 + win; + } + if (h * S + K - P > hin) { + dr1 = zero_ptr; + if (exclusive) { + coef_h = 1.f; + } + if (h * S + K - P > hin + 1) { + memset(dr_out, 0, wout * sizeof(float)); + continue; + } + } + float coef_left_most = exclusive ? coef_h : coef_h / 2; + float32x4_t vcoef = vdupq_n_f32(coef_h / 2); + float coef_left[4] = { + coef_left_most, coef_h / 2, coef_h / 2, coef_h / 2}; + float32x4_t vcoef_left = vld1q_f32(coef_left); + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile( + P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), + [vzero] "w"(vzero), + [vcoef_left] "w"(vcoef_left) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8"); +#else + asm volatile( + P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), + [vzero] "w"(vzero), + [vcoef_left] "w"(vcoef_left) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9"); +#endif + dr0 -= 8; + dr1 -= 8; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = 0.f; + float coef = coef_h / 2; + if (exclusive && wend - st == 1) { + coef = coef_h; + } + for (int i = 0; i < wend - st; i++) { + tmp += dr0[i] + dr1[i]; + } + *(dr_out++) = tmp * coef; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + wstart += S; + } + data_out_channel += wout; + } + } + } + TargetFree(TARGET(kARM), zero_ptr); +} + void pooling3x3s1p1_max(const float* din, float* dout, int num, @@ -2217,6 +2536,9 @@ void pooling3x3s2p0_max(const float* din, w_unroll_remian = wout - w_unroll_size * 4; } + int remain = w_unroll_remian - 1; + int right = wout * 2 + 1 - win; // if need right pad + for (int n = 0; n < num; ++n) { float* data_out_batch = data_out + n * chout * size_channel_out; const float* data_in_batch = data_in + n * chin * size_channel_in; @@ -2243,6 +2565,7 @@ void pooling3x3s2p0_max(const float* din, } } int cnt_num = w_unroll_size; + int cnt_remain = remain; if (w_unroll_size > 0) { #ifdef __aarch64__ asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX @@ -2266,46 +2589,80 @@ void pooling3x3s2p0_max(const float* din, "v9", "v10", "v11"); -#else - asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr2] "+r"(dr2), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -#endif dr0 -= 8; dr1 -= 8; dr2 -= 8; - } - // deal with right pad - int rem = win - (w_unroll_size * 4) * S; - int wstart = 0; - for (int j = 0; j < w_unroll_remian; ++j) { - int wend = std::min(wstart + K, rem); - float tmp = dr0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { - tmp = std::max(tmp, dr0[i]); - tmp = std::max(tmp, dr1[i]); - tmp = std::max(tmp, dr2[i]); + int rem = win - (w_unroll_size * 4) * S; + int wstart = 0; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, rem); + float tmp = dr0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + tmp = std::max(tmp, dr2[i]); + } + *(dr_out++) = tmp; + wstart += S; } - *(dr_out++) = tmp; - wstart += S; +#else + asm volatile( + P3x3S2P0_INIT P3x3S2P0_MAX + "cmp %[remain], #0 @cmp cnt_num\n" + "sub %[dr0], #32 @sub - 8\n" + "sub %[dr1], #32 @sub - 8\n" + "sub %[dr2], #32 @sub - 8\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load \n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load \n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load \n" + "vmov.f32 s3,s2 @mov \n" + "vmov.f32 s7,s6 @mov \n" + "vmov.f32 s11,s10 @mov \n" + "vmax.f32 q0, q0, q1 @max n" + "sub %[dr0], #8 @add w \n" + "sub %[dr1], #8 @add w \n" + "sub %[dr2], #8 @add w \n" + "vmax.f32 q0, q0, q2 @max \n" + "vpmax.f32 d0, d0, d1 @pmax \n" + "vpmax.f32 d0, d0, d0 @pmax \n" + "subs %[remain], #1 @subs \n" + "vst1.f32 d0[0], [%[dr_out]]! @vst \n" + "bne 2b @bne \n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [remain] "+r"(cnt_remain), + [cnt_num] "+r"(cnt_num) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); + if (right) { + int wstart = (w_unroll_size * 4 + remain) * S; + int wend = std::min(wstart + K, win); + float tmp = dr0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(dr0[i], dr1[i])); + tmp = std::max(tmp, dr2[i]); + } + *(dr_out++) = tmp; + } +#endif } r0 = r2; @@ -2344,7 +2701,9 @@ void pooling3x3s2p0_avg(const float* din, w_unroll_size -= 1; w_unroll_remian = wout - w_unroll_size * 4; } - + // do overflow process + w_unroll_size -= 1; + w_unroll_remian += 4; auto zero_ptr = static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); memset(zero_ptr, 0, win * sizeof(float)); diff --git a/lite/backends/arm/math/pooling.h b/lite/backends/arm/math/pooling.h index 7bbffa8e2f4594da4be589569efc0ef18b8dd0da..572919e3f083f736d8f49b3bae0dd2820fac35c4 100644 --- a/lite/backends/arm/math/pooling.h +++ b/lite/backends/arm/math/pooling.h @@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din, int pad_bottom, int pad_right); -void pooling2x2s2_max(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - int pad_bottom, - int pad_right); - -void pooling2x2s2_avg(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - bool exclusive, - int pad_bottom, - int pad_right); +void pooling2x2s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right); + +void pooling2x2s2p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right); + +void pooling2x2s2p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right); + +void pooling2x2s2p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right); void pooling3x3s1p1_max(const float* din, float* dout, diff --git a/lite/kernels/arm/pool_compute.cc b/lite/kernels/arm/pool_compute.cc index ff6100c4e2c68d7eee0d5d0eeabbb64a1ca699e2..7beac7e0f8e86fff069650fa35dbba168a39090c 100644 --- a/lite/kernels/arm/pool_compute.cc +++ b/lite/kernels/arm/pool_compute.cc @@ -58,6 +58,7 @@ void PoolCompute::Run() { bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && (ksize[1] == in_dims[3]) && kps_equal && pads_equal; global_pooling = param.global_pooling || global_pooling; + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[2 * i] = 0; @@ -107,35 +108,65 @@ void PoolCompute::Run() { } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) { if (pooling_type == "max") { - lite::arm::math::pooling2x2s2_max(din, - dout, - out_dims[0], - out_dims[1], - out_dims[2], - out_dims[3], - in_dims[1], - in_dims[2], - in_dims[3], - paddings[1], - paddings[3]); + lite::arm::math::pooling2x2s2p0_max(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + paddings[1], + paddings[3]); return; } else if (pooling_type == "avg") { - lite::arm::math::pooling2x2s2_avg(din, - dout, - out_dims[0], - out_dims[1], - out_dims[2], - out_dims[3], - in_dims[1], - in_dims[2], - in_dims[3], - exclusive, - paddings[1], - paddings[3]); + lite::arm::math::pooling2x2s2p0_avg(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + exclusive, + paddings[1], + paddings[3]); return; } - } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && + } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 && kps_equal) { + if (pooling_type == "max") { + lite::arm::math::pooling2x2s2p1_max(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + paddings[1], + paddings[3]); + return; + } else if (pooling_type == "avg") { + lite::arm::math::pooling2x2s2p1_avg(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + exclusive, + paddings[1], + paddings[3]); + return; + } + } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s1p1_max(din, dout, @@ -165,7 +196,7 @@ void PoolCompute::Run() { return; } } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && - kps_equal) { + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s1p0_max(din, dout, @@ -195,7 +226,7 @@ void PoolCompute::Run() { return; } } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && - kps_equal) { + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s2p0_max(din, dout, @@ -225,7 +256,7 @@ void PoolCompute::Run() { return; } } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && - kps_equal) { + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s2p1_max(din, dout, @@ -276,7 +307,6 @@ void PoolCompute::Run() { use_quantizer, pooling_type); } - } // namespace arm } // namespace kernels } // namespace lite