From eb1e983c4de538af5cb001309e3d8c25cbb08af5 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 10 Dec 2018 03:06:50 +0800 Subject: [PATCH] Optimize pooling3x3 --- src/operators/math/pooling.cpp | 6 - src/operators/math/pooling.h | 53 +++---- src/operators/math/pooling3x3.cpp | 254 +++++++++++++++++++++++++----- 3 files changed, 240 insertions(+), 73 deletions(-) diff --git a/src/operators/math/pooling.cpp b/src/operators/math/pooling.cpp index b4aba52b9b..1270e6a898 100644 --- a/src/operators/math/pooling.cpp +++ b/src/operators/math/pooling.cpp @@ -60,18 +60,12 @@ void Pooling

::operator()(const framework::Tensor &input, wstart = std::max(wstart, 0); PoolingVal

val; - // std::cout << "output[" << ph * output_width + pw << "]:" - // << std::endl; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { val += input_ptr[h * input_width + w]; - // std::cout << "input[" << h << "][" << w << "] = " - // << input_ptr[h * input_width + w] << std::endl; } } output_ptr[ph * output_width + pw] = val.Value(); - // std::cout << "output[" << ph * output_width + pw << "] = " - // << val.Value() << std::endl; } } } diff --git a/src/operators/math/pooling.h b/src/operators/math/pooling.h index 9fcfdf811f..8b60759003 100644 --- a/src/operators/math/pooling.h +++ b/src/operators/math/pooling.h @@ -34,45 +34,39 @@ template struct PoolingVal { float val; int count; - PoolingVal() { - val = -std::numeric_limits::max(); - count = 0; - } + PoolingVal() : count(0) { val = -std::numeric_limits::max(); } inline PoolingVal

&operator+=(const float &x) { val = std::max(val, x); - count += 1; + ++count; return *this; } - float Value() const { - if (count > 0) { - return val; - } - return 0.f; - } + inline float Value() { return (count > 0) ? val : 0.f; } }; template <> struct PoolingVal { float val; int count; - PoolingVal() { - val = 0.f; - count = 0; - } + PoolingVal() : val(0.f), count(0) {} inline PoolingVal &operator+=(const float &x) { val += x; - count += 1; + ++count; return *this; } - float Value() const { - if (count > 0) { - return val / count; - } - return 0.f; - } + inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; } }; #if defined(__ARM_NEON) || defined(__ARM_NEON__) +template +inline float32x4_t vPoolInitq_f32() { + return vdupq_n_f32(-std::numeric_limits::max()); +} + +template <> +inline float32x4_t vPoolInitq_f32() { + return vdupq_n_f32(0.f); +} + template inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { return vmaxq_f32(x1, x2); @@ -85,14 +79,15 @@ inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, } template -inline float32x4_t vPoolPostq_f32(const float32x4_t &x) { +inline float32x4_t vPoolPostq_f32(const float32x4_t &x, + const float32x4_t &post) { return x; } template <> -inline float32x4_t vPoolPostq_f32(const float32x4_t &x) { - float32x4_t avg = vdupq_n_f32(1.f / 9); - return vmulq_f32(avg, x); +inline float32x4_t vPoolPostq_f32(const float32x4_t &x, + const float32x4_t &post) { + return vmulq_f32(x, post); } #endif // __ARM_NEON__ @@ -107,13 +102,13 @@ inline float PoolPre(const float &x1, const float &x2) { } template -inline float PoolPost(const float &x) { +inline float PoolPost(const float &x, const float &post) { return x; } template <> -inline float PoolPost(const float &x) { - return 1.f / 9 * x; +inline float PoolPost(const float &x, const float &post) { + return x * post; } template diff --git a/src/operators/math/pooling3x3.cpp b/src/operators/math/pooling3x3.cpp index 8f1defa6cb..34eb2ec3dd 100644 --- a/src/operators/math/pooling3x3.cpp +++ b/src/operators/math/pooling3x3.cpp @@ -38,6 +38,126 @@ namespace math { output_ptr[w] = val.Value(); \ } +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +template +struct Pooling3x3ValidColLoadInput { + inline void operator()(const float *input, const int input_w, + const int valid_cols, float32x4x2_t &x0, // NOLINT + float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT + float32x4x2_t &y0) { // NOLINT + float fake_input[3][8]; + if (valid_cols == 1) { + for (int i = 0; i < 8; ++i, input += input_w) { + fake_input[0][i] = input[0]; + } + } else if (valid_cols == 2) { + for (int i = 0; i < 8; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + } + } else { + for (int i = 0; i < 8; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + fake_input[2][i] = input[2]; + } + } + y0.val[0] = vPoolInitq_f32

(); + y0.val[1] = vPoolInitq_f32

(); + for (int i = 0; i < valid_cols; ++i) { + x0.val[0] = vld1q_f32(fake_input[i]); + x0.val[1] = vld1q_f32(fake_input[i] + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + y0.val[0] = vPoolPreq_f32

(x1.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x1.val[1], y0.val[1]); + y0.val[0] = vPoolPreq_f32

(x2.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x2.val[1], y0.val[1]); + } + } +}; + +template +struct Pooling3x3ValidColLoadInput { + inline void operator()(const float *input, const int input_w, + const int valid_cols, float32x4x2_t &x0, // NOLINT + float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT + float32x4x2_t &y0) { // NOLINT + float fake_input[3][13]; + if (valid_cols == 1) { + for (int i = 0; i < 13; ++i, input += input_w) { + fake_input[0][i] = input[0]; + } + } else if (valid_cols == 2) { + for (int i = 0; i < 13; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + } + } else { + for (int i = 0; i < 13; ++i, input += input_w) { + fake_input[0][i] = input[0]; + fake_input[1][i] = input[1]; + fake_input[2][i] = input[2]; + } + } + for (int i = 0; i < valid_cols; ++i) { + x0 = vld2q_f32(fake_input[i]); + x1 = vld2q_f32(fake_input[i] + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + } + } +}; + +template +struct Pooling3x3NormalRowLoadInput { + inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT + float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT + float32x4x2_t &y0) { // NOLINT + x0.val[0] = vld1q_f32(input); + x0.val[1] = vld1q_f32(input + 4); + x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + y0.val[0] = vPoolPreq_f32

(x1.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x1.val[1], y0.val[1]); + y0.val[0] = vPoolPreq_f32

(x2.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x2.val[1], y0.val[1]); + } +}; + +template +struct Pooling3x3NormalRowLoadInput { + inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT + float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT + float32x4x2_t &y0) { // NOLINT + x0 = vld2q_f32(input); + x1 = vld2q_f32(input + 8); + x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); + } +}; +#endif // __ARM_NEON__ + template inline void Pooling3x3ValidCol(const float *input, const int h_output, const int h_output_end, const int w_output, @@ -48,7 +168,38 @@ inline void Pooling3x3ValidCol(const float *input, const int h_output, const int w_in_end = w_in_start + 3; const int w_start = w_in_start > 0 ? w_in_start : 0; const int w_end = w_in_end < input_w ? w_in_end : input_w; - for (int h = h_output; h < h_output_end; ++h) { + int remain_start = h_output; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int output_tiles = (h_output_end - h_output) / 6; + remain_start = h_output + output_tiles * 6; + int input_h_start = h_output * Stride - padding_h; + size_t input_offset = input_h_start * input_w + w_start; + size_t output_offset = h_output * output_w + w_output; + int valid_cols = w_end - w_start; + Pooling3x3ValidColLoadInput PoolingCompute; + float32x4x2_t x0, x1, x2, y0; + float32x4_t avg = vdupq_n_f32(1.f / (3 * valid_cols)); + for (int h = 0; h < output_tiles * 6; h += 6) { + float *output0 = output + output_offset; + float *output1 = output0 + output_w; + float *output2 = output1 + output_w; + float *output3 = output2 + output_w; + float *output4 = output3 + output_w; + float *output5 = output4 + output_w; + y0.val[0] = vPoolInitq_f32

(); + y0.val[1] = vPoolInitq_f32

(); + PoolingCompute(input + input_offset, input_w, valid_cols, x0, x1, x2, y0); + y0.val[0] = vPoolPostq_f32

(y0.val[0], avg); + y0.val[1] = vPoolPostq_f32

(y0.val[1], avg); + vst1q_lane_f32(output0, y0.val[0], 0); + vst1q_lane_f32(output1, y0.val[0], 1); + vst1q_lane_f32(output2, y0.val[0], 2); + vst1q_lane_f32(output3, y0.val[0], 3); + vst1q_lane_f32(output4, y0.val[1], 0); + vst1q_lane_f32(output5, y0.val[1], 1); + } +#endif + for (int h = remain_start; h < h_output_end; ++h) { PoolingVal

val; const int h_in_start = -padding_h + h * Stride; for (int i = 0; i < 3; ++i) { @@ -77,7 +228,28 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output, // border left POOLING3X3_NORMAL_BORDER(0, valid_w_start) // middle - for (int w = valid_w_start; w < valid_w_end; ++w) { + int remain_start = valid_w_start; +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + int output_tiles = (valid_w_end - valid_w_start) / 6; + remain_start = valid_w_start + output_tiles * 6; + Pooling3x3NormalRowLoadInput PoolingCompute; + float32x4x2_t x0, x1, x2, y0; + float32x4_t post = vdupq_n_f32(1.f / (3 * (h_end - h_start))); + for (int w = 0; w < output_tiles * 6; w += 6) { + int output_offset = valid_w_start + w; + int input_w_offset = output_offset * Stride - padding_w; + y0.val[0] = vPoolInitq_f32

(); + y0.val[1] = vPoolInitq_f32

(); + for (int h_in = h_start; h_in < h_end; ++h_in) { + PoolingCompute(input + h_in * input_w + input_w_offset, x0, x1, x2, y0); + } + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr + output_offset, y0.val[0]); + vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0.val[1])); + } +#endif // __ARM_NEON__ + for (int w = remain_start; w < valid_w_end; ++w) { PoolingVal

val; int input_start = -padding_w + w * Stride; for (int h_in = h_start; h_in < h_end; ++h_in) { @@ -112,6 +284,7 @@ struct Pooling3x3 { int valid_w_start = padding_w; int valid_w_end = output_w - valid_w_start; int valid_w = valid_w_end - valid_w_start; + float avg = 1.f / 9; #pragma omp parallel for for (int c = 0; c < output->dims()[1]; ++c) { @@ -157,6 +330,7 @@ struct Pooling3x3 { #if defined(__ARM_NEON__) || defined(__ARM_NEON) float32x4x2_t x0, x1, x2; float32x4x2_t y0, y1, y2; + float32x4_t post = vdupq_n_f32(1.f / 9); for (int loop = 0; loop < output_w_tiles; ++loop) { x0.val[0] = vld1q_f32(input_ptr0); x0.val[1] = vld1q_f32(input_ptr0 + 4); @@ -196,8 +370,8 @@ struct Pooling3x3 { y1.val[1] = vPoolPreq_f32

(y2.val[1], y1.val[1]); y0.val[0] = vPoolPreq_f32

(y2.val[0], y0.val[0]); y0.val[1] = vPoolPreq_f32

(y2.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); - y0.val[1] = vPoolPostq_f32

(y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); vst1q_f32(output_ptr0, y0.val[0]); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); @@ -215,8 +389,8 @@ struct Pooling3x3 { y1.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); y2.val[0] = vPoolPreq_f32

(y0.val[0], y2.val[0]); y2.val[1] = vPoolPreq_f32

(y0.val[1], y2.val[1]); - y1.val[0] = vPoolPostq_f32

(y1.val[0]); - y1.val[1] = vPoolPostq_f32

(y1.val[1]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); vst1q_f32(output_ptr1, y1.val[0]); vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); @@ -234,8 +408,8 @@ struct Pooling3x3 { y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); y2.val[1] = vPoolPreq_f32

(x0.val[1], y2.val[1]); - y2.val[0] = vPoolPostq_f32

(y2.val[0]); - y2.val[1] = vPoolPostq_f32

(y2.val[1]); + y2.val[0] = vPoolPostq_f32

(y2.val[0], post); + y2.val[1] = vPoolPostq_f32

(y2.val[1], post); vst1q_f32(output_ptr2, y2.val[0]); vst1_f32(output_ptr2 + 4, vget_low_f32(y2.val[1])); @@ -251,8 +425,8 @@ struct Pooling3x3 { x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); - y0.val[1] = vPoolPostq_f32

(y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); vst1q_f32(output_ptr3, y0.val[0]); vst1_f32(output_ptr3 + 4, vget_low_f32(y0.val[1])); @@ -292,7 +466,7 @@ struct Pooling3x3 { y2.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y1.val[0] = vPoolPreq_f32

(y2.val[0], y1.val[0]); y0.val[0] = vPoolPreq_f32

(y2.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); vst1q_f32(output_ptr0, y0.val[0]); x0.val[0] = vld1q_f32(input_ptr3); @@ -303,7 +477,7 @@ struct Pooling3x3 { y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); y2.val[0] = vPoolPreq_f32

(y0.val[0], y2.val[0]); - y1.val[0] = vPoolPostq_f32

(y1.val[0]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); vst1q_f32(output_ptr1, y1.val[0]); x0.val[0] = vld1q_f32(input_ptr4); @@ -314,7 +488,7 @@ struct Pooling3x3 { x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); - y2.val[0] = vPoolPostq_f32

(y2.val[0]); + y2.val[0] = vPoolPostq_f32

(y2.val[0], post); vst1q_f32(output_ptr2, y2.val[0]); x0.val[0] = vld1q_f32(input_ptr5); @@ -324,7 +498,7 @@ struct Pooling3x3 { x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); vst1q_f32(output_ptr3, y0.val[0]); input_ptr0 += 4; @@ -358,10 +532,10 @@ struct Pooling3x3 { m1 = PoolPre

(PoolPre

(m1, m2), m3); m2 = PoolPre

(PoolPre

(m2, m3), m4); m3 = PoolPre

(PoolPre

(m3, m4), m5); - output_ptr0[r] = PoolPost

(m0); - output_ptr1[r] = PoolPost

(m1); - output_ptr2[r] = PoolPost

(m2); - output_ptr3[r] = PoolPost

(m3); + output_ptr0[r] = PoolPost

(m0, avg); + output_ptr1[r] = PoolPost

(m1, avg); + output_ptr2[r] = PoolPost

(m2, avg); + output_ptr3[r] = PoolPost

(m3, avg); } } // remain h @@ -374,6 +548,7 @@ struct Pooling3x3 { int remain = output_w_remain; #if defined(__ARM_NEON__) || defined(__ARM_NEON) float32x4x2_t x0, x1, x2, y0; + float32x4_t post = vdupq_n_f32(1.f / 9); for (int loop = 0; loop < output_w_tiles; ++loop) { x0.val[0] = vld1q_f32(input_ptr0); x0.val[1] = vld1q_f32(input_ptr0 + 4); @@ -411,8 +586,8 @@ struct Pooling3x3 { x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); - y0.val[1] = vPoolPostq_f32

(y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); vst1q_f32(output_ptr0, y0.val[0]); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); @@ -445,7 +620,7 @@ struct Pooling3x3 { x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); vst1q_f32(output_ptr0, y0.val[0]); input_ptr0 += 4; @@ -464,7 +639,7 @@ struct Pooling3x3 { m2 = PoolPre

(m2, input_ptr2[r + 2]); m0 = PoolPre

(PoolPre

(m0, m1), m2); - output_ptr0[r] = PoolPost

(m0); + output_ptr0[r] = PoolPost

(m0, avg); } } } @@ -492,6 +667,7 @@ struct Pooling3x3 { int valid_w_start = (padding_w + 1) / 2; int valid_w_end = output_w - valid_w_start; int valid_w = valid_w_end - valid_w_start; + float avg = 1.f / 9; #pragma omp parallel for for (int c = 0; c < output->dims()[1]; ++c) { @@ -539,6 +715,7 @@ struct Pooling3x3 { #if defined(__ARM_NEON__) || defined(__ARM_NEON) float32x4x2_t x0, x1, x2; float32x4x2_t y0, y1, y2; + float32x4_t post = vdupq_n_f32(1.f / 9); for (int loop = 0; loop < output_w_tiles; ++loop) { x0 = vld2q_f32(input_ptr0); x1 = vld2q_f32(input_ptr0 + 8); @@ -570,8 +747,8 @@ struct Pooling3x3 { y1.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); y0.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); - y0.val[1] = vPoolPostq_f32

(y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); vst1q_f32(output_ptr0, y0.val[0]); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); @@ -596,8 +773,8 @@ struct Pooling3x3 { y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); y1.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); - y1.val[0] = vPoolPostq_f32

(y1.val[0]); - y1.val[1] = vPoolPostq_f32

(y1.val[1]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); vst1q_f32(output_ptr1, y1.val[0]); vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); @@ -622,8 +799,8 @@ struct Pooling3x3 { x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); - y0.val[1] = vPoolPostq_f32

(y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); vst1q_f32(output_ptr2, y0.val[0]); vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1])); @@ -659,7 +836,7 @@ struct Pooling3x3 { x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); vst1q_f32(output_ptr0, y0.val[0]); x0 = vld2q_f32(input_ptr3); @@ -675,7 +852,7 @@ struct Pooling3x3 { x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); - y1.val[0] = vPoolPostq_f32

(y1.val[0]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); vst1q_f32(output_ptr1, y1.val[0]); x0 = vld2q_f32(input_ptr5); @@ -691,7 +868,7 @@ struct Pooling3x3 { x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); vst1q_f32(output_ptr2, y0.val[0]); input_ptr0 += 8; @@ -726,9 +903,9 @@ struct Pooling3x3 { m0 = PoolPre

(PoolPre

(m0, m1), m2); m1 = PoolPre

(PoolPre

(m2, m3), m4); m2 = PoolPre

(PoolPre

(m4, m5), m6); - output_ptr0[r] = PoolPost

(m0); - output_ptr1[r] = PoolPost

(m1); - output_ptr2[r] = PoolPost

(m2); + output_ptr0[r] = PoolPost

(m0, avg); + output_ptr1[r] = PoolPost

(m1, avg); + output_ptr2[r] = PoolPost

(m2, avg); } } // remain h @@ -742,6 +919,7 @@ struct Pooling3x3 { int remain = output_w_remain; #if defined(__ARM_NEON__) || defined(__ARM_NEON) float32x4x2_t x0, x1, x2, y0; + float32x4_t post = vdupq_n_f32(1.f / 9); for (int loop = 0; loop < output_w_tiles; ++loop) { x0 = vld2q_f32(input_ptr0); x1 = vld2q_f32(input_ptr0 + 8); @@ -773,8 +951,8 @@ struct Pooling3x3 { x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); - y0.val[1] = vPoolPostq_f32

(y0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); vst1q_f32(output_ptr0, y0.val[0]); vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); @@ -804,7 +982,7 @@ struct Pooling3x3 { x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); vst1q_f32(output_ptr0, y0.val[0]); input_ptr0 += 8; @@ -823,7 +1001,7 @@ struct Pooling3x3 { m2 = PoolPre

(m2, input_ptr2[2 * r + 2]); m0 = PoolPre

(PoolPre

(m0, m1), m2); - output_ptr0[r] = PoolPost

(m0); + output_ptr0[r] = PoolPost

(m0, avg); } } } -- GitLab