From 1433cd740394b4541896c0e51bbffa765b83b80b Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Wed, 6 Nov 2019 17:25:13 +0800 Subject: [PATCH] fix pooling bug and speed --- lite/backends/arm/math/pooling.cc | 4515 +++++++++++--------------- lite/backends/arm/math/pooling.h | 21 + lite/kernels/arm/pool_compute.cc | 28 + lite/tests/math/CMakeLists.txt | 1 + lite/tests/math/pool_compute_test.cc | 454 +++ 5 files changed, 2314 insertions(+), 2705 deletions(-) create mode 100644 lite/tests/math/pool_compute_test.cc diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index 38078580c2..d023ed35fd 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -21,7 +21,6 @@ namespace paddle { namespace lite { namespace arm { namespace math { - void pooling_basic(const float* din, float* dout, int num, @@ -41,6 +40,7 @@ void pooling_basic(const float* din, bool use_quantizer, const std::string& pooling_type) { // no need to pad input tensor, border is zero pad inside this function + memset(dout, 0, num * chout * hout * wout * sizeof(float)); int kernel_h = ksize[0]; int kernel_w = ksize[1]; int stride_h = strides[0]; @@ -85,113 +85,705 @@ void pooling_basic(const float* din, LOG(FATAL) << "unsupported pooling type: " << pooling_type; } } else { - if (pooling_type == "max") { - // Pooling_max - for (int n = 0; n < num; ++n) { - float* dout_ch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; -#pragma omp parallel for - for (int c = 0; c < chout; c++) { - float* dout_row = dout_ch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - for (int i = 0; i < hout; i++) { - for (int j = 0; j < wout; j++) { - int hstart = i * stride_h - pad_h; - int wstart = j * stride_w - pad_w; - int hend = std::min(hstart + kernel_h, hin + pad_h); - int wend = std::min(wstart + kernel_w, win + pad_w); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, hin); - wend = std::min(wend, win); - int pool_size = (hend - hstart) * (wend - wstart); - if (pool_size == 0) continue; - float tmp1 = din_ch[hstart * win + wstart]; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - float tmp2 = din_ch[h * win + w]; - tmp1 = tmp1 > tmp2 ? tmp1 : tmp2; + for (int ind_n = 0; ind_n < num; ++ind_n) { + for (int ind_c = 0; ind_c < chin; ++ind_c) { + for (int ind_h = 0; ind_h < hout; ++ind_h) { + int sh = ind_h * stride_h; + int eh = sh + kernel_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > hin ? hin : eh - pad_h; + for (int ind_w = 0; ind_w < wout; ++ind_w) { + int sw = ind_w * stride_w; + int ew = sw + kernel_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > win ? win : ew - pad_w; + float result = static_cast(0); + int dst_ind = (ind_n * chout + ind_c) * size_channel_out + + ind_h * wout + ind_w; + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_ind = + (ind_n * chin + ind_c) * size_channel_in + kh * win + kw; + if (kh == sh && kw == sw) { + result = din[src_ind]; + } else { + if (pooling_type == "max") { + result = result >= din[src_ind] ? result : din[src_ind]; + } else if (pooling_type == "avg") { + result += din[src_ind]; + } } } - dout_row[j] = tmp1; } - dout_row += wout; - } - } - } - } else if (pooling_type == "avg") { - if (exclusive) { - // Pooling_average_exclude_padding - for (int n = 0; n < num; ++n) { - float* dout_ch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; -#pragma omp parallel for - for (int c = 0; c < chout; c++) { - float* dout_row = dout_ch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - for (int i = 0; i < hout; i++) { - for (int j = 0; j < wout; j++) { - int hstart = i * stride_h - pad_h; - int wstart = j * stride_w - pad_w; - int hend = std::min(hstart + kernel_h, hin + pad_h); - int wend = std::min(wstart + kernel_w, win + pad_w); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, hin); - wend = std::min(wend, win); - int pool_size = (hend - hstart) * (wend - wstart); - if (pool_size == 0) continue; - float sum = 0.f; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - sum += din_ch[h * win + w]; + if (pooling_type == "avg") { + if (exclusive) { + int div = (ew - sw) * (eh - sh); + div = div > 0 ? div : 1; + result /= div; + } else { + int bh = kernel_h; + int bw = kernel_w; + if (ew == win) { + bw = sw + kernel_w >= win + pad_w ? win + pad_w + : sw + kernel_w; + bw -= sw; + if (sw - pad_w < 0 && sw + kernel_w > win + pad_w) { + bw += pad_w; } } - dout_row[j] = sum / pool_size; - } - dout_row += wout; - } - } - } - } else { // Pooling_average_include_padding - for (int n = 0; n < num; ++n) { - float* dout_ch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; -#pragma omp parallel for - for (int c = 0; c < chout; c++) { - float* dout_row = dout_ch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - for (int i = 0; i < hout; i++) { - for (int j = 0; j < wout; j++) { - int hstart = i * stride_h - pad_h; - int wstart = j * stride_w - pad_w; - int hend = std::min(hstart + kernel_h, hin + pad_h); - int wend = std::min(wstart + kernel_w, win + pad_w); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, hin); - wend = std::min(wend, win); - int pool_size = (hend - hstart) * (wend - wstart); - if (pool_size == 0) continue; - float sum = 0.f; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - sum += din_ch[h * win + w]; + if (eh == hin) { + bh = sh + kernel_h >= hin + pad_h ? hin + pad_h + : sh + kernel_h; + bh -= sh; + if (sh - pad_h < 0 && sh + kernel_h > hin + pad_h) { + bh += pad_h; } } - dout_row[j] = sum / (kernel_w * kernel_h); + result /= bh * bw; } - dout_row += wout; } + dout[dst_ind] = result; } } } - } else { - LOG(FATAL) << "unsupported pooling type: " << pooling_type; } } } +#ifdef __aarch64__ +#define GLOBAL_INIT \ + "ld1 {v0.4s-v1.4s}, [%[data_in_channel]], #32 \n" \ + "ld1 {v2.4s-v3.4s}, [%[data_in_channel]], #32 \n" +#define GLOBAL_MAX \ + "1: \n" \ + "fmax v4.4s, v0.4s, v2.4s \n" \ + "fmax v5.4s, v1.4s, v3.4s \n" \ + "ld1 {v0.4s-v1.4s}, [%[data_in_channel]], #32 \n" \ + "ld1 {v2.4s-v3.4s}, [%[data_in_channel]], #32 \n" \ + "fmax v6.4s, v4.4s, v5.4s \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fmax %w[vmax].4s, %w[vmax].4s, v6.4s \n" \ + "bne 1b \n" +#define GLOBAL_AVG \ + "1: \n" \ + "fadd %[vsum].4s, %[vsum].4s, v0.4s \n" \ + "fadd v4.4s, v1.4s, v2.4s \n" \ + "ld1 {v0.4s-v1.4s}, [%[data_in_channel]], #32 \n" \ + "fadd %[vsum].4s, %[vsum].4s, v3.4s \n" \ + "subs %w[cnt], %w[cnt], #1 \n" \ + "fadd %w[vsum].4s, %w[vsum].4s, v4.4s \n" \ + "ld1 {v2.4s-v3.4s}, [%[data_in_channel]], #32 \n" \ + "bne 1b \n" + +#define P2x2S2_INIT \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ + +#define P2x2S2P0_MAX \ + "1: \n" \ + "fmax v4.4s, v0.4s, v1.4s\n" /* max */ \ + "fmax v5.4s, v2.4s, v3.4s\n" /* max */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "fmax v6.4s, v4.4s, v5.4s\n" /* max reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#define P2x2S2P0_AVG \ + "1: \n" /* load bias to q2, q3*/ \ + "fadd v4.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "fadd v5.4s, v2.4s, v3.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "fadd v6.4s, v4.4s, v5.4s\n" /* add reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "fmul v4.4s, v6.4s, %[vcoef].4s\n" /* mul coef */ \ + "st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ +#define P3x3S1_INIT \ + "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ + "ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \ + "ldr q2, [%[dr2]], #16\n" /* load q2, dr2, 0-3*/ \ + "ldr d3, [%[dr0]]\n" /* load q3, dr0, 4-5*/ \ + "ldr d4, [%[dr1]]\n" /* load q4, dr1, 4-5*/ \ + "ldr d5, [%[dr2]]\n" /* load q5, dr2, 4-5*/ + +#define P3x3S1P1_MAX \ + "ext v6.16b, v0.16b, v3.16b, #4\n" /* ext 1, 2, 3, 4, r0 */ \ + "ext v7.16b, v1.16b, v4.16b, #4\n" /* ext 1, 2, 3, 4, r1 */ \ + "ext v8.16b, v2.16b, v5.16b, #4\n" /* ext 1, 2, 3, 4, r2 */ \ + "ext v9.16b, %[vmin].16b, v0.16b, #12\n" /* ext -1, 0, 1, 2 */ \ + "ext v10.16b, %[vmin].16b, v1.16b, #12\n" /* ext -1, 0, 1, 2 */ \ + "ext v11.16b, %[vmin].16b, v2.16b, #12\n" /* ext -1, 0, 1, 2 */ \ + "fmax v3.4s, v0.4s, v1.4s\n" \ + "fmax v4.4s, v2.4s, v6.4s\n" \ + "fmax v5.4s, v7.4s, v8.4s\n" \ + \ + "fmax v6.4s, v9.4s, v10.4s\n" \ + "fmax v7.4s, v11.4s, v3.4s\n" \ + "fmax v8.4s, v4.4s, v5.4s\n" \ + "subs %[dr0], %[dr0], #4\n" \ + "subs %[dr1], %[dr1], #4\n" \ + "subs %[dr2], %[dr2], #4\n" \ + \ + "fmax v9.4s, v6.4s, v7.4s\n" \ + "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ + "ldr q1, [%[dr1]], #16\n" /* load q0, dr0, 0-3*/ \ + "ldr q2, [%[dr2]], #16\n" /* load q0, dr0, 0-3*/ \ + "fmax v7.4s, v8.4s, v9.4s\n" \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "ldr d3, [%[dr0]] \n" /* load q0, dr0, 0-3*/ \ + "ldr d4, [%[dr1]]\n" /* load q4, dr1, 4-5*/ \ + "ldr d5, [%[dr2]]\n" /* load q4, dr1, 4-5*/ \ + "st1 {v7.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ + +#define P3x3S1P0_MAX \ + "1: \n" /* */ \ + "ext v6.16b, v0.16b, v3.16b, #4\n" /* ext 1, 2, 3, 4, r0 */ \ + "ext v7.16b, v1.16b, v4.16b, #4\n" /* ext 1, 2, 3, 4, r1 */ \ + "ext v8.16b, v2.16b, v5.16b, #4\n" /* ext 1, 2, 3, 4, r2 */ \ + "ext v9.16b, v0.16b, v3.16b, #8\n" /* ext 2, 3, 4, 5, r0 */ \ + "ext v10.16b, v1.16b, v4.16b, #8\n" /* ext 2, 3, 4, 5, r1 */ \ + "ext v11.16b, v2.16b, v5.16b, #8\n" /* ext 2, 3, 4, 5, r2 */ \ + "fmax v3.4s, v0.4s, v1.4s\n" \ + "fmax v4.4s, v2.4s, v6.4s\n" \ + "fmax v5.4s, v7.4s, v8.4s\n" \ + "fmax v6.4s, v9.4s, v10.4s\n" \ + \ + "fmax v7.4s, v11.4s, v3.4s\n" \ + "fmax v8.4s, v4.4s, v5.4s\n" \ + "fmax v9.4s, v6.4s, v7.4s\n" \ + "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ + "ldr q1, [%[dr1]], #16\n" /* load q0, dr0, 0-3*/ \ + "ldr q2, [%[dr2]], #16\n" /* load q0, dr0, 0-3*/ \ + \ + "fmax v7.4s, v8.4s, v9.4s\n" \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "ldr d3, [%[dr0]] \n" /* load q0, dr0, 0-3*/ \ + "ldr d4, [%[dr1]]\n" /* load q4, dr1, 4-5*/ \ + "ldr d5, [%[dr2]]\n" /* load q4, dr1, 4-5*/ \ + "st1 {v7.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#define P3x3S1P1_AVG \ + "ext v6.16b, v0.16b, v3.16b, #4\n" /* ext 1, 2, 3, 4, r0 */ \ + "ext v7.16b, v1.16b, v4.16b, #4\n" /* ext 1, 2, 3, 4, r1 */ \ + "ext v8.16b, v2.16b, v5.16b, #4\n" /* ext 1, 2, 3, 4, r2 */ \ + "ext v9.16b, v31.16b, v0.16b, #12\n" /* ext -1, 0, 1, 2, r0 */ \ + "ext v10.16b, v31.16b, v1.16b, #12\n" /* ext -1, 0, 1, 2, r1 */ \ + "ext v11.16b, v31.16b, v2.16b, #12\n" /* ext -1, 0, 1, 2, r2 */ \ + \ + "fadd v3.4s, v0.4s, v1.4s\n" \ + "fadd v4.4s, v2.4s, v6.4s\n" \ + "fadd v5.4s, v7.4s, v8.4s\n" \ + "fadd v6.4s, v9.4s, v10.4s\n" \ + "fadd v7.4s, v11.4s, v3.4s\n" \ + \ + "subs %[dr0], %[dr0], #4\n" \ + "subs %[dr1], %[dr1], #4\n" \ + "subs %[dr2], %[dr2], #4\n" \ + \ + "fadd v8.4s, v4.4s, v5.4s\n" \ + "fadd v9.4s, v6.4s, v7.4s\n" \ + "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ + "ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \ + "ldr q2, [%[dr2]], #16\n" /* load q2, dr2, 0-3*/ \ + \ + "fadd v10.4s, v8.4s, v9.4s\n" \ + "ldr d3, [%[dr0]]\n" /* load q3, dr0, 4-5*/ \ + "ldr d4, [%[dr1]]\n" /* load q4, dr1, 4-5*/ \ + \ + "fmul v11.4s, v10.4s, %[vcoef_left].4s\n" \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "ldr d5, [%[dr2]]\n" /* load q5, dr2, 4-5*/ \ + \ + "st1 {v11.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ +#define P3x3S1P0_AVG \ + "1: \n" /* */ \ + "ext v6.16b, v0.16b, v3.16b, #4\n" /* ext 1, 2, 3, 4, r0 */ \ + "ext v7.16b, v1.16b, v4.16b, #4\n" /* ext 1, 2, 3, 4, r1 */ \ + "ext v8.16b, v2.16b, v5.16b, #4\n" /* ext 1, 2, 3, 4, r2 */ \ + "ext v9.16b, v0.16b, v3.16b, #8\n" /* ext 2, 3, 4, 5, r0 */ \ + "ext v10.16b, v1.16b, v4.16b, #8\n" /* ext 2, 3, 4, 5, r1 */ \ + "ext v11.16b, v2.16b, v5.16b, #8\n" /* ext 2, 3, 4, 5, r2 */ \ + \ + "fadd v3.4s, v0.4s, v1.4s\n" \ + "fadd v4.4s, v2.4s, v6.4s\n" \ + "fadd v5.4s, v7.4s, v8.4s\n" \ + "fadd v6.4s, v9.4s, v10.4s\n" \ + "fadd v7.4s, v11.4s, v3.4s\n" \ + \ + "fadd v8.4s, v4.4s, v5.4s\n" \ + "fadd v9.4s, v6.4s, v7.4s\n" \ + \ + "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ + "ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \ + "ldr q2, [%[dr2]], #16\n" /* load q2, dr2, 0-3*/ \ + "fadd v10.4s, v8.4s, v9.4s\n" \ + \ + "ldr d3, [%[dr0]]\n" /* load q3, dr0, 4-5*/ \ + "ldr d4, [%[dr1]]\n" /* load q4, dr1, 4-5*/ \ + "fmul v11.4s, v10.4s, %[vcoef].4s\n" \ + \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "ldr d5, [%[dr2]]\n" /* load q3, dr0, 4-5*/ \ + "st1 {v11.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#define P3x3S2_INIT \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "ld2 {v4.4s, v5.4s}, [%[dr2]], #32\n" /* load q4-q5, dr2, 0-7*/ + +#define P3x3S2P0_INIT \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "ld2 {v4.4s, v5.4s}, [%[dr2]], #32\n" /* load q4-q5, dr2, 0-7*/ \ + "ld1 {v6.2s}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "ld1 {v7.2s}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "ld1 {v8.2s}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ + +#define P3x3S2P1_MAX \ + "fmax v6.4s, v0.4s, v1.4s\n" \ + "fmax v7.4s, v2.4s, v3.4s\n" \ + "fmax v8.4s, v4.4s, v5.4s\n" \ + "ext v0.16b, %[vmin].16b, v1.16b, #12\n" /* ext 0, 1, 3, 5 */ \ + "ext v2.16b, %[vmin].16b, v3.16b, #12\n" /* ext 0, 1, 3, 5 */ \ + "ext v4.16b, %[vmin].16b, v5.16b, #12\n" /* ext 0, 1, 3, 5 */ \ + "fmax v1.4s, v6.4s, v0.4s\n" \ + "fmax v3.4s, v7.4s, v2.4s\n" \ + "fmax v11.4s, v8.4s, v4.4s\n" \ + \ + "subs %[dr0], %[dr0], #4\n" \ + "subs %[dr1], %[dr1], #4\n" \ + "subs %[dr2], %[dr2], #4\n" \ + \ + "fmax v9.4s, v1.4s, v3.4s\n" /* reduce */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "ld2 {v4.4s, v5.4s}, [%[dr2]], #32\n" /* load q4-q5, dr2, 0-7*/ \ + \ + "fmax v10.4s, v9.4s, v11.4s\n" /* reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "ld1 {v6.2s}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "ld1 {v7.2s}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "ld1 {v8.2s}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "st1 {v10.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ + +#define P3x3S2P0_MAX \ + "1: \n" /* load bias to q2, q3*/ \ + "fmax v9.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "fmax v10.4s, v2.4s, v3.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "fmax v11.4s, v4.4s, v5.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "ext v1.16b, v0.16b, v6.16b, #4\n" /* ext 2, 4, 6, 8, r0 */ \ + "ext v3.16b, v2.16b, v7.16b, #4\n" /* ext 2, 4, 6, 8, r1 */ \ + "ext v5.16b, v4.16b, v8.16b, #4\n" /* ext 2, 4, 6, 8, r2 */ \ + \ + "fmax v6.4s, v9.4s, v1.4s\n" /* max */ \ + "fmax v7.4s, v10.4s, v3.4s\n" /* max */ \ + "fmax v8.4s, v11.4s, v5.4s\n" /* max */ \ + \ + "fmax v9.4s, v6.4s, v7.4s\n" /* max reduce */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "ld2 {v4.4s, v5.4s}, [%[dr2]], #32\n" /* load q4-q5, dr2, 0-7*/ \ + \ + "fmax v10.4s, v8.4s, v9.4s\n" /* max reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "ld1 {v6.2s}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "ld1 {v7.2s}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "ld1 {v8.2s}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "st1 {v10.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#define P3x3S2P1_AVG \ + "fadd v6.4s, v0.4s, v1.4s\n" \ + "fadd v7.4s, v2.4s, v3.4s\n" \ + "fadd v8.4s, v4.4s, v5.4s\n" \ + "ext v0.16b, v31.16b, v1.16b, #12\n" /* ext 0, 1, 3, 5, r0 */ \ + "ext v2.16b, v31.16b, v3.16b, #12\n" /* ext 0, 1, 3, 5, r1 */ \ + "ext v4.16b, v31.16b, v5.16b, #12\n" /* ext 0, 1, 3, 5, r2 */ \ + \ + "fadd v1.4s, v6.4s, v0.4s\n" \ + "fadd v3.4s, v7.4s, v2.4s\n" \ + "fadd v5.4s, v8.4s, v4.4s\n" \ + \ + "fadd v9.4s, v1.4s, v3.4s\n" /* reduce */ \ + "subs %[dr0], %[dr0], #4\n" \ + "subs %[dr1], %[dr1], #4\n" \ + "subs %[dr2], %[dr2], #4\n" \ + \ + "fadd v10.4s, v9.4s, v5.4s\n" /* reduce */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "ld2 {v4.4s, v5.4s}, [%[dr2]], #32\n" /* load q4-q5, dr2, 0-7*/ \ + \ + "fmul v11.4s, v10.4s, %[vcoef_left].4s\n" \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "ld1 {v6.2s}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "ld1 {v7.2s}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "ld1 {v8.2s}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "st1 {v11.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ + +#define P3x3S2P0_AVG \ + "1: \n" /* load bias to q2, q3*/ \ + "fadd v9.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "fadd v10.4s, v2.4s, v3.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "fadd v11.4s, v4.4s, v5.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "ext v1.16b, v0.16b, v6.16b, #4\n" /* ext 2, 4, 6, 8, r0 */ \ + "ext v3.16b, v2.16b, v7.16b, #4\n" /* ext 2, 4, 6, 8, r1 */ \ + "ext v5.16b, v4.16b, v8.16b, #4\n" /* ext 2, 4, 6, 8, r2 */ \ + \ + "fadd v9.4s, v9.4s, v1.4s\n" /* max */ \ + "fadd v10.4s, v10.4s, v3.4s\n" /* max */ \ + "fadd v11.4s, v11.4s, v5.4s\n" /* max */ \ + \ + "fadd v9.4s, v9.4s, v10.4s\n" /* max reduce */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + \ + "fadd v9.4s, v9.4s, v11.4s\n" /* max reduce */ \ + "ld2 {v4.4s, v5.4s}, [%[dr2]], #32\n" /* load q4-q5, dr2, 0-7*/ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + \ + "fmul v10.4s, v9.4s, %[vcoef].4s\n" \ + "ld1 {v6.2s}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "ld1 {v7.2s}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "ld1 {v8.2s}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "st1 {v10.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#else +#define GLOBAL_INIT \ + "vld1.f32 {d0-d3}, [%[data_in_channel]]! @ load data \n" \ + "vld1.f32 {d4-d7}, [%[data_in_channel]]! @ load data \n" +#define GLOBAL_MAX \ + "1: @ main loop\n" \ + "vmax.f32 q4, q0, q1 @ max \n" \ + "vmax.f32 q5, q2, q3 @ max vmax \n" \ + "vld1.f32 {d0-d3}, [%[data_in_channel]]! @ load data \n" \ + "vld1.f32 {d4-d7}, [%[data_in_channel]]! @ load data \n" \ + "vmax.f32 q6, q4, q5 @ max vmax \n" \ + "subs %[cnt], #1 @ subs num, 1\n" \ + "vmax.f32 %q[vmax], %q[vmax], q6 @ max vmax \n" \ + "bne 1b @ bne num\n" +#define GLOBAL_AVG \ + "1: @main loop\n" \ + "vadd.f32 %q[vsum], %q[vsum], q0 @add vmax \n" \ + "vadd.f32 q4, q2, q1 @add vmax \n" \ + "vld1.f32 {d0-d3}, [%[data_in_channel]]! @load q1 \n" \ + "vadd.f32 %q[vsum], %q[vsum], q3 @add vmax \n" \ + "subs %[cnt], #1 @subs num, 1\n" \ + "vadd.f32 %q[vsum], %q[vsum], q4 @add vmax \n" \ + "vld1.f32 {d4-d7}, [%[data_in_channel]]! @load q1 \n" \ + "bne 1b @bne num\n" + +#define P2x2S2_INIT \ + "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ + "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" + +#define P2x2S2P0_MAX \ + "1: @ main loop\n" \ + "vmax.f32 q4, q0, q1 @ max \n" \ + "vmax.f32 q5, q2, q3 @ max \n" \ + "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ + "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ + "vmax.f32 q6, q4, q5 @ max reduce\n" \ + "subs %[cnt_num], #1 @ subs cnt_num \n" \ + "vst1.f32 {d12-d13}, [%[dr_out]]! @ store 4 out \n" \ + "bne 1b @ bne " + +#define P2x2S2P0_AVG \ + "1: @ main loop\n" \ + "vadd.f32 q4, q0, q1 @ add 0, 2, 4, 6 \n" \ + "vadd.f32 q5, q2, q3 @ add 0, 2, 4, 6 \n" \ + "vld2.f32 {d0-d3}, [%[dr0]]! @ load d0-d3 \n" \ + "vld2.f32 {d4-d7}, [%[dr1]]! @ load d4-d7 \n" \ + "vadd.f32 q6, q4, q5 @ add reduce \n" \ + "subs %[cnt_num], #1 @ subs \n" \ + "vmul.f32 q4, q6, %q[vcoef] @ mul coef \n" \ + "vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \ + "bne 1b @ bne \n" + +#define P3x3S1_INIT \ + "vld1.32 {d0-d2}, [%[dr0]]!\n" /* load q0, dr0, 0-5*/ \ + "vld1.32 {d4-d6}, [%[dr1]]!\n" /* load q2, dr0, 0-5*/ \ + "vld1.32 {d8-d10}, [%[dr2]]!\n" /* load q4, dr0, 0-5*/ +#define P3x3S1P0_INIT \ + "vld1.32 {d0-d1}, [%[dr0]]!\n" /* load q0, dr0, 0-5*/ \ + "vld1.32 {d4-d5}, [%[dr1]]!\n" /* load q2, dr0, 0-5*/ \ + "vld1.32 {d8-d9}, [%[dr2]]!\n" /* load q4, dr0, 0-5*/ \ + "vld1.32 {d2}, [%[dr0]]\n" /* load q1, dr0, 4-5*/ \ + "vld1.32 {d6}, [%[dr1]]\n" /* load q1, dr0, 4-5*/ \ + "vld1.32 {d10}, [%[dr2]]\n" /* load q1, dr0, 4-5*/ +#define P3x3S1P1_MAX \ + "vext.32 q6, q0, q1, #1\n" /* ext 1, 2, 3, 4, r0 */ \ + "vext.32 q7, q2, q3, #1\n" /* ext 1, 2, 3, 4, r1 */ \ + "vext.32 q8, q4, q5, #1\n" /* ext 1, 2, 3, 4, r2 */ \ + "vext.32 q9, %q[vmin], q0, #3\n" /* ext -1, 0, 1, 2, r0 */ \ + "vext.32 q10, %q[vmin], q2, #3\n" /* ext -1, 0, 1, 2, r1 */ \ + "vext.32 q11, %q[vmin], q4, #3\n" /* ext -1, 0, 1, 2, r2 */ \ + \ + "vmax.f32 q1, q0, q2\n" \ + "vmax.f32 q3, q4, q6\n" \ + "vmax.f32 q5, q7, q8\n" \ + "vmax.f32 q6, q9, q10\n" \ + "vmax.f32 q7, q11, q1\n" \ + \ + "subs %[dr0], %[dr0], #12\n" \ + "subs %[dr1], %[dr1], #12\n" \ + "subs %[dr2], %[dr2], #12\n" \ + \ + "vmax.f32 q8, q3, q5\n" \ + "vmax.f32 q9, q6, q7\n" \ + "vld1.32 {d0-d1}, [%[dr0]]!\n" /* load q0, dr0, 0-3*/ \ + "vld1.32 {d4-d5}, [%[dr1]]!\n" /* load q0, dr0, 0-3*/ \ + "vld1.32 {d8-d9}, [%[dr2]]!\n" /* load q0, dr0, 0-3*/ \ + "vmax.f32 q6, q8, q9\n" \ + \ + "subs %[cnt_num], %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "vld1.32 {d2}, [%[dr0]]\n" /* load q1, dr0, 4-5*/ \ + "vld1.32 {d6}, [%[dr1]]\n" /* load q1, dr0, 4-5*/ \ + "vld1.32 {d10}, [%[dr2]]\n" /* load q1, dr0, 4-5*/ \ + \ + "vst1.32 {d12-d13}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ + +#define P3x3S1P0_MAX \ + "1: \n" /* */ \ + "vext.32 q6, q0, q1, #1\n" /* ext 1, 2, 3, 4, r0 */ \ + "vext.32 q7, q2, q3, #1\n" /* ext 1, 2, 3, 4, r1 */ \ + "vext.32 q8, q4, q5, #1\n" /* ext 1, 2, 3, 4, r2 */ \ + "vext.32 q9, q0, q1, #2\n" /* ext 2, 3, 4, 5, r0 */ \ + "vext.32 q10, q2, q3, #2\n" /* ext 2, 3, 4, 5, r1 */ \ + "vext.32 q11, q4, q5, #2\n" /* ext 2, 3, 4, 5, r2 */ \ + \ + "vmax.f32 q1, q0, q2\n" \ + "vmax.f32 q3, q4, q6\n" \ + "vmax.f32 q5, q7, q8\n" \ + "vmax.f32 q6, q9, q10\n" \ + "vmax.f32 q7, q11, q1\n" \ + \ + "vmax.f32 q8, q3, q5\n" \ + "vmax.f32 q9, q6, q7\n" \ + "vld1.32 {d0-d1}, [%[dr0]]!\n" /* load q0, dr0, 0-3*/ \ + "vld1.32 {d4-d5}, [%[dr1]]!\n" /* load q0, dr0, 0-3*/ \ + "vld1.32 {d8-d9}, [%[dr2]]!\n" /* load q0, dr0, 0-3*/ \ + \ + "vmax.f32 q6, q8, q9\n" \ + "subs %[cnt_num], %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "vld1.32 {d2}, [%[dr0]]\n" /* load q1, dr0, 4-5*/ \ + "vld1.32 {d6}, [%[dr1]]\n" /* load q1, dr0, 4-5*/ \ + "vld1.32 {d10}, [%[dr2]]\n" /* load q1, dr0, 4-5*/ \ + \ + "vst1.32 {d12-d13}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#define P3x3S1P1_AVG \ + "vext.32 q6, q0, q1, #1\n" /* ext 1, 2, 3, 4, r0 */ \ + "vext.32 q7, q2, q3, #1\n" /* ext 1, 2, 3, 4, r1 */ \ + "vext.32 q8, q4, q5, #1\n" /* ext 1, 2, 3, 4, r2 */ \ + "vext.32 q9, q15, q0, #3\n" /* ext -1, 0, 1, 2, r0 */ \ + "vext.32 q10, q15, q2, #3\n" /* ext -1, 0, 1, 2, r1 */ \ + "vext.32 q11, q15, q4, #3\n" /* ext -1, 0, 1, 2, r2 */ \ + \ + "vadd.f32 q1, q0, q2\n" \ + "vadd.f32 q3, q4, q6\n" \ + "vadd.f32 q5, q7, q8\n" \ + "vadd.f32 q6, q9, q10\n" \ + "vadd.f32 q7, q11, q1\n" \ + \ + "vadd.f32 q8, q3, q5\n" \ + "vadd.f32 q9, q6, q7\n" \ + \ + "subs %[dr0], %[dr0], #12\n" \ + "subs %[dr1], %[dr1], #12\n" \ + "subs %[dr2], %[dr2], #12\n" \ + "vadd.f32 q10, q8, q9\n" \ + \ + "vld1.32 {d0-d1}, [%[dr0]]!\n" /* load q0, dr0, 0-3*/ \ + "vld1.32 {d4-d5}, [%[dr1]]!\n" /* load q2, dr1, 0-3*/ \ + "vld1.32 {d8-d9}, [%[dr2]]!\n" /* load q4, dr2, 0-3*/ \ + "vmul.f32 q11, q10, %q[vcoef_left]\n" \ + \ + "subs %[cnt_num], %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "vld1.32 {d2}, [%[dr0]]\n" /* load q1, dr0, 4-5*/ \ + "vld1.32 {d6}, [%[dr1]]\n" /* load q3, dr1, 4-5*/ \ + "vld1.32 {d10}, [%[dr2]]\n" /* load q5, dr2, 4-5*/ \ + \ + "vst1.32 {d22-d23}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ \ + "1: \n" /* */ + +#define P3x3S1P0_AVG \ + "1: \n" /* */ \ + "vext.32 q6, q0, q1, #1\n" /* ext 1, 2, 3, 4, r0 */ \ + "vext.32 q7, q2, q3, #1\n" /* ext 1, 2, 3, 4, r1 */ \ + "vext.32 q8, q4, q5, #1\n" /* ext 1, 2, 3, 4, r2 */ \ + "vext.32 q9, q0, q1, #2\n" /* ext 2, 3, 4, 5, r0 */ \ + "vext.32 q10, q2, q3, #2\n" /* ext 2, 3, 4, 5, r1 */ \ + "vext.32 q11, q4, q5, #2\n" /* ext 2, 3, 4, 5, r2 */ \ + \ + "vadd.f32 q1, q0, q2\n" \ + "vadd.f32 q3, q4, q6\n" \ + "vadd.f32 q5, q7, q8\n" \ + "vadd.f32 q6, q9, q10\n" \ + "vadd.f32 q7, q11, q1\n" \ + \ + "vadd.f32 q8, q3, q5\n" \ + "vadd.f32 q9, q6, q7\n" \ + "vld1.32 {d0-d1}, [%[dr0]]!\n" /* load q0, dr0, 0-3*/ \ + "vld1.32 {d4-d5}, [%[dr1]]!\n" /* load q2, dr1, 0-3*/ \ + \ + "vadd.f32 q10, q8, q9\n" \ + "vld1.32 {d8-d9}, [%[dr2]]!\n" /* load q4, dr2, 0-3*/ \ + "vld1.32 {d2}, [%[dr0]]\n" /* load q1, dr0, 4-5*/ \ + \ + "vmul.f32 q11, q10, %q[vcoef]\n" \ + "subs %[cnt_num], %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "vld1.32 {d6}, [%[dr1]]\n" /* load q3, dr1, 4-5*/ \ + "vld1.32 {d10}, [%[dr2]]\n" /* load q5, dr2, 4-5*/ \ + \ + "vst1.32 {d22-d23}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#define P3x3S2_INIT \ + "vld2.f32 {d0-d3}, [%[dr0]]!\n" /* load q0-q1, dr0, 0-7*/ \ + "vld2.f32 {d4-d7}, [%[dr1]]!\n" /* load q2-q3, dr1, 0-7*/ \ + "vld2.f32 {d8-d11}, [%[dr2]]!\n" /* load q4-q5, dr2, 0-7*/ +#define P3x3S2P0_INIT \ + "vld2.f32 {d0-d3}, [%[dr0]]!\n" /* load q0-q1, dr0, 0-7*/ \ + "vld2.f32 {d4-d7}, [%[dr1]]!\n" /* load q2-q3, dr1, 0-7*/ \ + "vld2.f32 {d8-d11}, [%[dr2]]!\n" /* load q4-q5, dr2, 0-7*/ \ + "vld1.f32 {d12-d13}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "vld1.f32 {d14-d15}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "vld1.f32 {d16-d17}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ +#define P3x3S2P1_MAX \ + "vmax.f32 q6, q0, q1\n" \ + "vmax.f32 q7, q2, q3\n" \ + "vmax.f32 q8, q4, q5\n" \ + "vext.32 q0, %q[vmin], q1, #3\n" /* ext 0, 1, 3, 5, r0 */ \ + "vext.32 q2, %q[vmin], q3, #3\n" /* ext 0, 1, 3, 5, r1 */ \ + "vext.32 q4, %q[vmin], q5, #3\n" /* ext 0, 1, 3, 5, r2 */ \ + \ + "vmax.f32 q9, q6, q0\n" \ + "vmax.f32 q10, q7, q2\n" \ + "vmax.f32 q11, q8, q4\n" \ + \ + "subs %[dr0], %[dr0], #4\n" \ + "subs %[dr1], %[dr1], #4\n" \ + "subs %[dr2], %[dr2], #4\n" \ + \ + "vmax.f32 q6, q9, q10\n" /* reduce */ \ + "vld2.f32 {d0-d3}, [%[dr0]]!\n" /* load q0-q1, dr0, 0-7*/ \ + "vld2.f32 {d4-d7}, [%[dr1]]!\n" /* load q2-q3, dr1, 0-7*/ \ + "vld2.f32 {d8-d11}, [%[dr2]]!\n" /* load q4-q5, dr2, 0-7*/ \ + \ + "vmax.f32 q10, q6, q11\n" /* reduce */ \ + "subs %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "vld1.f32 {d12-d13}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "vld1.f32 {d14-d15}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "vld1.f32 {d16-d17}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "vst1.32 {d20-d21}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ +#define P3x3S2P0_MAX \ + "1: \n" /* load bias to q2, q3*/ \ + "vmax.f32 q9, q0, q1\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "vmax.f32 q10, q2, q3\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "vmax.f32 q11, q4, q5\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "vext.32 q1, q0, q6, #1\n" /* ext 2, 4, 6, 8, r0 */ \ + "vext.32 q3, q2, q7, #1\n" /* ext 2, 4, 6, 8, r1 */ \ + "vext.32 q5, q4, q8, #1\n" /* ext 2, 4, 6, 8, r2 */ \ + \ + "vmax.f32 q6, q9, q1\n" /* add */ \ + "vmax.f32 q7, q10, q3\n" /* add */ \ + "vmax.f32 q8, q11, q5\n" /* add */ \ + \ + "vmax.f32 q9, q6, q7\n" /* max reduce */ \ + "vld2.f32 {d0-d3}, [%[dr0]]!\n" /* load q0-q1, dr0, 0-7*/ \ + "vld2.f32 {d4-d7}, [%[dr1]]!\n" /* load q2-q3, dr1, 0-7*/ \ + "vld2.f32 {d8-d11}, [%[dr2]]!\n" /* load q4-q5, dr2, 0-7*/ \ + \ + "vmax.f32 q10, q9, q8\n" /* max reduce */ \ + "subs %[cnt_num], %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "vld1.f32 {d12-d13}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "vld1.f32 {d14-d15}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "vld1.f32 {d16-d17}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "vst1.32 {d20-d21}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#define P3x3S2P1_AVG \ + "vadd.f32 q6, q0, q1\n" \ + "vadd.f32 q7, q2, q3\n" \ + "vadd.f32 q8, q4, q5\n" \ + "vext.32 q0, q15, q1, #3\n" /* ext 0, 1, 3, 5, r0 */ \ + "vext.32 q2, q15, q3, #3\n" /* ext 0, 1, 3, 5, r1 */ \ + "vext.32 q4, q15, q5, #3\n" /* ext 0, 1, 3, 5, r2 */ \ + "vadd.f32 q6, q6, q0\n" \ + "vadd.f32 q7, q7, q2\n" \ + "vadd.f32 q8, q8, q4\n" \ + \ + "vadd.f32 q9, q6, q7\n" /* reduce */ \ + "subs %[dr0], %[dr0], #4\n" \ + "subs %[dr1], %[dr1], #4\n" \ + "subs %[dr2], %[dr2], #4\n" \ + \ + "vadd.f32 q10, q9, q8\n" /* reduce */ \ + "vld2.f32 {d0-d3}, [%[dr0]]!\n" /* load q0-q1, dr0, 0-7*/ \ + "vld2.f32 {d4-d7}, [%[dr1]]!\n" /* load q2-q3, dr1, 0-7*/ \ + "vld2.f32 {d8-d11}, [%[dr2]]!\n" /* load q4-q5, dr2, 0-7*/ \ + \ + "vmul.f32 q11, q10, %q[vcoef_left]\n" \ + "subs %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "vld1.f32 {d12-d13}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "vld1.f32 {d14-d15}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "vld1.f32 {d16-d17}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "vst1.32 {d22-d23}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* jump to end */ + +#define P3x3S2P0_AVG \ + "1: \n" \ + "vadd.f32 q9, q0, q1\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7, */ \ + "vadd.f32 q10, q2, q3\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7, */ \ + "vadd.f32 q11, q4, q5\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7, */ \ + "vext.32 q1, q0, q6, #1\n" /* ext 2, 4, 6, 8, r0 */ \ + "vext.32 q3, q2, q7, #1\n" /* ext 2, 4, 6, 8, r1 */ \ + "vext.32 q5, q4, q8, #1\n" /* ext 2, 4, 6, 8, r2 */ \ + \ + "vadd.f32 q9, q9, q1\n" /* add */ \ + "vadd.f32 q10, q10, q3\n" /* add */ \ + "vadd.f32 q11, q11, q5\n" /* add */ \ + \ + "vadd.f32 q9, q9, q10 \n" /* max reduce */ \ + "vld2.f32 {d0-d3}, [%[dr0]]!\n" /* load q0-q1, dr0, 0-7*/ \ + "vld2.f32 {d4-d7}, [%[dr1]]!\n" /* load q2-q3, dr1, 0-7*/ \ + \ + "vadd.f32 q10, q9, q11 \n" /* max reduce */ \ + "vld2.f32 {d8-d11}, [%[dr2]]!\n" /* load q4-q5, dr2, 0-7*/ \ + "subs %[cnt_num], %[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + \ + "vmul.f32 q11, q10, %q[vcoef]\n" \ + "vld1.f32 {d12-d13}, [%[dr0]]\n" /* load d6, dr0, 8,9 */ \ + "vld1.f32 {d14-d15}, [%[dr1]]\n" /* load d7, dr1, 8,9 */ \ + "vld1.f32 {d16-d17}, [%[dr2]]\n" /* load d8, dr2, 8,9 */ \ + \ + "vst1.32 {d22-d23}, [%[dr_out]]!\n" /* store 4 out, dr_out */ \ + "bne 1b\n" /* bne s3_max_loop_mid */ + +#endif + void pooling_global_max(const float* din, float* dout, int num, @@ -202,50 +794,47 @@ void pooling_global_max(const float* din, int hin, int win) { int size_channel_in = win * hin; - int cnt = size_channel_in / 8; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + int cnt = size_channel_in / 16; + for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; ++c) { - const float* din_ch = din_batch + c * size_channel_in; + const float* data_in_channel = data_in_batch + c * size_channel_in; int i = 0; - float minval = std::numeric_limits::lowest(); - float32x4_t vmax = vdupq_n_f32(minval); + float32x4_t vmax = vdupq_n_f32(std::numeric_limits::lowest()); + int size_cnt = cnt; + if (cnt > 0) { #ifdef __aarch64__ - for (; i < cnt; i++) { - float32x4_t vdin1 = vld1q_f32(din_ch); - vmax = vmaxq_f32(vdin1, vmax); - float32x4_t vdin2 = vld1q_f32(din_ch + 4); - vmax = vmaxq_f32(vmax, vdin2); - din_ch += 8; - } + asm volatile( + GLOBAL_INIT GLOBAL_MAX + : [data_in_channel] "+r"(data_in_channel), + [cnt] "+r"(size_cnt), + [vmax] "+w"(vmax) + : + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6"); #else - int cnt_num = cnt; - if (cnt_num > 0) { asm volatile( - "max_loop: @main loop\n" - "vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n" - "vmax.f32 %q[vmax], %q[vmax], q0 @max vmax,vmax,din_ch\n" - "vld1.f32 {d2-d3}, [%[din_ch]]! @load 2nd 4 data\n" - "vmax.f32 %q[vmax], %q[vmax], q1 @compare 2nd 4 datas\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "bne max_loop @bne cnt_num\n" - : [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vmax] "+w"(vmax) + GLOBAL_INIT GLOBAL_MAX + : [data_in_channel] "+r"(data_in_channel), + [cnt] "+r"(size_cnt), + [vmax] "+w"(vmax) : - : "cc", "memory", "q0", "q1"); + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); +#endif // __aarch64__ + data_in_channel -= 16; } -#endif // __aarch64__ float32x2_t vmax_tmp = vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax)); - float tmp1 = vget_lane_f32(vmax_tmp, 0); - float tmp2 = vget_lane_f32(vmax_tmp, 1); - float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2; - for (i = cnt * 8; i < size_channel_in; ++i) { - /* code */ - max_tmp = max_tmp > din_ch[0] ? max_tmp : din_ch[0]; - din_ch++; + float max_tmp = vmax_tmp[0] > vmax_tmp[1] ? vmax_tmp[0] : vmax_tmp[1]; + for (i = cnt * 16; i < size_channel_in; ++i) { + max_tmp = max_tmp > data_in_channel[0] ? max_tmp : data_in_channel[0]; + data_in_channel++; } - dout_batch[c] = max_tmp; + data_out_batch[c] = max_tmp; } } } @@ -260,41 +849,46 @@ void pooling_global_avg(const float* din, int hin, int win) { int size_channel_in = win * hin; - int cnt = size_channel_in / 4; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + int cnt = size_channel_in / 16; + for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - const float* din_ch = din_batch + c * size_channel_in; // in address + const float* data_in_channel = + data_in_batch + c * size_channel_in; // in address int i = 0; float32x4_t vsum = vdupq_n_f32(0.0f); + int size_cnt = cnt; + if (cnt > 0) { #ifdef __aarch64__ - for (; i < cnt; i++) { - vsum = vaddq_f32(vld1q_f32(din_ch), vsum); - din_ch += 4; - } + asm volatile(GLOBAL_INIT GLOBAL_AVG + : [data_in_channel] "+r"(data_in_channel), + [cnt] "+r"(size_cnt), + [vsum] "+w"(vsum) + : + : "cc", "memory", "v0", "v1", "v2", "v3", "v4"); #else - int cnt_num = cnt; - if (cnt_num > 0) { - asm volatile( - "add_loop: @main loop\n" - "vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n" - "vadd.f32 %q[vsum], %q[vsum], q0 @add vmax,vmax, din_ch\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "bne add_loop @bne num\n" - : [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vsum] "+w"(vsum) - : - : "cc", "memory", "q0"); + asm volatile(GLOBAL_INIT GLOBAL_AVG + : [data_in_channel] "+r"(data_in_channel), + [cnt] "+r"(size_cnt), + [vsum] "+w"(vsum) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4"); +#endif // __aarch64__ + data_in_channel -= 16; } -#endif // __aarch64__ float32x2_t vsum_tmp = vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum)); - float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1); - for (i = cnt * 4; i < size_channel_in; i++) { - sum += din_ch[0]; - din_ch++; + float sum = vsum_tmp[0] + vsum_tmp[1]; + for (i = cnt * 16; i < size_channel_in; i++) { + sum += data_in_channel[0]; + data_in_channel++; } - dout_batch[c] = sum / size_channel_in; + data_out_batch[c] = sum / size_channel_in; } } } @@ -308,132 +902,74 @@ void pooling2x2s2_max(const float* din, int chin, int hin, int win) { - int kernel = 2; - int stride = 2; - int padding = 0; int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 2; + const int P = 0; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; - int w_needed = (wout << 1); - int h_needed = (hout << 1); - int w_limit = w_needed > win ? win : w_needed; - int h_limit = h_needed > hin ? hin : h_needed; - int w_even = (w_limit >> 1) << 1; - int h_even = (h_limit >> 1) << 1; - int w_unroll_size = (w_even >> 3) << 3; - // int w_unroll_remain = w_even - w_unroll_size; - int w_in_2 = win << 1; for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; - int h = 0; - for (; h < h_even; h += 2) { - int w = 0; -#ifdef __aarch64__ - for (; w < w_unroll_size; w += 8) { - float32x4_t dr00 = vld1q_f32(&r0[w]); - float32x4_t dr01 = vld1q_f32(&r0[w + 4]); - float32x4_t dr10 = vld1q_f32(&r1[w]); - float32x4_t dr11 = vld1q_f32(&r1[w + 4]); - float32x4_t dmax1 = vmaxq_f32(dr00, dr10); - float32x4_t dmax2 = vmaxq_f32(dr01, dr11); -#ifdef __aarch64__ - float32x4_t dmax = vpmaxq_f32(dmax1, dmax2); -#else - float32x2_t dmaxl = - vpmax_f32(vget_low_f32(dmax1), vget_high_f32(dmax1)); - float32x2_t dmaxh = - vpmax_f32(vget_low_f32(dmax2), vget_high_f32(dmax2)); - float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); -#endif - vst1q_f32(&dout_ch[w >> 1], dmax); + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + if (h * S + K - P > hin) { + dr1 = r0; } -#else - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - int cnt_num = w_unroll_size >> 3; - if (cnt_num > 0) { + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ asm volatile( - "s2_max_loop: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" - "vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n" - "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" - "vmax.f32 q1, q1, q3 @max q1,q1,q2\n" - "vpmax.f32 d4, d0, d1 @max d4,d0,d1\n" - "vpmax.f32 d5, d2, d3 @max d5,d2,d3\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "bne s2_max_loop @bne cnt_num\n" + P2x2S2_INIT P2x2S2P0_MAX : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : - : "cc", "memory", "q0", "q1", "q2", "q3"); - } - w = w_unroll_size; -#endif // __aarch64__ - for (; w < w_even; w += 2) { - dout_ch[w >> 1] = - std::max(std::max(r0[w], r0[w + 1]), std::max(r1[w], r1[w + 1])); - } - for (; w < w_limit; ++w) { // run 0 or 1 time - dout_ch[w >> 1] = std::max(r0[w], r1[w]); - } - r0 += w_in_2; // << 1; - r1 += w_in_2; // << 1; - dout_ch += wout; - } - // process remain row (odd, last row) - for (; h < h_limit; h++) { // run 0 or 1 time - int w = 0; -#ifdef __aarch64__ - for (; w < w_unroll_size; w += 8) { - float32x4_t dr00 = vld1q_f32(&r0[w]); - float32x4_t dr01 = vld1q_f32(&r0[w + 4]); -#ifdef __aarch64__ - float32x4_t dmax = vpmaxq_f32(dr00, dr01); -#else - float32x2_t dmaxl = - vpmax_f32(vget_low_f32(dr00), vget_high_f32(dr00)); - float32x2_t dmaxh = - vpmax_f32(vget_low_f32(dr01), vget_high_f32(dr01)); - float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); -#endif - vst1q_f32(&dout_ch[w >> 1], dmax); - } + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6"); #else - float* dr_out = dout_ch; - const float* dr0 = r0; - int cnt_num = w_unroll_size >> 3; - if (cnt_num > 0) { asm volatile( - "s2_max_loop1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" - "vpmax.f32 d4, d0, d1 @max d4,d0,d1\n" - "vpmax.f32 d5, d2, d3 @max d5,d2,d3\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "bne s2_max_loop1 @bne cnt_num\n" - : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) + P2x2S2_INIT P2x2S2P0_MAX + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) : - : "cc", "memory", "q0", "q1", "q2"); - } - w = w_unroll_size; -#endif // __aarch64__ - for (; w < w_even; w += 2) { - dout_ch[w >> 1] = std::max(r0[w], r0[w + 1]); + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); +#endif + dr0 -= 8; + dr1 -= 8; } - for (; w < w_limit; ++w) { // run 0 or 1 time - dout_ch[w >> 1] = r0[w]; + // deal with right pad + int rem = win - (w_unroll_size * 4) * S; + int wstart = 0; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, rem); + float tmp = dr0[wstart]; + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + } + *(dr_out++) = tmp; + wstart += S; } + r0 = r1 + win; + r1 = r0 + win; + data_out_channel += wout; } } } @@ -449,145 +985,76 @@ void pooling2x2s2_avg(const float* din, int hin, int win, bool exclusive) { - int kernel = 2; - int stride = 2; - int padding = 0; int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 2; + const int P = 0; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4 - int w_needed = (wout << 1); - int h_needed = (hout << 1); - int w_limit = w_needed > win ? win : w_needed; - int h_limit = h_needed > hin ? hin : h_needed; - int w_even = (w_limit >> 1) << 1; - int h_even = (h_limit >> 1) << 1; - int w_unroll_size = (w_even >> 3) << 3; - // int w_unroll_remain = w_even - w_unroll_size; - int w_in_2 = win << 1; - const float coef = 1.f / 4.f; - const float coef_1 = exclusive ? 1.f : coef; - const float coef_2 = exclusive ? 1.f / 2.f : coef; - float32x4_t vcoef = vdupq_n_f32(coef); - float32x4_t vcoef_1 = vdupq_n_f32(coef_1); - float32x4_t vcoef_2 = vdupq_n_f32(coef_2); for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; - int h = 0; - for (; h < h_even; h += 2) { - int w = 0; -#ifdef __aarch64__ - for (; w < w_unroll_size; w += 8) { - float32x4_t dr00 = vld1q_f32(&r0[w]); - float32x4_t dr01 = vld1q_f32(&r0[w + 4]); - float32x4_t dr10 = vld1q_f32(&r1[w]); - float32x4_t dr11 = vld1q_f32(&r1[w + 4]); - float32x4_t dsum1 = vaddq_f32(dr00, dr10); - float32x4_t dsum2 = vaddq_f32(dr01, dr11); -#ifdef __aarch64__ - float32x4_t dsum = vpaddq_f32(dsum1, dsum2); -#else - float32x2_t dsuml = - vpadd_f32(vget_low_f32(dsum1), vget_high_f32(dsum1)); - float32x2_t dsumh = - vpadd_f32(vget_low_f32(dsum2), vget_high_f32(dsum2)); - float32x4_t dsum = vcombine_f32(dsuml, dsumh); -#endif - float32x4_t res = vmulq_f32(dsum, vcoef); - vst1q_f32(&dout_ch[w >> 1], res); + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + if (h * S + K - P > hin) { + dr1 = r0; } -#else - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - int cnt_num = w_unroll_size >> 3; - if (cnt_num > 0) { + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" - "vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n" - "vadd.f32 q0, q0, q2 @add q0,q0,q2\n" - "vadd.f32 q1, q1, q3 @add q1,q1,q2\n" - "vpadd.f32 d4, d0, d1 @add d4,d0,d1\n" - "vpadd.f32 d5, d2, d3 @add d5,d2,d3\n" - "vmul.f32 q2, q2, %q[vcoef] @mul q2,q2,vcoef\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "bne 1b @bne cnt_num\n" + P2x2S2_INIT P2x2S2P0_AVG : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [vcoef] "+w"(vcoef), [cnt_num] "+r"(cnt_num) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "w"(vcoef) - : "cc", "memory", "q0", "q1", "q2", "q3"); - } - w = w_unroll_size; -#endif // __aarch64__ - for (; w < w_even; w += 2) { - dout_ch[w >> 1] = (r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) * coef; - } - for (; w < w_limit; ++w) { // run 0 or 1 time - dout_ch[w >> 1] = (r0[w] + r1[w]) * coef_2; - } - r0 += w_in_2; // << 1; - r1 += w_in_2; // << 1; - dout_ch += wout; - } - // process remain row (odd, last row) - for (; h < h_limit; h++) { // run 0 or 1 time - int w = 0; -#ifdef __aarch64__ - for (; w < w_unroll_size; w += 8) { - float32x4_t dr00 = vld1q_f32(&r0[w]); - float32x4_t dr01 = vld1q_f32(&r0[w + 4]); -#ifdef __aarch64__ - float32x4_t dsum = vpaddq_f32(dr00, dr01); + : [vcoef] "w"(vcoef) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6"); #else - float32x2_t dsuml = - vpadd_f32(vget_low_f32(dr00), vget_high_f32(dr00)); - float32x2_t dsumh = - vpadd_f32(vget_low_f32(dr01), vget_high_f32(dr01)); - float32x4_t dsum = vcombine_f32(dsuml, dsumh); -#endif - float32x4_t res = vmulq_f32(dsum, vcoef_2); - vst1q_f32(&dout_ch[w >> 1], res); - } -#else - float* dr_out = dout_ch; - const float* dr0 = r0; - int cnt_num = w_unroll_size >> 3; - if (cnt_num > 0) { asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" - "vpadd.f32 d4, d0, d1 @add d4,d0,d1\n" - "vpadd.f32 d5, d2, d3 @add d5,d2,d3\n" - "vmul.f32 q2, q2, %q[vcoef_2] @mul q2,q2,vcoef_2\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "bne 1b @bne cnt_num\n" + P2x2S2_INIT P2x2S2P0_AVG : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [vcoef_2] "+w"(vcoef_2), [cnt_num] "+r"(cnt_num) - : "r"(dr0), "r"(dr_out), "r"(cnt_num), "w"(vcoef_2) - : "cc", "memory", "q0", "q1", "q2"); - } - w = w_unroll_size; -#endif // __aarch64__ - for (; w < w_even; w += 2) { - dout_ch[w >> 1] = (r0[w] + r0[w + 1]) * coef_2; + : [vcoef] "w"(vcoef) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); +#endif + dr0 -= 8; + dr1 -= 8; } - for (; w < w_limit; ++w) { // run 0 or 1 time - dout_ch[w >> 1] = r0[w] * coef_1; + // deal with right pad + int rem = win - (w_unroll_size * 4) * S; + int wstart = 0; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, rem); + float coef = 0.5f / (wend - wstart); + float tmp = 0.f; + for (int i = wstart; i < wend; i++) { + tmp += dr0[i] + dr1[i]; + } + *(dr_out++) = tmp * coef; + wstart += S; } + + r0 = r1 + win; + r1 = r0 + win; + data_out_channel += wout; } } } @@ -602,200 +1069,96 @@ void pooling3x3s1p1_max(const float* din, int chin, int hin, int win) { - int kernel = 3; - int stride = 1; - int padding = 1; int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 3; + const int P = 1; + const int S = 1; + const int WUNROLL = 4; + + int w_unroll_size = wout / WUNROLL; + int w_unroll_remian = wout - w_unroll_size * WUNROLL; + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * WUNROLL; + } + + float32x4_t vmin = vdupq_n_f32(std::numeric_limits::lowest()); - int w_unroll_size = ((win - 2) >> 2) << 2; - int w_unroll_remain = win - 2 - w_unroll_size; - const float minval = std::numeric_limits::lowest(); for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; const float* r2 = r1 + win; - int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4 - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - const float* dr2 = r2; - int w = 0; - int cnt = 1; - // left - dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); -// first row with zero pad -#ifdef __aarch64__ - for (; w < w_unroll_size; w += 4) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_34_56 = - vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); - float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); - vst1q_f32(&dout_ch[cnt], vmax); - cnt += 4; - } - -#else - dr_out = dr_out + 1; - if (cnt_num > 0) { - asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" - "vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n" - "vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n" - "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" - "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" - "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" - "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" - "sub %[dr0], #8 @sub w,8\n" - "sub %[dr1], #8 @sub w,8\n" - // swap - "vmov.f32 s0, s17 @mov\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s0 @mov\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s1_max_loop\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); - } - -#endif - // remain - w = w_unroll_size; - for (int j = 0; j < w_unroll_remain; j++) { - float tmp_max = std::max(r0[j + w], r1[j + w]); - tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); - tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); - dout_ch[j + w + 1] = tmp_max; - } - // right - float tmp = std::max(r0[win - 2], r1[win - 2]); - tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); - dout_ch[wout - 1] = tmp; - - // r0 = r1; - // r1 = r0 + w_in; - // r2 = r1 + w_in; - dout_ch += wout; - int h = 0; - for (; h < hin - 2; h += 1) { - // deal with left pad - float maxr0 = std::max(r0[0], r0[1]); - float maxr1 = std::max(r1[0], r1[1]); - float maxr2 = std::max(r2[0], r2[1]); - dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); -#ifdef __aarch64__ - w = 0; - cnt = 1; - for (; w < w_unroll_size; w += 4) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr2_1234 = vld1q_f32(&r2[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); - - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_34_56 = - vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); - float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); - vst1q_f32(&dout_ch[cnt], vmax); - cnt += 4; + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + auto dr2 = r2; + if (h == 0) { + dr0 = r0; + dr1 = r0; + dr2 = r1; + } else { + r0 = r1; + r1 = r2; + r2 = r1 + win; } + if (h * S + K - P > hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = dr0; + case 1: + dr2 = dr0; + default: + break; + } + } + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile( + /* preocess left */ + P3x3S1_INIT P3x3S1P1_MAX P3x3S1P0_MAX "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vmin] "w"(vmin) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v31"); #else - dr_out = dout_ch + 1; - dr0 = r0; - dr1 = r1; - dr2 = r2; - cnt_num = w_unroll_size >> 2; - if (cnt_num > 0) { asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" - "vmax.f32 q7, q0, q2 @max r0_1234,r1_1234\n" - "vmax.f32 d16, d2, d6 @max r0_5678,r1_5678\n" - "vmax.f32 q3, q7, q4 @max r0_1234,r1_1234\n" - "vmax.f32 d12, d16, d10 @max r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q3, q6, #1 @vext max_2345\n" - "vext.f32 q2, q3, q6, #2 @vext max_3456\n" - "vpmax.f32 d2, d6, d7 @pmax d4,max_1234,max_1234\n" - "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" - "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" - "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" - "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" - "sub %[dr0], #8 @sub w,8\n" - "sub %[dr1], #8 @sub w,8\n" - "sub %[dr2], #8 @sub w,8\n" - // swap - "vmov.f32 s0, s17 @mov\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s0 @mov\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s1_max_loop\n" + /* preocess left */ + P3x3S1_INIT P3x3S1P1_MAX P3x3S1P0_MAX "2: \n" /* end */ : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : [vmin] "w"(vmin) : "cc", "memory", "q0", @@ -806,114 +1169,35 @@ void pooling3x3s1p1_max(const float* din, "q5", "q6", "q7", - "q8"); - } + "q8", + "q9", + "q10", + "q11", + "q15"); #endif - // remain - w = w_unroll_size; - for (int j = 0; j < w_unroll_remain; j++) { - float tmp_max = std::max(r0[j + w], r1[j + w]); - tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); - tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); - tmp_max = std::max(tmp_max, std::max(r2[j + w], r2[j + w + 1])); - tmp_max = std::max(tmp_max, r2[j + w + 2]); - dout_ch[j + w + 1] = tmp_max; + dr0 -= 4; + dr1 -= 4; + dr2 -= 4; } - // right - tmp = std::max(r0[win - 2], r1[win - 2]); - tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); - tmp = std::max(tmp, std::max(r2[win - 2], r2[win - 1])); - dout_ch[wout - 1] = tmp; - - r0 = r1; - r1 = r2; - r2 = r1 + win; - dout_ch += wout; - } - - // the last two line - float maxr0 = std::max(r0[0], r0[1]); - float maxr1 = std::max(r1[0], r1[1]); - dout_ch[0] = std::max(maxr0, maxr1); -#ifdef __aarch64__ - w = 0; - cnt = 1; - for (; w < w_unroll_size; w += 4) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_34_56 = - vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); - float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); - vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); - vst1q_f32(&dout_ch[cnt], vmax); - cnt += 4; - } -#else - dr_out = dout_ch + 1; - dr0 = r0; - dr1 = r1; - cnt_num = w_unroll_size >> 2; - if (cnt_num > 0) { - asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" - "vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n" - "vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n" - "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" - "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" - "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" - "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" - "sub %[dr0], #8 @sub w,8\n" - "sub %[dr1], #8 @sub w,8\n" - // swap - "vmov.f32 s0, s17 @mov\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s0 @mov\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s1_max_loop\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); - } -#endif - // remian - w = w_unroll_size; - for (int j = 0; j < w_unroll_remain; j++) { - float tmp_max = std::max(r0[j + w], r1[j + w]); - tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); - tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); - dout_ch[j + w + 1] = tmp_max; + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = dr0[0]; + for (int i = 0; i < wend - st; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + tmp = std::max(tmp, dr2[i]); + } + *(dr_out++) = tmp; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + dr2 += S - (st - wstart); + wstart += S; + } + data_out_channel += wout; } - tmp = std::max(r0[win - 2], r1[win - 2]); - tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); - dout_ch[wout - 1] = tmp; } } } @@ -928,288 +1212,176 @@ void pooling3x3s1p1_avg(const float* din, int hin, int win, bool exclusive) { - int kernel = 3; - int stride = 1; - int padding = 1; int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 3; + const int P = 1; + const int S = 1; + const int WUNROLL = 4; + + int w_unroll_size = wout / WUNROLL; + int w_unroll_remian = wout - w_unroll_size * WUNROLL; + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * WUNROLL; + } + + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + memset(zero_ptr, 0, win * sizeof(float)); - int w_unroll_size = ((win - 2) >> 2) << 2; - int w_unroll_remain = win - 2 - w_unroll_size; - const float coef = 1.f / 9.f; - const float coef_2 = exclusive ? 1.f / 2.f : coef; - const float coef_4 = exclusive ? 1.f / 4.f : coef; - const float coef_6 = exclusive ? 1.f / 6.f : coef; - float32x4_t vcoef = vdupq_n_f32(coef); - float32x4_t vcoef_2 = vdupq_n_f32(coef_2); - float32x4_t vcoef_4 = vdupq_n_f32(coef_4); - float32x4_t vcoef_6 = vdupq_n_f32(coef_6); for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; const float* r2 = r1 + win; - int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4 - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - const float* dr2 = r2; - int w = 0; - int cnt = 1; - // left - dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; -// first row with zero pad -#ifdef __aarch64__ - for (; w < w_unroll_size; w += 4) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); - float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); - - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); - vsum = vaddq_f32(vsum, vsum_3456); - vsum = vmulq_f32(vsum, vcoef_6); - vst1q_f32(&dout_ch[cnt], vsum); - cnt += 4; - } -#else - dr_out = dr_out + 1; - if (cnt_num > 0) { - asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" - "vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n" - "vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vadd.f32 q1, q5, q0 @add 1234+2345\n" - "vadd.f32 q1, q1, q2 @add + 3456\n" - "vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n" - "sub %[dr0], #8 @sub w,8\n" - "sub %[dr1], #8 @sub w,8\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s1_max_loop\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [vcoef_6] "+w"(vcoef_6) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); - } - -#endif - // remain - w = w_unroll_size; - for (int j = 0; j < w_unroll_remain; j++) { - float tmp_sum = r0[j + w] + r1[j + w]; - tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); - tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); - dout_ch[j + w + 1] = tmp_sum * coef_6; - } - // right - float tmp = r0[win - 2] + r1[win - 2]; - tmp += (r0[win - 1] + r1[win - 1]); - dout_ch[wout - 1] = tmp * coef_4; - - // r0 = r1; - // r1 = r0 + w_in; - // r2 = r1 + w_in; - dout_ch += wout; - int h = 0; - for (; h < hin - 2; h += 1) { - // deal with left pad - float maxr0 = r0[0] + r0[1]; - float maxr1 = r1[0] + r1[1]; - float maxr2 = r2[0] + r2[1]; - dout_ch[0] = (maxr0 + maxr1 + maxr2) * coef_6; -#ifdef __aarch64__ - w = 0; - cnt = 1; - for (; w < w_unroll_size; w += 4) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr2_1234 = vld1q_f32(&r2[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); - float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); - vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); - float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); - vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); - - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); - vsum = vaddq_f32(vsum, vsum_3456); - vsum = vmulq_f32(vsum, vcoef); - vst1q_f32(&dout_ch[cnt], vsum); - cnt += 4; - } -#else - dr_out = dout_ch + 1; - dr0 = r0; - dr1 = r1; - dr2 = r2; - cnt_num = w_unroll_size >> 2; - if (cnt_num > 0) { - asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d10}, [%[dr2]]! @load d4-d7,dr1\n" - "vadd.f32 q7, q0, q2 @max r0_1234,r1_1234\n" - "vadd.f32 d16, d2, d6 @max r0_5678,r1_5678\n" - "vadd.f32 q3, q7, q4 @max r0_1234,r1_1234\n" - "vadd.f32 d12, d16, d10 @max r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q3, q6, #1 @vext max_2345\n" - "vext.f32 q2, q3, q6, #2 @vext max_3456\n" - "vadd.f32 q1, q3, q0 @add 1234+2345\n" - "vadd.f32 q1, q1, q2 @add+3456\n" - "vmul.f32 q4, q1, %q[vcoef] @mul*1/9.f\n" - "sub %[dr0], #8 @sub w,8\n" - "sub %[dr1], #8 @sub w,8\n" - "sub %[dr2], #8 @sub w,8\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s1_max_loop\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr2] "+r"(dr2), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [vcoef] "+w"(vcoef) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8"); + for (int h = 0; h < hout; h++) { + float coef_h = 1.f / 3; + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + auto dr2 = r2; + if (h == 0) { + if (exclusive) { + coef_h = 0.5f; + } + dr0 = zero_ptr; + dr1 = r0; + dr2 = r1; + } else { + r0 = r1; + r1 = r2; + r2 = r1 + win; } -#endif - // remain - w = w_unroll_size; - for (int j = 0; j < w_unroll_remain; j++) { - float tmp_sum = r0[j + w] + r1[j + w]; - tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); - tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); - tmp_sum += (r2[j + w + 1] + r2[j + w + 2]); - tmp_sum += r2[j + w]; - dout_ch[j + w + 1] = tmp_sum * coef; + if (h * S + K - P > hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = zero_ptr; + dr2 = zero_ptr; + if (exclusive) { + coef_h = 1.f; + } else { + coef_h = 0.5f; + } + break; + case 1: + dr2 = zero_ptr; + if (exclusive) { + if (fabsf(coef_h - 0.5f) < 1e-6f) { + coef_h = 1.f; + } else { + coef_h = 0.5f; + } + } else { + coef_h = 1.f / 3; + } + default: + break; + } } - // right - tmp = r0[win - 2] + r1[win - 2]; - tmp += (r0[win - 1] + r1[win - 1]); - tmp += (r2[win - 2] + r2[win - 1]); - dout_ch[wout - 1] = tmp * coef_6; - - r0 = r1; - r1 = r2; - r2 = r1 + win; - dout_ch += wout; - } - - // last line - float maxr0 = (r0[0] + r0[1]); - float maxr1 = (r1[0] + r1[1]); - dout_ch[0] = (maxr0 + maxr1) * coef_4; + float32x4_t vcoef = vdupq_n_f32(coef_h / 3); + float coef_left_most = exclusive ? coef_h / 2 : coef_h / 3; + float coef_left[4] = { + coef_left_most, coef_h / 3, coef_h / 3, coef_h / 3}; + float32x4_t vcoef_left = vld1q_f32(coef_left); + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { #ifdef __aarch64__ - w = 0; - cnt = 1; - for (; w < w_unroll_size; w += 4) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); - float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); - - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); - vsum = vaddq_f32(vsum, vsum_3456); - vsum = vmulq_f32(vsum, vcoef_6); - vst1q_f32(&dout_ch[cnt], vsum); - cnt += 4; - } + asm volatile("movi v31.4s, #0\n" + /* preocess left */ + P3x3S1_INIT P3x3S1P1_AVG P3x3S1P0_AVG "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), [vcoef_left] "w"(vcoef_left) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v31"); #else - dr_out = dout_ch + 1; - dr0 = r0; - dr1 = r1; - cnt_num = w_unroll_size >> 2; - if (cnt_num > 0) { - asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" - "vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n" - "vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vadd.f32 q1, q5, q0 @add 1234+2345\n" - "vadd.f32 q1, q1, q2 @add + 3456\n" - "vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n" - "sub %[dr0], #8 @sub w,8\n" - "sub %[dr1], #8 @sub w,8\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s1_max_loop\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [vcoef_6] "+w"(vcoef_6) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); - } + asm volatile("vmov.i32 q15, #0\n" + /* preocess left */ + P3x3S1_INIT P3x3S1P1_AVG P3x3S1P0_AVG "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), [vcoef_left] "w"(vcoef_left) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); #endif - // remain - w = w_unroll_size; - for (int j = 0; j < w_unroll_remain; j++) { - float tmp_sum = r0[j + w] + r1[j + w]; - tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); - tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); - dout_ch[j + w + 1] = tmp_sum * coef_6; + dr0 -= 4; + dr1 -= 4; + dr2 -= 4; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = wstart + K; // std::min(wstart + K, win); + float coef = coef_h / 3.f; + int st = wstart > 0 ? wstart : 0; + if (wstart + K > win) { + wend = win; + if (!exclusive && wstart + K - win == 2) { + coef = coef_h / 2; + } + } + if (exclusive) { + coef = coef_h / (wend - st); + } + float tmp = 0.f; + for (int i = 0; i < wend - st; i++) { + tmp += dr0[i] + dr1[i] + dr2[i]; + } + *(dr_out++) = tmp * coef; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + dr2 += S - (st - wstart); + wstart += S; + } + data_out_channel += wout; } - // right - tmp = r0[win - 2] + r1[win - 2]; - tmp += (r0[win - 1] + r1[win - 1]); - dout_ch[wout - 1] = tmp * coef_4; } } + TargetFree(TARGET(kARM), zero_ptr); } -void pooling3x3s2p1_max(const float* din, +void pooling3x3s1p0_max(const float* din, float* dout, int num, int chout, @@ -1218,300 +1390,87 @@ void pooling3x3s2p1_max(const float* din, int chin, int hin, int win) { - int kernel = 3; - int stride = 2; - int padding = 1; int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 3; + const int P = 0; + const int S = 1; + const int WUNROLL = 4; + + int w_unroll_size = wout / WUNROLL; + int w_unroll_remian = wout - w_unroll_size * WUNROLL; + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * WUNROLL; + } + + float32x4_t vmin = vdupq_n_f32(std::numeric_limits::lowest()); - int w_needed = (wout << 1) + 1; - int h_needed = (hout << 1) + 1; - int w_limit = w_needed > win ? win : w_needed; - int h_limit = h_needed > hin ? hin : h_needed; - int w_even = (w_limit >> 1) << 1; - int h_even = (h_limit >> 1) << 1; - int w_unroll_size = ((w_even - 1) >> 3) << 3; - int w_unroll_remain = w_even - 1 - w_unroll_size; - int w_remain = w_needed - w_limit - padding; - int h_remain = h_needed - h_limit - padding; - int w_in_2 = win << 1; - float minval = std::numeric_limits::lowest(); for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; const float* r2 = r1 + win; - int cnt_num = w_unroll_size >> 3; - int cnt_num_remain = w_unroll_remain >> 1; - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - const float* dr2 = r2; - int w = 1; - int cnt = 1; - dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); -// first row with zero pad -#if __aarch64__ - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_56_78 = - vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); - float32x2_t vmax_67_89 = - vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&dout_ch[cnt], vmax_123_345); - vst1_f32(&dout_ch[cnt + 2], vmax_567_789); - cnt += 4; - } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - vr0 = vsetq_lane_f32(minval, vr0, 3); - vr1 = vsetq_lane_f32(minval, vr1, 3); - float32x4_t vmax1 = vmaxq_f32(vr0, vr1); - float32x2_t vmax2 = - vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); - vmax2 = vpmax_f32(vmax2, vmax2); - dout_ch[cnt] = vget_lane_f32(vmax2, 0); - cnt++; - } -#else - dr0 = dr0 + 1; - dr1 = dr1 + 1; - dr_out = dr_out + 1; - // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << - // cnt_num_remain; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" - "vmax.f32 q6, q0, q3 @max r0_1234,r1_1234\n" - "vmax.f32 q7, q1, q4 @max r0_5678,r1_5678\n" - "vmax.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q7, q8, #1 @vext max_6789\n" - "vpmax.f32 d4, d12, d13 @pmax d4,vmax_1234,vmax_1234\n" - "vpmax.f32 d6, d14, d15 @pmax d6,vmax_5678,vmax_5678\n" - "vpmax.f32 d5, d0, d1 @pmax d5,vmax_2345,vmax_2345\n" - "vpmax.f32 d7, d2, d3 @pmax d7,vmax_6789,vmax_6789\n" - "vmax.f32 d8, d4, d5 @max d2,vmax_12_34,vmax_23_45\n" - "vmax.f32 d9, d6, d7 @max d2,vmax_56_78,vmax_67_89\n" - "sub %[dr0], #16 @add w,8\n" - "sub %[dr1], #16 @add w, 8\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "subs %[cnt_num], #1 @subs cnt_num, #1\n" - "bne 1b @bne s3_max_loop\n" - "3: @loop \n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" - "ble 4f @ble exit\n" - "2: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vmov.f32 s3,s2 @movs3,s2\n" - "vmov.f32 s7,s6 @movs7,s6\n" - "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" - "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" - "bne 2b @bne s3_max_loop_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9"); - } -#endif - // int w = w_even - 1; - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp = r0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp = std::max(tmp, std::max(r0[i], r1[i])); - } - dout_ch[w_even >> 1] = tmp; - // cnt ++; - } - - r0 = r1; - r1 = r0 + win; - r2 = r1 + win; - dout_ch += wout; - int h = 2; - for (; h < h_even; h += 2) { - // deal with left pad - float maxr0 = std::max(r0[0], r0[1]); - float maxr1 = std::max(r1[0], r1[1]); - float maxr2 = std::max(r2[0], r2[1]); - dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); -#if __aarch64__ - w = 1; - cnt = 1; - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - float32x4_t vr2_1234 = vld1q_f32(&r2[w]); - float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); - float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); - float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); - vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_56_78 = - vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); - float32x2_t vmax_67_89 = - vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&dout_ch[cnt], vmax_123_345); - vst1_f32(&dout_ch[cnt + 2], vmax_567_789); - cnt += 4; - } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - float32x4_t vr2 = vld1q_f32(&r2[w]); - vr0 = vsetq_lane_f32(minval, vr0, 3); - vr1 = vsetq_lane_f32(minval, vr1, 3); - vr2 = vsetq_lane_f32(minval, vr2, 3); - float32x4_t vmax1 = vmaxq_f32(vr0, vr1); - vmax1 = vmaxq_f32(vmax1, vr2); - float32x2_t vmax2 = - vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); - float32x2_t vmax = vpmax_f32(vmax2, vmax2); - dout_ch[cnt] = vget_lane_f32(vmax, 0); - cnt++; + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + auto dr2 = r2; + if (h * S + K - P > hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = dr0; + case 1: + dr2 = dr0; + default: + break; + } } + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile( + /* preocess left */ + P3x3S1_INIT P3x3S1P0_MAX "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vmin] "w"(vmin) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v31"); #else - dr_out = dout_ch + 1; - dr0 = (r0 + 1); - dr1 = (r1 + 1); - dr2 = (r2 + 1); - cnt_num = w_unroll_size >> 3; - cnt_num_remain = w_unroll_remain >> 1; - if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n" - "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" - "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" - "vmax.f32 q11, q2, q5 @max q1,q1,q3\n" - "vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n" - "vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n" - "vmax.f32 q1, q11, q8 @max q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q4, q0, q3, #1 @vext 2345\n" - "vext.f32 q2, q3, q1, #1 @vext 6789\n" - "vpmax.f32 d10, d0, d1 @pmax d10,vmax_1234,vmax_1234\n" - "vpmax.f32 d12, d6, d7 @pmax d12,vmax_5678,vmax_5678\n" - "vpmax.f32 d11, d8, d9 @pmax d11,vmax_2345,vmax_2345\n" - "vpmax.f32 d13, d4, d5 @pmax d13,vmax_6789,vmax_6789\n" - "vmax.f32 d0, d10, d11 @pmax d0,vmax_12_34,vmax_23_45\n" - "vmax.f32 d1, d12, d13 @pmax d1,vmax_56_78,vmax_67_89\n" - "sub %[dr0], #16 @add w,8\n" - "sub %[dr1], #16 @add w,8\n" - "sub %[dr2], #16 @add w,8\n" - "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "bne 1b @bne s3_max_loop_mid\n" - "3: @loop \n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" - "ble 4f @ble exit1\n" - "2: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" - "vmov.f32 s3,s2 @movs3,s2\n" - "vmov.f32 s7,s6 @movs7,s6\n" - "vmov.f32 s11,s10 @movs11,s10\n" - "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" - "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" - "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0, d0,d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "sub %[dr2], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" - "bne 2b @bne s3_max_loop_mid_1\n" - "4: @exit\n" + /* preocess left */ + P3x3S1P0_INIT P3x3S1P0_MAX : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain) - : "r"(dr0), - "r"(dr1), - "r"(dr2), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) + [cnt_num] "+r"(cnt_num) + : [vmin] "w"(vmin) : "cc", "memory", "q0", @@ -1526,280 +1485,205 @@ void pooling3x3s2p1_max(const float* din, "q9", "q10", "q11", - "q12"); - } + "q15"); #endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp = r0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { - tmp = std::max(tmp, std::max(r0[i], r1[i])); - tmp = std::max(tmp, r2[i]); + dr0 -= 4; + dr1 -= 4; + dr2 -= 4; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = dr0[0]; + for (int i = 0; i < wend - st; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + tmp = std::max(tmp, dr2[i]); } - dout_ch[w_even >> 1] = tmp; - // cnt ++; + *(dr_out++) = tmp; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + dr2 += S - (st - wstart); + wstart += S; } - r0 = r2; - r1 = r0 + win; + r0 = r1; + r1 = r2; r2 = r1 + win; - dout_ch += wout; + data_out_channel += wout; } + } + } +} - if (h_remain > 0) { - // deal with bottom pad - // first row with zero pad - int hstart = (h >> 1) * stride - padding; - int hend = std::min(std::min(hstart + kernel, hin + padding), hin); - if (hstart == hend - 1) { // only one lline - dout_ch[0] = std::max(r0[0], r0[1]); -#if __aarch64__ - w = 1; - cnt = 1; - for (; w < w_unroll_size; w += 8) { - float32x4_t vmax_1234 = vld1q_f32(&r0[w]); - float32x4_t vmax_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vmax_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_56_78 = - vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); - float32x2_t vmax_67_89 = - vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&dout_ch[cnt], vmax_123_345); - vst1_f32(&dout_ch[cnt + 2], vmax_567_789); - cnt += 4; - } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - vr0 = vsetq_lane_f32(minval, vr0, 3); - float32x2_t vmax = vpmax_f32(vget_low_f32(vr0), vget_high_f32(vr0)); - vmax = vpmax_f32(vmax, vmax); - dout_ch[cnt] = vget_lane_f32(vmax, 0); - cnt++; +void pooling3x3s1p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 3; + const int P = 0; + const int S = 1; + const int WUNROLL = 4; + + int w_unroll_size = wout / WUNROLL; + int w_unroll_remian = wout - w_unroll_size * WUNROLL; + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * WUNROLL; + } + + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + memset(zero_ptr, 0, win * sizeof(float)); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + for (int h = 0; h < hout; h++) { + float coef_h = 1.f / 3; + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + auto dr2 = r2; + if (h * S + K - P > hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = zero_ptr; + dr2 = zero_ptr; + if (exclusive) { + coef_h = 1.f; + } else { + coef_h = 0.5f; + } + break; + case 1: + dr2 = zero_ptr; + if (exclusive) { + if (fabsf(coef_h - 0.5f) < 1e-6f) { + coef_h = 1.f; + } else { + coef_h = 0.5f; + } + } else { + coef_h = 1.f / 3; + } + default: + break; } + } + float32x4_t vcoef = vdupq_n_f32(coef_h / 3); + float coef_left_most = exclusive ? coef_h / 2 : coef_h / 3; + float coef_left[4] = { + coef_left_most, coef_h / 3, coef_h / 3, coef_h / 3}; + float32x4_t vcoef_left = vld1q_f32(coef_left); + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile("movi v31.4s, #0\n" P3x3S1_INIT P3x3S1P0_AVG + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), [vcoef_left] "w"(vcoef_left) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v31"); #else - dr_out = dout_ch + 1; - dr0 = (r0 + 1); - cnt_num = w_unroll_size >> 3; - cnt_num_remain = w_unroll_remain >> 1; - // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << - // cnt_num_remain; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3,dr0\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" - "vext.f32 q4, q0, q1, #1 @vmax_2345\n" - "vext.f32 q5, q1, q2, #1 @vmax_6789\n" - "vpmax.f32 d12, d0, d1 @vmax_12_34\n" - "vpmax.f32 d14, d2, d3 @vmax_56_78\n" - "vpmax.f32 d13, d8, d9 @vmax_23_45\n" - "vpmax.f32 d15, d10, d11 @vmax_67_89\n" - "vmax.f32 d0, d12, d13 @12_34,23_45\n" - "vmax.f32 d1, d14, d15 @56_78,67_89\n" - "sub %[dr0], #16 @add w,6\n" - "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "sub %[dr0], #8 @add w,2\n" - "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" - "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain) - : "r"(dr0), - "r"(dr1), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8"); - } + asm volatile("vmov.i32 q15, #0\n" P3x3S1P0_INIT P3x3S1P0_AVG + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), [vcoef_left] "w"(vcoef_left) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); #endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp = r0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { - tmp = std::max(tmp, r0[i]); + dr0 -= 4; + dr1 -= 4; + dr2 -= 4; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = wstart + K; // std::min(wstart + K, win); + float coef = coef_h / 3.f; + int st = wstart > 0 ? wstart : 0; + if (wstart + K > win) { + wend = win; + if (!exclusive && wstart + K - win == 2) { + coef = coef_h / 2; } - dout_ch[w_even >> 1] = tmp; - } - } else { // two lines - dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); -#ifdef __aarch64__ - w = 1; - cnt = 1; - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_56_78 = - vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); - float32x2_t vmax_67_89 = - vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&dout_ch[cnt], vmax_123_345); - vst1_f32(&dout_ch[cnt + 2], vmax_567_789); - cnt += 4; } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - vr0 = vsetq_lane_f32(minval, vr0, 3); - vr1 = vsetq_lane_f32(minval, vr1, 3); - float32x4_t vmax1 = vmaxq_f32(vr0, vr1); - float32x2_t vmax2 = - vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); - vmax2 = vpmax_f32(vmax2, vmax2); - dout_ch[cnt] = vget_lane_f32(vmax2, 0); - cnt++; + if (exclusive) { + coef = coef_h / (wend - st); } -#else - dr_out = dout_ch + 1; - dr0 = (r0 + 1); - dr1 = (r1 + 1); - cnt_num = w_unroll_size >> 3; - cnt_num_remain = w_unroll_remain >> 1; - // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << - // cnt_num_remain; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" - "vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n" - "vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n" - "vmax.f32 q8, q2, q5 @max q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext q0,2345\n" - "vext.f32 q1, q7, q8, #1 @vext q1,6789\n" - "vpmax.f32 d4, d12, d13 @pmax " - "d4,vmax_1234,vmax_1234\n" - "vpmax.f32 d6, d14, d15 @pmax " - "d6,vmax_5678,vmax_5678\n" - "vpmax.f32 d5, d0, d1 @pmax " - "d5,vmax_2345,vmax_2345\n" - "vpmax.f32 d7, d2, d3 @pmax " - "d7,vmax_6789,vmax_6789\n" - "vmax.f32 d8, d4, d5 @max " - "d2,vmax_12_34,vmax_23_45\n" - "vmax.f32 d9, d6, d7 @max " - "d2,vmax_56_78,vmax_67_89\n" - "sub %[dr0], #16 @add w,8\n" - "sub %[dr1], #16 @add w,8\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vmov.f32 s7,s6 @movs7, s6\n" - "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" - "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" - "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain) - : "r"(dr0), - "r"(dr1), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9"); - } -#endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp = r0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp = std::max(tmp, std::max(r0[i], r1[i])); - } - dout_ch[w_even >> 1] = tmp; + float tmp = 0.f; + for (int i = 0; i < wend - st; i++) { + tmp += dr0[i] + dr1[i] + dr2[i]; } + *(dr_out++) = tmp * coef; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + dr2 += S - (st - wstart); + wstart += S; } + r0 = r1; + r1 = r2; + r2 = r1 + win; + data_out_channel += wout; } } } + TargetFree(TARGET(kARM), zero_ptr); } -void pooling3x3s2p1_avg(const float* din, +void pooling3x3s2p1_max(const float* din, float* dout, int num, int chout, @@ -1807,337 +1691,98 @@ void pooling3x3s2p1_avg(const float* din, int wout, int chin, int hin, - int win, - bool exclusive) { - int kernel = 3; - int stride = 2; - int padding = 1; + int win) { int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 3; + const int P = 1; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } + float32x4_t vmin = vdupq_n_f32(std::numeric_limits::lowest()); - int w_needed = (wout << 1) + 1; - int h_needed = (hout << 1) + 1; - int w_limit = w_needed > win ? win : w_needed; - int h_limit = h_needed > hin ? hin : h_needed; - int w_even = (w_limit >> 1) << 1; - int h_even = (h_limit >> 1) << 1; - int w_unroll_size = ((w_even - 1) >> 3) << 3; - int w_unroll_remain = w_even - 1 - w_unroll_size; - int w_remain = w_needed - w_limit - padding; - int h_remain = h_needed - h_limit - padding; - int w_in_2 = win << 1; - const float coef = 1.f / 9.f; - const float coef_1 = exclusive ? 1.f : coef; - const float coef_2 = exclusive ? 1.f / 2.f : coef; - const float coef_3 = exclusive ? 1.f / 3.f : coef; - const float coef_4 = exclusive ? 1.f / 4.f : coef; - const float coef_6 = exclusive ? 1.f / 6.f : coef; - float32x4_t vcoef = vdupq_n_f32(coef); - float32x4_t vcoef_1 = vdupq_n_f32(coef_1); - float32x4_t vcoef_2 = vdupq_n_f32(coef_2); - float32x4_t vcoef_3 = vdupq_n_f32(coef_3); - float32x4_t vcoef_4 = vdupq_n_f32(coef_4); - float32x4_t vcoef_6 = vdupq_n_f32(coef_6); for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; const float* r2 = r1 + win; - int cnt_num = w_unroll_size >> 3; - int cnt_num_remain = w_unroll_remain >> 1; - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - const float* dr2 = r2; - int w = 1; - int cnt = 1; - float32x4_t vzero = vdupq_n_f32(0.f); - dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; -// first row with zero pad -#ifdef __aarch64__ - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); - float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); - float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); - - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); - float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); - float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); - vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); - float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); - vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); - vst1q_f32(&dout_ch[cnt], vrst); - cnt += 4; - } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - vr0 = vsetq_lane_f32(0.f, vr0, 3); - vr1 = vsetq_lane_f32(0.f, vr1, 3); - float32x4_t vsum1 = vaddq_f32(vr0, vr1); - float32x2_t vsum2 = - vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); - vsum2 = vpadd_f32(vsum2, vsum2); - float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); - dout_ch[cnt] = vget_lane_f32(vrst, 0); - cnt++; - } -#else - dr0 = dr0 + 1; - dr1 = dr1 + 1; - dr_out = dr_out + 1; - // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << - // cnt_num_remain; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" - "vadd.f32 q6, q0, q3 @max r0_1234,r1_1234\n" - "vadd.f32 q7, q1, q4 @max r0_5678,r1_5678\n" - "vadd.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, 2345\n" - "vadd.f32 q5, q7, q1 @add 5678, 4567\n" - "vadd.f32 q4, q4, q2 @add 3456, sum1\n" - "vadd.f32 q5, q5, q3 @add 6789, sum2\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s21 @mov\n" - "vmov.f32 s19, s23 @mov\n" - "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" - "sub %[dr0], #16 @add w,8\n" - "sub %[dr1], #16 @add w,8\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s3_max_loop\n" - "3: @loop\n" - "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" - "ble 4f @ble exit\n" - "2: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" - "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0, d0,d0\n" - "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "bne 2b @bne s3_max_loop_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain), - [vcoef_6] "+w"(vcoef_6), - [vzero] "+w"(vzero) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9"); - } -#endif - // int w = w_even - 1; - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp1 = 0.f; // std::numeric_limits::min(); - float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; - for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp1 += (r0[i] + r1[i]); + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + auto dr2 = r2; + if (h == 0) { + dr0 = r0; + dr1 = r0; + dr2 = r1; + r0 = r1; + r1 = r2; + r2 = r1 + win; + } else { + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; } - dout_ch[w_even >> 1] = tmp1 * tmp2; - // cnt ++; - } - - r0 = r1; - r1 = r0 + win; - r2 = r1 + win; - dout_ch += wout; - int h = 2; - for (; h < h_even; h += 2) { - // deal with left pad - float sum0 = r0[0] + r0[1]; - float sum1 = r1[0] + r1[1]; - float sum2 = r2[0] + r2[1]; - dout_ch[0] = (sum0 + sum1 + sum2) * coef_6; -#ifdef __aarch64__ - w = 1; - cnt = 1; - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - float32x4_t vr2_1234 = vld1q_f32(&r2[w]); - float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); - float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); - float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); - float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); - float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); - vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); - vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); - vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112); - - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); - float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); - float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); - vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); - float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); - vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); - vst1q_f32(&dout_ch[cnt], vrst); - cnt += 4; - } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - float32x4_t vr2 = vld1q_f32(&r2[w]); - vr0 = vsetq_lane_f32(0.f, vr0, 3); - vr1 = vsetq_lane_f32(0.f, vr1, 3); - vr2 = vsetq_lane_f32(0.f, vr2, 3); - float32x4_t vsum1 = vaddq_f32(vr0, vr1); - vsum1 = vaddq_f32(vsum1, vr2); - float32x2_t vsum2 = - vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); - float32x2_t vsum = vpadd_f32(vsum2, vsum2); - dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef; - cnt++; + if (h * S + K - P > hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = dr0; + case 1: + dr2 = dr0; + default: + break; + } } + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile( + /* preocess left */ + P3x3S2_INIT P3x3S2P1_MAX P3x3S2P0_MAX "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vmin] "w"(vmin) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v31"); #else - dr_out = dout_ch + 1; - dr0 = (r0 + 1); - dr1 = (r1 + 1); - dr2 = (r2 + 1); - cnt_num = w_unroll_size >> 3; - cnt_num_remain = w_unroll_remain >> 1; - if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n" - "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" - "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" - "vadd.f32 q11, q2, q5 @max q1,q1,q3\n" - "vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n" - "vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n" - "vadd.f32 q8, q11, q8 @max q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234,2345\n" - "vadd.f32 q5, q7, q1 @add 5678,4567\n" - "vadd.f32 q4, q4, q2 @add 3456,sum1\n" - "vadd.f32 q5, q5, q3 @add 6789,sum2\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s21 @mov\n" - "vmov.f32 s19, s23 @mov\n" - "vmul.f32 q4, q4, %q[vcoef] @mul\n" - "sub %[dr0], #16 @add w,8\n" - "sub %[dr1], #16 @add w,8\n" - "sub %[dr2], #16 @add w, 8\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s3_max_loop_mid\n" - "3: @loop\n" - "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" - "ble 4f @ble exit1\n" - "2: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" - "vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" - "vadd.f32 q0, q0, q2 @add q0,q0,q1\n" - "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "sub %[dr2], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "bne 2b @bne s3_max_loop_mid_1\n" - "4: @exit\n" + /* preocess left */ + P3x3S2_INIT P3x3S2P1_MAX P3x3S2P0_MAX "2: \n" /* end */ : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain), - [vcoef] "+w"(vcoef), - [vzero] "+w"(vzero) - : "r"(dr0), - "r"(dr1), - "r"(dr2), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) + [cnt_num] "+r"(cnt_num) + : [vmin] "w"(vmin) : "cc", "memory", "q0", @@ -2152,290 +1797,214 @@ void pooling3x3s2p1_avg(const float* din, "q9", "q10", "q11", - "q12"); - } + "q15"); #endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp1 = 0.f; - float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef; - for (int i = wstart; i < wend; i++) { - tmp1 += (r0[i] + r1[i] + r2[i]); + dr0 -= 8; + dr1 -= 8; + dr2 -= 8; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = dr0[0]; + for (int i = 0; i < wend - st; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + tmp = std::max(tmp, dr2[i]); } - dout_ch[w_even >> 1] = tmp1 * tmp2; - // cnt ++; + *(dr_out++) = tmp; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + dr2 += S - (st - wstart); + wstart += S; } - r0 = r2; - r1 = r0 + win; - r2 = r1 + win; - dout_ch += wout; + data_out_channel += wout; } + } + } +} - if (h_remain > 0) { - // deal with bottom pad - // first row with zero pad - int hstart = (h >> 1) * stride - padding; - int hend = std::min(std::min(hstart + kernel, hin + padding), hin); - if (hstart == hend - 1) { // only one line - dout_ch[0] = (r0[0] + r0[1]) * coef_2; -#ifdef __aarch64__ - w = 1; - cnt = 1; - for (; w < w_unroll_size; w += 8) { - float32x4_t vsum_1234 = vld1q_f32(&r0[w]); - float32x4_t vsum_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vsum_9101112 = vld1q_f32(&r0[w + 8]); - - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); - float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); - float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); - vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); - float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); - vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); - vsum_123_345 = vsetq_lane_f32( - vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); - vsum_123_345 = vsetq_lane_f32( - vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); - vsum_123_345 = vsetq_lane_f32( - vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_3); - vst1q_f32(&dout_ch[cnt], vrst); - cnt += 4; +void pooling3x3s2p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 3; + const int P = 1; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } + + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + memset(zero_ptr, 0, win * sizeof(float)); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + for (int h = 0; h < hout; h++) { + float coef_h = 1.f / 3; + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + auto dr2 = r2; + if (h == 0) { + if (exclusive) { + coef_h = 0.5f; } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - vr0 = vsetq_lane_f32(0.f, vr0, 3); - float32x2_t vsum = vpadd_f32(vget_low_f32(vr0), vget_high_f32(vr0)); - vsum = vpadd_f32(vsum, vsum); - dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef_3; - cnt++; + dr0 = zero_ptr; + dr1 = r0; + dr2 = r1; + r0 = r1; + r1 = r2; + r2 = r1 + win; + } else { + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + } + if (h * S + K - P > hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = zero_ptr; + dr2 = zero_ptr; + if (exclusive) { + coef_h = 1.f; + } else { + coef_h = 0.5f; + } + break; + case 1: + dr2 = zero_ptr; + if (exclusive) { + if (fabsf(coef_h - 0.5f) < 1e-6f) { + coef_h = 1.f; + } else { + coef_h = 0.5f; + } + } else { + coef_h = 1.f / 3; + } + default: + break; } + } + float32x4_t vcoef = vdupq_n_f32(coef_h / 3); + float coef_left_most = exclusive ? coef_h / 2 : coef_h / 3; + float coef_left[4] = { + coef_left_most, coef_h / 3, coef_h / 3, coef_h / 3}; + float32x4_t vcoef_left = vld1q_f32(coef_left); + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile("movi v31.4s, #0\n" + /* preocess left */ + P3x3S2_INIT P3x3S2P1_AVG P3x3S2P0_AVG "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), [vcoef_left] "w"(vcoef_left) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11", + "v31"); #else - dr_out = dout_ch + 1; - dr0 = (r0 + 1); - cnt_num = w_unroll_size >> 3; - cnt_num_remain = w_unroll_remain >> 1; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d12-d15}, [%[dr0]]! @load d0-d3,dr0\n" - "vld1.f32 {d16-d17}, [%[dr0]]! @load d0-d3,dr0\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234,2345\n" - "vadd.f32 q5, q7, q1 @add 5678,4567\n" - "vadd.f32 q4, q4, q2 @add 3456,sum1\n" - "vadd.f32 q5, q5, q3 @add 6789,sum2\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s21 @mov\n" - "vmov.f32 s19, s23 @mov\n" - "vmul.f32 q4, q4, %q[vcoef_3] @mul\n" - "sub %[dr0], #16 @add w,6\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop\n" - "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" - "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" - "vmul.f32 d0, d0, %e[vcoef_3] @mul\n" - "sub %[dr0], #8 @add w,2\n" - "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain), - [vcoef_3] "+w"(vcoef_3), - [vzero] "+w"(vzero) - : "r"(dr0), - "r"(dr1), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8"); - } + asm volatile("vmov.i32 q15, #0\n" + /* preocess left */ + P3x3S2_INIT P3x3S2P1_AVG P3x3S2P0_AVG "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), [vcoef_left] "w"(vcoef_left) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q15"); #endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp1 = 0.f; - float tmp2 = exclusive ? 1.0f / (1.f * (wend - wstart)) : coef; - for (int i = wstart; i < wend; i++) { - tmp1 += r0[i]; + dr0 -= 8; + dr1 -= 8; + dr2 -= 8; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = wstart + K; // std::min(wstart + K, win); + float coef = coef_h / 3.f; + if (wstart + K > win) { + wend = win; + if (!exclusive && wstart + K - win == 2) { + coef = coef_h / 2; } - dout_ch[w_even >> 1] = tmp1 * tmp2; } - } else { // two lines - dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; -#ifdef __aarch64__ - w = 1; - cnt = 1; - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - - float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); - float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); - float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); - float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); - float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); - vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); - float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); - vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); - vsum_123_345 = vsetq_lane_f32( - vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); - vsum_123_345 = vsetq_lane_f32( - vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); - vsum_123_345 = vsetq_lane_f32( - vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); - vst1q_f32(&dout_ch[cnt], vrst); - cnt += 4; + int st = wstart > 0 ? wstart : 0; + if (exclusive) { + coef = coef_h / (wend - st); } - for (; w < w_even - 1; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - vr0 = vsetq_lane_f32(0.f, vr0, 3); - vr1 = vsetq_lane_f32(0.f, vr1, 3); - float32x4_t vsum1 = vaddq_f32(vr0, vr1); - float32x2_t vsum2 = - vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); - vsum2 = vpadd_f32(vsum2, vsum2); - float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); - dout_ch[cnt] = vget_lane_f32(vrst, 0); - cnt++; - } -#else - dr_out = dout_ch + 1; - dr0 = (r0 + 1); - dr1 = (r1 + 1); - cnt_num = w_unroll_size >> 3; - cnt_num_remain = w_unroll_remain >> 1; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" - "vadd.f32 q6, q0, q3 @add q0,q0,q2 1234\n" - "vadd.f32 q7, q1, q4 @add q1,q1,q3 5678\n" - "vadd.f32 q8, q2, q5 @add q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234,2345\n" - "vadd.f32 q5, q7, q1 @add 5678,4567\n" - "vadd.f32 q4, q4, q2 @add 3456,sum1\n" - "vadd.f32 q5, q5, q3 @add 6789,sum2\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s21 @mov\n" - "vmov.f32 s19, s23 @mov\n" - "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" - "sub %[dr0], #16 @add w,8\n" - "sub %[dr1], #16 @add w,8\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop\n" - "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" - "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" - "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain), - [vcoef_6] "+w"(vcoef_6), - [vzero] "+w"(vzero) - : "r"(dr0), - "r"(dr1), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9"); - } -#endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp1 = 0.f; - float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; - for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp1 += (r0[i] + r1[i]); - } - dout_ch[w_even >> 1] = tmp1 * tmp2; + float tmp = 0.f; + for (int i = 0; i < wend - st; i++) { + tmp += dr0[i] + dr1[i] + dr2[i]; } + *(dr_out++) = tmp * coef; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + dr2 += S - (st - wstart); + wstart += S; } + data_out_channel += wout; } } } + TargetFree(TARGET(kARM), zero_ptr); } void pooling3x3s2p0_max(const float* din, @@ -2447,357 +2016,117 @@ void pooling3x3s2p0_max(const float* din, int chin, int hin, int win) { - int kernel = 3; - int stride = 2; - int padding = 0; + const int K = 3; + const int P = 0; + const int S = 2; + int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + if (w_unroll_remian == 0 && w_unroll_size * 4 * S + K > win) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } - int w_needed = (wout << 1) + 1; - int h_needed = (hout << 1) + 1; - int w_limit = w_needed > win ? win : w_needed; - int h_limit = h_needed > hin ? hin : h_needed; - int w_even = ((w_limit - 1) >> 1) << 1; - int h_even = ((h_limit - 1) >> 1) << 1; - int w_unroll_size = (w_even >> 3) << 3; - int w_unroll_remain = w_even - w_unroll_size; - int w_remain = w_needed - w_limit; - int h_remain = h_needed - h_limit; - int w_in_2 = win << 1; - float minval = std::numeric_limits::lowest(); for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; const float* r2 = r1 + win; - // w = w_in - 8; - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - const float* dr2 = r2; - int w = 0; - int cnt = 0; - // dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], - // r1[1])); - // first row with zero pad - // r0 = r1; - // r1 = r0 + w_in; - // r2 = r1 + w_in; - // dout_channel += w_out; - int h = 0; - for (; h < h_even; h += 2) { - // deal with left pad - float maxr0 = std::max(r0[0], r0[1]); - float maxr1 = std::max(r1[0], r1[1]); - float maxr2 = std::max(r2[0], r2[1]); -// dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); -#ifdef __aarch64__ - w = 0; - cnt = 0; - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - float32x4_t vr2_1234 = vld1q_f32(&r2[w]); - float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); - float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); - float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); - vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_56_78 = - vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); - float32x2_t vmax_67_89 = - vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&dout_ch[cnt], vmax_123_345); - vst1_f32(&dout_ch[cnt + 2], vmax_567_789); - cnt += 4; - } - for (; w < w_even; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - float32x4_t vr2 = vld1q_f32(&r2[w]); - vr0 = vsetq_lane_f32(minval, vr0, 3); - vr1 = vsetq_lane_f32(minval, vr1, 3); - vr2 = vsetq_lane_f32(minval, vr2, 3); - float32x4_t vmax1 = vmaxq_f32(vr0, vr1); - vmax1 = vmaxq_f32(vmax1, vr2); - float32x2_t vmax2 = - vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); - float32x2_t vmax = vpmax_f32(vmax2, vmax2); - dout_ch[cnt] = vget_lane_f32(vmax, 0); - cnt++; + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + auto dr2 = r2; + if (h * S + K - P > hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = r0; + case 1: + dr2 = r0; + default: + break; + } } + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); #else - dr_out = dout_ch; // + 1; - dr0 = r0; // (r0 + 1); - dr1 = r1; // (r1 + 1); - dr2 = r2; // (r2 + 1); - int cnt_num = w_unroll_size >> 3; - int cnt_num_remain = w_unroll_remain >> 1; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d16}, [%[dr2]]! @load d4-d7,dr1\n" - "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" - "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" - "vmax.f32 d22, d4, d10 @max q1,q1,q3\n" - "vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n" - "vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n" - "vmax.f32 d2, d22, d16 @max q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q4, q0, q3, #1 @vext 2345\n" - "vext.f32 q2, q3, q1, #1 @vext 6789\n" - "vpmax.f32 d10, d0, d1 @pmax " - "d10,vmax_1234,vmax_1234\n" - "vpmax.f32 d12, d6, d7 @pmax " - "d12,vmax_5678,vmax_5678\n" - "vpmax.f32 d11, d8, d9 @pmax " - "d11,vmax_2345,vmax_2345\n" - "vpmax.f32 d13, d4, d5 @pmax " - "d13,vmax_6789,vmax_6789\n" - "vmax.f32 d0, d10, d11 @pmax " - "d0,vmax_12_34,vmax_23_45\n" - "vmax.f32 d1, d12, d13 @pmax " - "d1,vmax_56_78,vmax_67_89\n" - "sub %[dr0], #8 @add w,8\n" - "sub %[dr1], #8 @add w,8\n" - "sub %[dr2], #8 @add w,8\n" - "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "bne 1b @bne s3_max_loop_mid\n" - "3: @loop\n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" - "ble 4f @ble exit1\n" - "2: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" - "vmov.f32 s3,s2 @movs3,s2\n" - "vmov.f32 s7,s6 @movs7,s6\n" - "vmov.f32 s11,s10 @movs11,s10\n" - "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" - "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" - "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "sub %[dr2], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" - "bne 2b @bne s3_max_loop_mid_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr2] "+r"(dr2), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain) - : "r"(dr0), - "r"(dr1), - "r"(dr2), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12"); - } + asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); #endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp = r0[wstart]; // std::numeric_limits::min(); + dr0 -= 8; + dr1 -= 8; + dr2 -= 8; + } + // deal with right pad + int rem = win - (w_unroll_size * 4) * S; + int wstart = 0; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, rem); + float tmp = dr0[wstart]; // std::numeric_limits::min(); for (int i = wstart; i < wend; i++) { - tmp = std::max(tmp, std::max(r0[i], r1[i])); - tmp = std::max(tmp, r2[i]); + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + tmp = std::max(tmp, dr2[i]); } - dout_ch[w_even >> 1] = tmp; - // cnt ++; + *(dr_out++) = tmp; + wstart += S; } + r0 = r2; r1 = r0 + win; r2 = r1 + win; - dout_ch += wout; - } - - if (h_remain > 0) { -// deal with bottom pad -// first row with zero pad -// int hstart = (h >> 1) * stride_h - pad_h; -// int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); -// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], -// r1[1])); -#ifdef __aarch64__ - w = 0; - cnt = 0; - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); - float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); - float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); - float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); - float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); - float32x2_t vmax_12_34 = - vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); - float32x2_t vmax_23_45 = - vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); - float32x2_t vmax_56_78 = - vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); - float32x2_t vmax_67_89 = - vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); - float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); - float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&dout_ch[cnt], vmax_123_345); - vst1_f32(&dout_ch[cnt + 2], vmax_567_789); - cnt += 4; - } - for (; w < w_even; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - vr0 = vsetq_lane_f32(minval, vr0, 3); - vr1 = vsetq_lane_f32(minval, vr1, 3); - float32x4_t vmax1 = vmaxq_f32(vr0, vr1); - float32x2_t vmax2 = - vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); - vmax2 = vpmax_f32(vmax2, vmax2); - dout_ch[cnt] = vget_lane_f32(vmax2, 0); - cnt++; - } -#else - dr_out = dout_ch; // + 1; - dr0 = r0; // (r0 + 1); - dr1 = r1; // (r1 + 1); - int cnt_num = w_unroll_size >> 3; - int cnt_num_remain = w_unroll_remain >> 1; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" - "vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n" - "vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n" - "vmax.f32 d16, d4, d10 @max q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7,s6\n" - "vext.f32 q0, q6, q7, #1 @vext q0,2345\n" - "vext.f32 q1, q7, q8, #1 @vext q1,6789\n" - "vpmax.f32 d4, d12, d13 @pmax " - "d4,vmax_1234,vmax_1234\n" - "vpmax.f32 d6, d14, d15 @pmax " - "d6,vmax_5678,vmax_5678\n" - "vpmax.f32 d5, d0, d1 @pmax " - "d5,vmax_2345,vmax_2345\n" - "vpmax.f32 d7, d2, d3 @pmax " - "d7,vmax_6789,vmax_6789\n" - "vmax.f32 d8, d4, d5 @max " - "d2,vmax_12_34,vmax_23_45\n" - "vmax.f32 d9, d6, d7 @max " - "d2,vmax_56_78,vmax_67_89\n" - "sub %[dr0], #8 @add w,8\n" - "sub %[dr1], #8 @add w,8\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "subs %[cnt_num], #1 @subs cnt_num,#1\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vmov.f32 s3,s2 @movs3,s2\n" - "vmov.f32 s7,s6 @movs7,s6\n" - "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" - "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" - "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain) - : "r"(dr0), - "r"(dr1), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9"); - } -#endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp = r0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp = std::max(tmp, std::max(r0[i], r1[i])); - } - dout_ch[w_even >> 1] = tmp; - } + data_out_channel += wout; } } } @@ -2813,358 +2142,134 @@ void pooling3x3s2p0_avg(const float* din, int hin, int win, bool exclusive) { - int kernel = 3; - int stride = 2; - int padding = 0; + const int K = 3; + const int P = 0; + const int S = 2; + int size_channel_out = wout * hout; int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + if (w_unroll_remian == 0 && w_unroll_size * 4 * S + K > win) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } + + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + memset(zero_ptr, 0, win * sizeof(float)); - int w_needed = (wout << 1) + 1; - int h_needed = (hout << 1) + 1; - int w_limit = w_needed > win ? win : w_needed; - int h_limit = h_needed > hin ? hin : h_needed; - int w_even = ((w_limit - 1) >> 1) << 1; - int h_even = ((h_limit - 1) >> 1) << 1; - int w_unroll_size = (w_even >> 3) << 3; - int w_unroll_remain = w_even - w_unroll_size; - int w_remain = w_needed - w_limit; - int h_remain = h_needed - h_limit; - int w_in_2 = win << 1; - const float coef = 1.f / 9.f; - const float coef_6 = exclusive ? 1.f / 6.f : coef; - float32x4_t vcoef = vdupq_n_f32(coef); - float32x4_t vcoef_6 = vdupq_n_f32(coef_6); for (int n = 0; n < num; ++n) { - float* dout_batch = dout + n * chout * size_channel_out; - const float* din_batch = din + n * chin * size_channel_in; + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* dout_ch = dout_batch + c * size_channel_out; - const float* din_ch = din_batch + c * size_channel_in; - const float* r0 = din_ch; + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; const float* r1 = r0 + win; const float* r2 = r1 + win; - // w = w_in - 8; - float* dr_out = dout_ch; - const float* dr0 = r0; - const float* dr1 = r1; - const float* dr2 = r2; - - float32x4_t vzero = vdupq_n_f32(0.f); - - int h = 0; - for (; h < h_even; h += 2) { -// LOG(INFO) << "h: " << h <<", dr0:" << r0 << ", dr1: " << r1 << -// ",dr2: " < hin) { + switch (h * S + K - P - hin) { + case 2: + dr1 = zero_ptr; + dr2 = zero_ptr; + coef_h = 1.f; + break; + case 1: + dr2 = zero_ptr; + coef_h = 0.5f; + break; + default: + break; + } } + float32x4_t vcoef = vdupq_n_f32(coef_h / 3); + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile(P3x3S2P0_INIT P3x3S2P0_AVG + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef) + : "cc", + "memory", + "v0", + "v1", + "v2", + "v3", + "v4", + "v5", + "v6", + "v7", + "v8", + "v9", + "v10", + "v11"); #else - dr_out = dout_ch; // + 1; - dr0 = r0; // (r0 + 1); - dr1 = r1; // (r1 + 1); - dr2 = r2; // (r2 + 1); - int cnt_num = w_unroll_size >> 3; - int cnt_num_remain = w_unroll_remain >> 1; - // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << - // cnt_num_remain; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" - "ble 3f @ble exit\n" - "s3_ave_loop_mid_p0: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, dr2\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr2\n" - "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" - "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" - "vadd.f32 d22, d4, d10 @max q1,q1,q3\n" - "vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n" - "vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n" - "vadd.f32 d16, d22, d16 @max q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, 2345\n" - "vadd.f32 q5, q7, q1 @add 5678, 4567\n" - "vadd.f32 q4, q4, q2 @add 3456, sum1\n" - "vadd.f32 q5, q5, q3 @add 6789, sum2\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s21 @mov\n" - "vmov.f32 s19, s23 @mov\n" - "vmul.f32 q4, q4, %q[vcoef] @mul\n" - "sub %[dr0], #8 @add w,8\n" - "sub %[dr1], #8 @add w,8\n" - "sub %[dr2], #8 @add w,8\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne s3_ave_loop_mid_p0 @bne s3_max_loop_mid\n" - "3: @loop\n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" - "ble 4f @ble exit1\n" - "s3_ave_loop_mid_1_p0: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" - "vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" - "vadd.f32 q0, q0, q2 @add q0,q0,q1\n" - "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "sub %[dr2], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "bne s3_ave_loop_mid_1_p0 @bne s3_max_loop_mid_1\n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr2] "+r"(dr2), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain), - [vcoef] "+w"(vcoef), - [vzero] "+w"(vzero) - : "r"(dr0), - "r"(dr1), - "r"(dr2), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11", - "q12"); - } + asm volatile(P3x3S2P0_INIT P3x3S2P0_AVG + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef) + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); #endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp1 = 0.f; - float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef; + dr0 -= 8; + dr1 -= 8; + dr2 -= 8; + } + // deal with right pad + int rem = win - (w_unroll_size * 4) * S; + int wstart = 0; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, rem); + float coef = coef_h / (wend - wstart); + float tmp = 0.f; for (int i = wstart; i < wend; i++) { - tmp1 += (r0[i] + r1[i] + r2[i]); + tmp += dr0[i]; + tmp += dr1[i]; + tmp += dr2[i]; } - dout_ch[w_even >> 1] = tmp1 * tmp2; - // cnt ++; + tmp *= coef; + *(dr_out++) = tmp; + wstart += S; } + r0 = r2; r1 = r0 + win; r2 = r1 + win; - dout_ch += wout; - } - - if (h_remain > 0) { -// deal with bottom pad -// first row with zero pad -// int hstart = (h >> 1) * stride_h - pad_h; -// int hend = std::min(std::min(hstart + kernel_h, hin + padding_h), -// hin); data_out_channel[0] =(r0[0] + r0[1] + r0[2] + r1[0] + r1[1] + -// r1[2]) / 9.f; -#ifdef __aarch64__ - int w = 0; - int cnt = 0; - for (; w < w_unroll_size; w += 8) { - float32x4_t vr0_1234 = vld1q_f32(&r0[w]); - float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); - float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); - float32x4_t vr1_1234 = vld1q_f32(&r1[w]); - float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); - float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); - - float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); - float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); - float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); - float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); - float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); - float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); - float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); - float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); - vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); - float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); - vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); - vsum_123_345 = - vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); - vst1q_f32(&dout_ch[cnt], vrst); - cnt += 4; - } - for (; w < w_even; w += 2) { - float32x4_t vr0 = vld1q_f32(&r0[w]); - float32x4_t vr1 = vld1q_f32(&r1[w]); - vr0 = vsetq_lane_f32(0.f, vr0, 3); - vr1 = vsetq_lane_f32(0.f, vr1, 3); - float32x4_t vsum1 = vaddq_f32(vr0, vr1); - float32x2_t vsum2 = - vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); - vsum2 = vpadd_f32(vsum2, vsum2); - float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); - dout_ch[cnt] = vget_lane_f32(vrst, 0); - cnt++; - } -#else - dr_out = dout_ch; // + 1; - dr0 = r0; // (r0 + 1); - dr1 = r1; // (r1 + 1); - int cnt_num = w_unroll_size >> 3; - int cnt_num_remain = w_unroll_remain >> 1; - // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << - // cnt_num_remain; - if (cnt_num > 0 || cnt_num_remain > 0) { - asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num,0\n" - "ble 2f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" - "vadd.f32 q6, q0, q3 @max q0,q0,q2 1234\n" - "vadd.f32 q7, q1, q4 @max q1,q1,q3 5678\n" - "vadd.f32 d16, d4, d10 @max q1,q1,q3 9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234,2345\n" - "vadd.f32 q5, q7, q1 @add 5678,4567\n" - "vadd.f32 q4, q4, q2 @add 3456,sum1\n" - "vadd.f32 q5, q5, q3 @add 6789,sum2\n" - "vmov.f32 s17, s18 @mov\n" - "vmov.f32 s18, s21 @mov\n" - "vmov.f32 s19, s23 @mov\n" - "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" - "sub %[dr0], #8 @add w,8\n" - "sub %[dr1], #8 @add w,8\n" - "subs %[cnt_num], #1 @cnt_num--\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" - "bne 1b @bne s3_max_loop_bot\n" - "2: @loop\n" - "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain, 0\n" - "ble 3f @ble exit\n" - "4: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" - "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" - "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" - "sub %[dr0], #8 @add w,6\n" - "sub %[dr1], #8 @add w,6\n" - "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" - "bne 4b @bne s3_max_loop_bot_1\n" - "3: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), - [cnt_num_remain] "+r"(cnt_num_remain), - [vcoef_6] "+w"(vcoef_6), - [vzero] "+w"(vzero) - : "r"(dr0), - "r"(dr1), - "r"(dr_out), - "r"(cnt_num), - "r"(cnt_num_remain) - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); - } - -#endif - if (w_remain > 0) { - // deal with right pad - int wstart = (w_even >> 1) * stride - padding; - int wend = std::min(std::min(wstart + kernel, win + padding), win); - float tmp1 = 0.f; - float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; - for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp1 += (r0[i] + r1[i]); - } - dout_ch[w_even >> 1] = tmp1 * tmp2; - } + data_out_channel += wout; } } } + TargetFree(TARGET(kARM), zero_ptr); } } // namespace math diff --git a/lite/backends/arm/math/pooling.h b/lite/backends/arm/math/pooling.h index 8fc9e0c4e0..9288f27bbc 100644 --- a/lite/backends/arm/math/pooling.h +++ b/lite/backends/arm/math/pooling.h @@ -116,6 +116,27 @@ void pooling3x3s2p1_max(const float* din, int hin, int win); +void pooling3x3s1p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win); + +void pooling3x3s1p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive); + void pooling3x3s2p1_avg(const float* din, float* dout, int num, diff --git a/lite/kernels/arm/pool_compute.cc b/lite/kernels/arm/pool_compute.cc index d95d658cf9..500e81118e 100644 --- a/lite/kernels/arm/pool_compute.cc +++ b/lite/kernels/arm/pool_compute.cc @@ -137,6 +137,34 @@ void PoolCompute::Run() { VLOG(3) << "invoking pooling3x3s1p1_avg"; return; } + } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && + kps_equal) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s1p0_max(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3]); + VLOG(3) << "pooling3x3s1p0_max"; + return; + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s1p0_avg(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + exclusive); + VLOG(3) << "invoking pooling3x3s1p0_avg"; + return; + } } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && kps_equal) { if (pooling_type == "max") { diff --git a/lite/tests/math/CMakeLists.txt b/lite/tests/math/CMakeLists.txt index 342901f075..87324375e0 100644 --- a/lite/tests/math/CMakeLists.txt +++ b/lite/tests/math/CMakeLists.txt @@ -5,4 +5,5 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_transpose_compute_test SRCS conv_transpose_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_int8_compute_test SRCS conv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(pool_compute_test SRCS pool_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/tests/math/pool_compute_test.cc b/lite/tests/math/pool_compute_test.cc new file mode 100644 index 0000000000..18267f14b4 --- /dev/null +++ b/lite/tests/math/pool_compute_test.cc @@ -0,0 +1,454 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#include "lite/tests/utils/naive_math_impl.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +#ifdef LITE_WITH_ARM +#include "lite/kernels/arm/pool_compute.h" +#endif // LITE_WITH_ARM + +DEFINE_int32(power_mode, + 3, + "power mode: " + "0 for POWER_HIGH;" + "1 for POWER_LOW;" + "2 for POWER_FULL;" + "3 for NO_BIND"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(batch, 1, "batch size"); +DEFINE_int32(in_channel, 32, "input channel"); +DEFINE_int32(in_height, 112, "input height"); +DEFINE_int32(in_width, 112, "input width"); + +DEFINE_int32(kernel_h, 3, "kernel height"); +DEFINE_int32(kernel_w, 3, "kernel width"); +DEFINE_int32(pad_h, 1, "pad height"); +DEFINE_int32(pad_w, 1, "pad width"); +DEFINE_int32(stride_h, 1, "stride height"); +DEFINE_int32(stride_w, 1, "stride width"); + +DEFINE_bool(ceil_mode, true, "do ceil_mode"); +DEFINE_bool(flag_global, true, "global pooling"); +DEFINE_bool(exclusive, true, "do exclusive"); +DEFINE_bool(adaptive, false, "no do adaptive"); +DEFINE_bool(use_quantizer, false, "no do use_quantizer"); + +DEFINE_string(pooling_type, "max", "do max pooling"); + +typedef paddle::lite::DDim DDim; +typedef paddle::lite::Tensor Tensor; +typedef paddle::lite::operators::PoolParam PoolParam; +using paddle::lite::Timer; + +DDim compute_out_dim(const DDim& dim_in, + const paddle::lite::operators::PoolParam& param) { + DDim dim_out = dim_in; + auto kernel_h = param.ksize[0]; + auto kernel_w = param.ksize[1]; + auto h = dim_in[2]; + auto w = dim_in[3]; + int pad_h = param.paddings[0]; + int pad_w = param.paddings[1]; + int stride_h = param.strides[0]; + int stride_w = param.strides[1]; + bool ceil_mode = param.ceil_mode; + bool flag_global = param.global_pooling; + int hout = 1; + int wout = 1; + if (!flag_global) { + if (!ceil_mode) { + hout = (h - kernel_h + 2 * pad_h) / stride_h + 1; + wout = (w - kernel_w + 2 * pad_w) / stride_w + 1; + } else { + hout = (h - kernel_h + 2 * pad_h + stride_h - 1) / stride_h + 1; + wout = (w - kernel_w + 2 * pad_w + stride_w - 1) / stride_w + 1; + } + } + dim_out[2] = hout; + dim_out[3] = wout; + return dim_out; +} + +void pooling_basic(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, + bool global_pooling, + bool exclusive, + bool adaptive, + bool ceil_mode, + bool use_quantizer, + const std::string& pooling_type) { + // no need to pad input tensor, border is zero pad inside this function + memset(dout, 0, num * chout * hout * wout * sizeof(float)); + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + int size_channel_in = win * hin; + int size_channel_out = wout * hout; + if (global_pooling) { + if (pooling_type == "max") { // Pooling_max + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* din_ch = din_batch + c * size_channel_in; // in address + float tmp1 = din_ch[0]; + for (int i = 0; i < size_channel_in; ++i) { + float tmp2 = din_ch[i]; + tmp1 = tmp1 > tmp2 ? tmp1 : tmp2; + } + dout_batch[c] = tmp1; + } + } + } else if (pooling_type == "avg") { + // Pooling_average_include_padding + // Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* din_ch = din_batch + c * size_channel_in; // in address + float sum = 0.f; + for (int i = 0; i < size_channel_in; ++i) { + sum += din_ch[i]; + } + dout_batch[c] = sum / size_channel_in; + } + } + } else { + LOG(FATAL) << "unsupported pooling type: " << pooling_type; + } + } else { + for (int ind_n = 0; ind_n < num; ++ind_n) { + for (int ind_c = 0; ind_c < chin; ++ind_c) { + for (int ind_h = 0; ind_h < hout; ++ind_h) { + int sh = ind_h * stride_h; + int eh = sh + kernel_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > hin ? hin : eh - pad_h; + for (int ind_w = 0; ind_w < wout; ++ind_w) { + int sw = ind_w * stride_w; + int ew = sw + kernel_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > win ? win : ew - pad_w; + float result = static_cast(0); + int dst_ind = (ind_n * chout + ind_c) * size_channel_out + + ind_h * wout + ind_w; + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_ind = + (ind_n * chin + ind_c) * size_channel_in + kh * win + kw; + if (kh == sh && kw == sw) { + result = din[src_ind]; + } else { + if (pooling_type == "max") { + result = result >= din[src_ind] ? result : din[src_ind]; + } else if (pooling_type == "avg") { + result += din[src_ind]; + } + } + } + } + if (pooling_type == "avg") { + if (exclusive) { + int div = (ew - sw) * (eh - sh); + div = div > 0 ? div : 1; + result /= div; + } else { + int bh = kernel_h; + int bw = kernel_w; + if (ew == win) { + bw = sw + kernel_w >= win + pad_w ? win + pad_w + : sw + kernel_w; + bw -= sw; + if (sw - pad_w < 0 && sw + kernel_w > win + pad_w) { + bw += pad_w; + } + } + if (eh == hin) { + bh = sh + kernel_h >= hin + pad_h ? hin + pad_h + : sh + kernel_h; + bh -= sh; + if (sh - pad_h < 0 && sh + kernel_h > hin + pad_h) { + bh += pad_h; + } + } + result /= bh * bw; + } + } + dout[dst_ind] = result; + } + } + } + } + } +} +#ifdef LITE_WITH_ARM +void test_pool_fp32(const std::vector& input_dims, + const std::vector& ksize, + const std::vector& strides, + const std::vector& pads, + bool ceil_mode, + bool flag_global, + bool exclusive, + bool adaptive, + bool use_quantizer, + std::string pooling_type, + const std::vector& thread_num, + const std::vector& power_mode) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + PoolParam param; + param.x = new Tensor; + param.x->set_precision(PRECISION(kFloat)); + param.ksize = ksize; + + param.strides = strides; + param.paddings = pads; + param.ceil_mode = ceil_mode; + param.global_pooling = flag_global; + param.pooling_type = pooling_type; + param.exclusive = exclusive; + param.adaptive = adaptive; + param.use_quantizer = use_quantizer; + + param.output = new Tensor; + param.output->set_precision(PRECISION(kFloat)); + + for (auto& cls : power_mode) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::PoolCompute pool; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + pool.SetParam(param); + pool.SetContext(std::move(ctx1)); + /// prepare for run + pool.PrepareForRun(); + + for (auto& dim_in : input_dims) { + DDim dim_out = compute_out_dim(dim_in, param); + if (dim_out[2] < 1 || dim_out[3] < 1) { + continue; + } + param.x->Resize(dim_in); + param.output->Resize(dim_out); + + paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f); + // paddle::lite::fill_tensor_const(*param.x, 1.f); + auto din = param.x->data(); + + Tensor tout_basic; + if (FLAGS_check_result) { + LOG(INFO) << "basic compute"; + tout_basic.set_precision(PRECISION(kFloat)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0.f); + auto dout_basic = tout_basic.mutable_data(); + pooling_basic(din, + dout_basic, + dim_in[0], + dim_out[1], + dim_out[2], + dim_out[3], + dim_in[1], + dim_in[2], + dim_in[3], + ksize, + strides, + pads, + flag_global, + exclusive, + adaptive, + ceil_mode, + use_quantizer, + pooling_type); + } + LOG(INFO) << "lite compute"; + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + pool.Launch(); + } + /// compute + Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + pool.Launch(); + t0.end(); + } + + double gops = 2.0 * dim_out.production() * ksize[0] * ksize[1]; + LOG(INFO) << "pool fp32: input shape: " << dim_in << ", output shape" + << dim_out << ", running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "din"; + print_tensor(*param.x); + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "lite result"; + print_tensor(*param.output); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kFloat)); + tensor_diff(tout_basic, *param.output, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test fp32 pool: input: " << dim_in + << ", output: " << dim_out + << ", kernel dim: " << ksize[0] << ", " << ksize[1] + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", global_pooling: " + << (flag_global ? "global" : "false") + << ", pooling_type: " << pooling_type + << ", ceil_mode: " << (ceil_mode ? "true" : "false") + << ", exclusive: " << (exclusive ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " failed!!\n"; + } + } + } + LOG(INFO) << "test fp32 pool: input: " << dim_in + << ", output: " << dim_out << ", kernel dim: " << ksize[0] + << ", " << ksize[1] << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", global_pooling: " << (flag_global ? "global" : "false") + << ", pooling_type: " << pooling_type + << ", ceil_mode: " << (ceil_mode ? "true" : "false") + << ", exclusive: " << (exclusive ? "true" : "false") + << ", threads: " << th << ", power_mode: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.output; +} +#else +void test_pool_fp32(const std::vector& input_dims, + const std::vector& ksize, + const std::vector& strides, + const std::vector& pads, + bool ceil_mode, + bool flag_global, + bool exclusive, + bool adaptive, + bool use_quantizer, + std::string pooling_type, + const std::vector& thread_num, + const std::vector& power_mode) {} +#endif // LITE_WITH_ARM + +#if 1 /// random param pool +TEST(TestPoolRand, test_pool_rand) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 16}) { + for (auto& kw : {1, 2, 3}) { + for (auto& kh : {1, 2, 3}) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1, 2}) { + for (auto& flag_global : {false, true}) { + for (auto& exclusive : {false, true}) { + for (auto& ceil_mode : {false, true}) { + for (auto& pooling_type : {"max", "avg"}) { + bool adaptive = false; + bool use_quantizer = false; + std::vector dims; + for (auto& batch : {1, 2}) { + for (auto& h : {1, 2, 3, 4, 11, 19, 32, 28}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_pool_fp32(dims, + {kh, kw}, + {stride, stride}, + {pad, pad}, + ceil_mode, + flag_global, + exclusive, + adaptive, + use_quantizer, + pooling_type, + {1, 2, 4}, + {FLAGS_power_mode}); + } + } + } + } + } + } + } + } + } + } +} +#endif /// random param conv + +#if 1 /// custom +TEST(TesPoolCustom, test_pool_fp32_custom_size) { + test_pool_fp32( + {DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})}, + {FLAGS_kernel_h, FLAGS_kernel_w}, + {FLAGS_stride_h, FLAGS_stride_w}, + {FLAGS_pad_h, FLAGS_pad_w}, + FLAGS_ceil_mode, + FLAGS_flag_global, + FLAGS_exclusive, + FLAGS_adaptive, + FLAGS_use_quantizer, + FLAGS_pooling_type, + {FLAGS_threads}, + {FLAGS_power_mode}); +} +#endif // custom -- GitLab