diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt index 76828cfcfc20b839e273872f654e62ac5434f727..14996d1428c221bcc09773c2481f081d254c10c0 100644 --- a/paddle/fluid/lite/arm/math/CMakeLists.txt +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -9,6 +9,7 @@ cc_library(math_arm SRCS packed_sgemm.cc softmax.cc scale.cc + pooling.cc elementwise.cc sgemv.cc type_trans.cpp diff --git a/paddle/fluid/lite/arm/math/pooling.cc b/paddle/fluid/lite/arm/math/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..c74e235b6be1d8eb1a753408e182c234eeac0b6c --- /dev/null +++ b/paddle/fluid/lite/arm/math/pooling.cc @@ -0,0 +1,3347 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/pooling.h" +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void pooling_basic(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + // no need to pad input tensor, border is zero pad inside this function + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + int size_channel_in = win * hin; + int size_channel_out = wout * hout; + + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + if (global_pooling) { + if (pooling_type == "max") { // Pooling_max + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* data_in_channel = + data_in_batch + c * size_channel_in; // in address + data_out_batch[c] = data_in_channel[0]; + for (int i = 0; i < size_channel_in; ++i) { + data_out_batch[c] = data_out_batch[c] > data_in_channel[i] + ? data_out_batch[c] + : data_in_channel[i]; + } + } + } + + } else if (pooling_type == "avg") { + // Pooling_average_include_padding + // Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* data_in_channel = + data_in_batch + c * size_channel_in; // in address + float sum = 0.f; + for (int i = 0; i < size_channel_in; ++i) { + sum += data_in_channel[i]; + } + data_out_batch[c] = sum / size_channel_in; + } + } + } else { + LOG(FATAL) << "not support"; + } + return; + } + + if (pooling_type == "max") { + // Pooling_max + for (int n = 0; n < num; ++n) { + float* data_out_channel = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int q = 0; q < chout; q++) { + float* data_out_row = data_out_channel + q * size_channel_out; + const float* data_in_channel = data_in_batch + q * size_channel_in; + + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + + data_out_row[j] = data_in_channel[hstart * win + wstart]; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + data_out_row[j] = data_out_row[j] > data_in_channel[h * win + w] + ? data_out_row[j] + : data_in_channel[h * win + w]; + } + } + } + data_out_row += wout; + } + } + } + } else if (pooling_type == "avg") { + if (exclusive == false) { + // Pooling_average_include_padding + for (int n = 0; n < num; ++n) { + int pool_size = + kernel_w * + kernel_h; // (hend - hstart) * (wend - wstart); // problem + float* data_out_channel = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int q = 0; q < chout; q++) { + float* data_out_row = data_out_channel + q * size_channel_out; + const float* data_in_channel = data_in_batch + q * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + + int bh = kernel_h; + int bw = kernel_w; + if (wend == win) { + bw = wstart + kernel_w >= win + pad_w ? win + pad_w + : wstart + kernel_w; + bw -= wstart; + } + if (hend == hin) { + bh = hstart + kernel_h >= hin + pad_h ? hin + pad_h + : hstart + kernel_h; + bh -= hstart; + } + pool_size = bh * bw; + + data_out_row[j] = data_in_channel[hstart * win + wstart]; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += data_in_channel[h * win + w]; + } + } + data_out_row[j] = sum / pool_size; + } + data_out_row += wout; + } + } + } + } else { // exclusive == true, Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* data_out_channel = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int q = 0; q < chout; q++) { + float* data_out_row = data_out_channel + q * size_channel_out; + const float* data_in_channel = data_in_batch + q * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + + data_out_row[j] = data_in_channel[hstart * win + wstart]; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += data_in_channel[h * win + w]; + } + } + int pool_size = (hend - hstart) * (wend - wstart); + data_out_row[j] = sum / pool_size; + } + data_out_row += wout; + } + } + } + } + + } else { + LOG(FATAL) << "not support"; + } +} + +void pooling_global(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int cnt = size_channel_in / 8; + +#if 0 + LOG(INFO) << "size_channel_in:" << size_channel_in; + LOG(INFO) << "cnt:" << cnt; + LOG(INFO) << "num:" << num; + LOG(INFO) << "chout:" << chout; + LOG(INFO) << "hout:" << hout; + LOG(INFO) << "wout:" << wout; + + LOG(INFO) << "chin:" << chin; + LOG(INFO) << "hin:" << hin; + LOG(INFO) << "win:" << win; + LOG(INFO) << "pooling_type " << pooling_type; +#endif + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout; + const float* data_in_batch = data_in + n * chin * size_channel_in; + if (pooling_type == "max") { +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* data_in_channel = data_in_batch + c * size_channel_in; + int i = 0; + float minval = std::numeric_limits::lowest(); + float32x4_t vmax = vdupq_n_f32(minval); +#ifdef __aarch64__ + for (; i < cnt; i++) { + float32x4_t vdin1 = vld1q_f32(data_in_channel); + vmax = vmaxq_f32(vdin1, vmax); + float32x4_t vdin2 = vld1q_f32(data_in_channel + 4); + vmax = vmaxq_f32(vmax, vdin2); + data_in_channel += 8; + } +#else + int num = cnt; + if (num > 0) { + asm volatile( + "max_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, " + "data_in_channel\n" + "vmax.f32 %q[vmax], %q[vmax], q0 @max vmax, " + "vmax, data_in_channel\n" + "vld1.f32 {d2-d3}, [%[data_in_channel]]! @ load 2nd 4 " + "data" + "vmax.f32 %q[vmax], %q[vmax], q1 @ compare 2nd " + "4 datas\n" + "subs %[num], #1 @subs num, 1\n" + "bne max_loop @bne num\n" + : [data_in_channel] "+r"(data_in_channel), [num] "+r"(num), + [vmax] "+w"(vmax) + : + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ + float32x2_t vmax_tmp = + vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax)); + float tmp1 = vget_lane_f32(vmax_tmp, 0); + float tmp2 = vget_lane_f32(vmax_tmp, 1); + float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2; + for (i = cnt * 8; i < size_channel_in; ++i) { + /* code */ + max_tmp = max_tmp > data_in_channel[0] ? max_tmp : data_in_channel[0]; + data_in_channel++; + } + data_out_batch[c] = max_tmp; + } + } else { +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + const float* data_in_channel = + data_in_batch + c * size_channel_in; // in address + int i = 0; + float32x4_t vsum = vdupq_n_f32(0.0f); +#ifdef __aarch64__ + for (; i < cnt; i++) { // + vsum = vaddq_f32(vld1q_f32(data_in_channel), vsum); + data_in_channel += 4; + } +#else + int num = cnt; + if (num > 0) { + asm volatile( + "add_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, " + "data_in_channel\n" + "vadd.f32 %q[vsum], %q[vsum], q0 @add vmax, " + "vmax, data_in_channel\n" + "subs %[num], #1 @subs num, 1\n" + "bne add_loop @bne num\n" + : [data_in_channel] "+r"(data_in_channel), [num] "+r"(num), + [vsum] "+w"(vsum) + : + : "cc", "memory", "q0"); + } +#endif // __aarch64__ + float32x2_t vsum_tmp = + vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum)); + float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1); + for (i = cnt * 4; i < size_channel_in; i++) { + sum += data_in_channel[0]; + data_in_channel++; + } + data_out_batch[c] = sum / size_channel_in; + } + } + } +} + +void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (win >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = (hin >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remian = w_even - w_unroll_size; + int w_in_2 = win << 1; + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + int h = 0; + for (; h < h_even; h += 2) { + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); + float32x4_t dr10 = vld1q_f32(&r1[w]); + float32x4_t dr11 = vld1q_f32(&r1[w + 4]); + float32x4_t dmax1 = vmaxq_f32(dr00, dr10); + float32x4_t dmax2 = vmaxq_f32(dr01, dr11); +#ifdef __aarch64__ + float32x4_t dmax = vpmaxq_f32(dmax1, dmax2); +#else + float32x2_t dmaxl = + vpmax_f32(vget_low_f32(dmax1), vget_high_f32(dmax1)); + float32x2_t dmaxh = + vpmax_f32(vget_low_f32(dmax2), vget_high_f32(dmax2)); + float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); +#endif + vst1q_f32(&data_out_channel[w >> 1], dmax); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + float* dr0 = reinterpret_cast(r0); + float* dr1 = reinterpret_cast(r1); + float* dr_out = data_out_channel; + if (num > 0) { + asm volatile( + "s2_max_loop: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @load q1, dr1\n" + "vmax.f32 q0, q0, q2 @max q0, q0, " + "q2\n" + "vmax.f32 q1, q1, q3 @max q1, q1, " + "q2\n" + "vpmax.f32 d4, d0, d1 @max d4, d0, " + "d1\n" + "vpmax.f32 d5, d2, d3 @max d5, d2, " + "d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, " + "dr_out\n" + "subs %[num], #1 @subs num, 1\n" + "bne s2_max_loop @bne num\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [num] "+r"(num) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = + std::max(std::max(r0[w], r0[w + 1]), std::max(r1[w], r1[w + 1])); + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = std::max(r0[w], r1[w]); + } + r0 += w_in_2; // << 1; + r1 += w_in_2; // << 1; + data_out_channel += wout; + } + // process remain row (odd, last row) + for (; h < hin; h++) { // run 0 or 1 time + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); +#ifdef __aarch64__ + float32x4_t dmax = vpmaxq_f32(dr00, dr01); +#else + float32x2_t dmaxl = + vpmax_f32(vget_low_f32(dr00), vget_high_f32(dr00)); + float32x2_t dmaxh = + vpmax_f32(vget_low_f32(dr01), vget_high_f32(dr01)); + float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); +#endif + float32x4_t dmax_cmp_zero = vmaxq_f32(dmax, vzero); + vst1q_f32(&data_out_channel[w >> 1], dmax_cmp_zero); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + float* dr0 = reinterpret_cast(r0); + float* dr_out = data_out_channel; + if (num > 0) { + asm volatile( + "s2_max_loop1: @main " + "loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n" + "vpmax.f32 d4, d0, d1 @max d4, d0, " + "d1\n" + "vpmax.f32 d5, d2, d3 @max d5, d2, " + "d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, " + "dr_out\n" + "subs %[num], #1 @subs num, 1\n" + "bne s2_max_loop1 @bne num\n" + : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [num] "+r"(num) + : + : "cc", "memory", "q0", "q1", "q2"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = std::max(std::max(r0[w], r0[w + 1]), 0.f); + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = std::max(r0[w], 0.f); + } + } + } + } +} + +void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (win >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = (hin >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remian = w_even - w_unroll_size; + int w_in_2 = win << 1; + float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4 + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + int h = 0; + for (; h < h_even; h += 2) { + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); + float32x4_t dr10 = vld1q_f32(&r1[w]); + float32x4_t dr11 = vld1q_f32(&r1[w + 4]); + float32x4_t dsum1 = vaddq_f32(dr00, dr10); + float32x4_t dsum2 = vaddq_f32(dr01, dr11); +#ifdef __aarch64__ + float32x4_t dsum = vpaddq_f32(dsum1, dsum2); +#else + float32x2_t dsuml = + vpadd_f32(vget_low_f32(dsum1), vget_high_f32(dsum1)); + float32x2_t dsumh = + vpadd_f32(vget_low_f32(dsum2), vget_high_f32(dsum2)); + float32x4_t dsum = vcombine_f32(dsuml, dsumh); +#endif + float32x4_t res = vmulq_f32(dsum, vcoef); + vst1q_f32(&data_out_channel[w >> 1], res); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + float* dr0 = reinterpret_cast(r0); + float* dr1 = reinterpret_cast(r1); + float* dr_out = data_out_channel; + + if (num > 0) { + asm volatile( + "1: @ main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, " + "dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @ load q1, " + "dr1\n" + "vadd.f32 q0, q0, q2 @ add q0, q0, " + "q2\n" + "vadd.f32 q1, q1, q3 @ add q1, q1, " + "q2\n" + "vpadd.f32 d4, d0, d1 @ add d4, d0, " + "d1\n" + "vpadd.f32 d5, d2, d3 @ add d5, d2, " + "d3\n" + "vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, " + "vcoef\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, " + "dr_out\n" + "subs %[num], #1 @ subs num, 1\n" + "bne 1b @ bne num\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [vcoef] "+w"(vcoef), [num] "+r"(num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(num), "w"(vcoef) + : "cc", "memory", "q0", "q1", "q2", "q3"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = + (r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) / 4.f; + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = (r0[w] + r1[w]) / 4.f; + } + r0 += w_in_2; // << 1; + r1 += w_in_2; // << 1; + data_out_channel += wout; + } + // process remain row (odd, last row) + for (; h < hin; h++) { // run 0 or 1 time + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); +#ifdef __aarch64__ + float32x4_t dsum = vpaddq_f32(dr00, dr01); +#else + float32x2_t dsuml = + vpadd_f32(vget_low_f32(dr00), vget_high_f32(dr00)); + float32x2_t dsumh = + vpadd_f32(vget_low_f32(dr01), vget_high_f32(dr01)); + float32x4_t dsum = vcombine_f32(dsuml, dsumh); +#endif + float32x4_t res = vmulq_f32(dsum, vcoef); + vst1q_f32(&data_out_channel[w >> 1], res); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + float* dr0 = reinterpret_cast(r0); + float* dr_out = data_out_channel; + + if (num > 0) { + asm volatile( + "1: @ main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, " + "dr0\n" + "vpadd.f32 d4, d0, d1 @ add d4, d0, " + "d1\n" + "vpadd.f32 d5, d2, d3 @ add d5, d2, " + "d3\n" + "vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, " + "vcoef\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, " + "dr_out\n" + "subs %[num], #1 @ subs num, 1\n" + "bne 1b @ bne num\n" + : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [vcoef] "+w"(vcoef), + [num] "+r"(num) + : "r"(dr0), "r"(dr_out), "r"(num), "w"(vcoef) + : "cc", "memory", "q0", "q1", "q2"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = (r0[w] + r0[w + 1]) / 4.f; + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = r0[w] / 4.f; + } + } + } + } +} + +void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + // no need to pad input tensor, pad_size is not used, default border is zero + // padded + int ch_in = chin; + int h_in = hin; + int w_in = win; + + int ch_out = chout; + int h_out = hout; + int w_out = wout; + + int size_channel_out = w_out * h_out; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (w_in >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = (h_in >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + // int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remian = w_even - w_unroll_size; + int w_in_2 = w_in << 1; + int w_unroll_size = (w_in - 2) >> 2; + int w_unroll_remian = w_in - 2 - w_unroll_size * 4; + float minval = std::numeric_limits::lowest(); + float32x4_t vzero = vdupq_n_f32(minval); // zero pad + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; // w_in / 4 + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 1; + // left + data_out_channel[0] = + std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +// first row with zero pad +#ifdef __aarch64__ + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&data_out_channel[cnt], vmax); + cnt += 4; + } + +#else + dr_out = dr_out + 1; + + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vmax.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4, " + "max_1234, max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4, " + "max_2345, max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6, " + "max_3456, max_3456\n" + "vmax.f32 d8, d2, d3 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2, " + "vmax_23_45, vmax_34_56\n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + // swap + "vmov.f32 s0, s17 @mov \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s0 @mov \n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } + +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + data_out_channel[j + w + 1] = tmp_max; + } + // right + float tmp = std::max(r0[w_in - 2], r1[w_in - 2]); + tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); + data_out_channel[w_out - 1] = tmp; + + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + data_out_channel += w_out; + int h = 0; + for (; h < h_in - 2; h += 1) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); + data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&data_out_channel[cnt], vmax); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + dr2 = r2; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main " + "loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" + "vmax.f32 q7, q0, q2 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d16, d2, d6 @max " + "r0_5678,r1_5678\n" + "vmax.f32 q3, q7, q4 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d12, d16, d10 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d6, d7 @pmax d4, " + "max_1234, max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4, " + "max_2345, max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6, " + "max_3456, max_3456\n" + "vmax.f32 d8, d2, d3 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2, " + "vmax_23_45, vmax_34_56\n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "sub %[dr2], #8 @sub w, 8\n" + // swap + "vmov.f32 s0, s17 @mov \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s0 @mov \n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @ bne " + "s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + tmp_max = std::max(tmp_max, std::max(r2[j + w], r2[j + w + 1])); + tmp_max = std::max(tmp_max, r2[j + w + 2]); + data_out_channel[j + w + 1] = tmp_max; + } + // right + tmp = std::max(r0[w_in - 2], r1[w_in - 2]); + tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); + tmp = std::max(tmp, std::max(r2[w_in - 2], r2[w_in - 1])); + data_out_channel[w_out - 1] = tmp; + + r0 = r1; + r1 = r2; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + // the last two line + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + data_out_channel[0] = std::max(maxr0, maxr1); +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&data_out_channel[cnt], vmax); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vmax.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4, " + "max_1234, max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4, " + "max_2345, max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6, " + "max_3456, max_3456\n" + "vmax.f32 d8, d2, d3 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2, " + "vmax_23_45, vmax_34_56\n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + // swap + "vmov.f32 s0, s17 @mov \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s0 @mov \n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + data_out_channel[j + w + 1] = tmp_max; + } + tmp = std::max(r0[w_in - 2], r1[w_in - 2]); + tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); + data_out_channel[w_out - 1] = tmp; + } + } +} + +void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + + int size_channel_out = w_out * h_out; + int size_channel_in = w_in * h_in; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (w_in >> 1) << 1; + int h_even = (h_in >> 1) << 1; + int w_in_2 = w_in << 1; + int w_unroll_size = (w_in - 2) >> 2; + int w_unroll_remian = w_in - 2 - w_unroll_size * 4; + float32x4_t vzero = vdupq_n_f32(0.f); // zero pad + float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); // zero pad + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; // w_in / 4 + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 1; + // left + data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +// first row with zero pad +#ifdef __aarch64__ + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef); + vst1q_f32(&data_out_channel[cnt], vsum); + cnt += 4; + } + +#else + dr_out = dr_out + 1; + + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vadd.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234 + 2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } + +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + data_out_channel[j + w + 1] = tmp_sum / 9.f; + } + // right + float tmp = r0[w_in - 2] + r1[w_in - 2]; + tmp += (r0[w_in - 1] + r1[w_in - 1]); + data_out_channel[w_out - 1] = tmp / 9.f; + + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + data_out_channel += w_out; + int h = 0; + for (; h < h_in - 2; h += 1) { + // deal with left pad + float maxr0 = r0[0] + r0[1]; + float maxr1 = r1[0] + r1[1]; + float maxr2 = r2[0] + r2[1]; + data_out_channel[0] = (maxr0 + maxr1 + maxr2) / 9.f; +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef); + vst1q_f32(&data_out_channel[cnt], vsum); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + dr2 = r2; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" + "vadd.f32 q7, q0, q2 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d16, d2, d6 @max " + "r0_5678,r1_5678\n" + "vadd.f32 q3, q7, q4 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d12, d16, d10 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q3, q0 @add 1234 + " + "2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "sub %[dr2], #8 @sub w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne " + "s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [vcoef] "+w"(vcoef) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + tmp_sum += (r2[j + w + 1] + r2[j + w + 2]); + tmp_sum += r2[j + w]; + data_out_channel[j + w + 1] = tmp_sum / 9.f; + } + // right + tmp = r0[w_in - 2] + r1[w_in - 2]; + tmp += (r0[w_in - 1] + r1[w_in - 1]); + tmp += (r2[w_in - 2] + r2[w_in - 1]); + data_out_channel[w_out - 1] = tmp / 9.f; + + r0 = r1; + r1 = r2; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + // the last two line + float maxr0 = (r0[0] + r0[1]); + float maxr1 = (r1[0] + r1[1]); + data_out_channel[0] = (maxr0 + maxr1) / 9.f; +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef); + vst1q_f32(&data_out_channel[cnt], vsum); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vadd.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234 + 2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + data_out_channel[j + w + 1] = tmp_sum / 9.f; + } + // right + tmp = r0[w_in - 2] + r1[w_in - 2]; + tmp += (r0[w_in - 1] + r1[w_in - 1]); + data_out_channel[w_out - 1] = tmp / 9.f; + } + } +} + +void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = wout * 2 + 1; + int h_needed = hout * 2 + 1; + int pad_right = w_needed - win - pad_left; + int pad_bottom = h_needed - hin - pad_top; + int w_even = (win >> 1) << 1; + int h_even = (hin >> 1) << 1; + int w_in_2 = win << 1; + float minval = std::numeric_limits::lowest(); + float32x4_t vzero = vdupq_n_f32(minval); // zero pad + int cnt_col = (win - 1) / 8; + // remain + int remain = ((win - 1) % 8) / 2; + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 1; + int cnt = 1; + int cnt_num = cnt_col; + int cnt_num1 = remain; + data_out_channel[0] = + std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +// first row with zero pad +#ifdef __aarch64__ + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr0 = dr0 + 1; + dr1 = dr1 + 1; + dr_out = dr_out + 1; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q6, q0, q3 @max " + "r0_1234,r1_1234\n" + "vmax.f32 q7, q1, q4 @max " + "r0_5678,r1_5678\n" + "vmax.f32 q8, q2, q5 @max " + "r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q7, q8, #1 @vext max_6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp cnt_num, " + "0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0, q0, q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0, d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne " + "s3_max_loop_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9"); + } +// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); +#endif + // int w = w_even - 1; + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + data_out_channel[w_even >> 1] = tmp; + // cnt ++; + } + + r0 = r1; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + int h = 2; + for (; h < h_even; h += 2) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); + data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + dr2 = (r2 + 1); + cnt_num = cnt_col; + cnt_num1 = remain; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 q11, q2, q5 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 " + "5678\n" + "vmax.f32 q1, q11, q8 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax d10, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax d12, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax d11, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax d13, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax d0, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax d1, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "sub %[dr2], #16 @add w, 8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne " + "s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmov.f32 s11,s10 @movs11, s10\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vmax.f32 q0, q0, q2 @max q0, q0, " + "q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "bne 2b @bne " + "s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(r0[i], r1[i])); + tmp = std::max(tmp, r2[i]); + } + data_out_channel[w_even >> 1] = tmp; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + } + + if (pad_bottom) { + // deal with bottom pad + // first row with zero pad + int hstart = (h >> 1) * stride_h - pad_h; + int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); + + if (hstart == hend - 1) { // only one lline + data_out_channel[0] = std::max(r0[0], r0[1]); +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vmax_1234 = vld1q_f32(&r0[w]); + float32x4_t vmax_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vmax_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + float32x2_t vmax = vpmax_f32(vget_low_f32(vr0), vget_high_f32(vr0)); + vmax = vpmax_f32(vmax, vmax); + data_out_channel[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + cnt_num = cnt_col; + cnt_num1 = remain; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vext.f32 q4, q0, q1, #1 @vext q4, q0, " + "q1, 1 2345\n" + "vext.f32 q5, q1, q2, #1 @vext q5, q0, " + "q1, 1 6789\n" + "vpmax.f32 d12, d0, d1 @pmax d12, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d14, d2, d3 @pmax d14, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d13, d8, d9 @pmax d13, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d15, d10, d11 @pmax d15, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d0, d12, d13 @max d0, " + "vmax_12_34,vmax_23_45\n" + "vmax.f32 d1, d14, d15 @pmax d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 6\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne " + "s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 2\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne " + "s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, r0[i]); + } + data_out_channel[w_even >> 1] = tmp; + } + } else { // two lines + data_out_channel[0] = + std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + cnt_num = cnt_col; + cnt_num1 = remain; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vmax.f32 q8, q2, q5 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, + // s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0, " + "2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1, " + "6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne " + "s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne " + "s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + data_out_channel[w_even >> 1] = tmp; + } + } + } + } + } +} + +void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = wout * 2 + 1; + int h_needed = hout * 2 + 1; + int pad_right = w_needed - win - pad_left; + int pad_bottom = h_needed - hin - pad_top; + int w_even = (win >> 1) << 1; + int h_even = (hin >> 1) << 1; + int w_in_2 = win << 1; + int w_unroll_size = (win - 1) / 8; + // remain + int w_unroll_remian = ((win - 1) % 8) / 2; + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size; + int cnt_num1 = w_unroll_remian; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 1; + int cnt = 1; + float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); + float32x4_t vzero = vdupq_n_f32(0.f); + data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +// first row with zero pad +#ifdef __aarch64__ + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); + data_out_channel[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr0 = dr0 + 1; + dr1 = dr1 + 1; + dr_out = dr_out + 1; + // printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vadd.f32 q6, q0, q3 @max " + "r0_1234,r1_1234\n" + "vadd.f32 q7, q1, q4 @max " + "r0_5678,r1_5678\n" + "vadd.f32 q8, q2, q5 @max " + "r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 \n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 \n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 \n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp cnt_num, " + "0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9"); + } +// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); +#endif + // int w = w_even - 1; + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp += (r0[i] + r1[i]); + } + data_out_channel[w_even >> 1] = tmp / 9.f; + // cnt ++; + } + + r0 = r1; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + int h = 2; + for (; h < h_even; h += 2) { + // deal with left pad + float sum0 = r0[0] + r0[1]; + float sum1 = r1[0] + r1[1]; + float sum2 = r2[0] + r2[1]; + data_out_channel[0] = (sum0 + sum1 + sum2) / 9.f; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); + vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); + vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + vr2 = vsetq_lane_f32(0.f, vr2, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + vsum1 = vaddq_f32(vsum1, vr2); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + float32x2_t vsum = vpadd_f32(vsum2, vsum2); + data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f; + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + dr2 = (r2 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 q11, q2, q5 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 " + "1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 " + "5678\n" + "vadd.f32 q8, q11, q8 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 " + "\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 " + "\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 " + "\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 " + "\n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "sub %[dr2], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vadd.f32 q0, q0, q2 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef), + [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; + for (int i = wstart; i < wend; i++) { + tmp += (r0[i] + r1[i] + r2[i]); + } + data_out_channel[w_even >> 1] = tmp / 9.f; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + } + + if (pad_bottom) { + // deal with bottom pad + // first row with zero pad + int hstart = (h >> 1) * stride_h - pad_h; + int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); + + if (hstart == hend - 1) { // only one lline + data_out_channel[0] = (r0[0] + r0[1]) / 9.f; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vsum_1234 = vld1q_f32(&r0[w]); + float32x4_t vsum_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vsum_9101112 = vld1q_f32(&r0[w + 8]); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), + vsum_123_345, 1); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), + vsum_123_345, 2); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), + vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + float32x2_t vsum = vpadd_f32(vget_low_f32(vr0), vget_high_f32(vr0)); + vsum = vpadd_f32(vsum, vsum); + data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f; + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d12-d15}, [%[dr0]]! @load " + "d0-d3, dr0\n" + "vld1.f32 {d16-d17}, [%[dr0]]! @load " + "d0-d3, dr0\n" + "vext.f32 q0, q6, q7, #1 @vext " + "max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext " + "max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext " + "max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext " + "max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, " + "2345 \n" + "vadd.f32 q5, q7, q1 @add 5678, " + "4567 \n" + "vadd.f32 q4, q4, q2 @add 3456, " + "sum1 \n" + "vadd.f32 q5, q5, q3 @add 6789, " + "sum2 \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 6\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext " + "v0_0123\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 2\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; + for (int i = wstart; i < wend; i++) { + tmp += r0[i]; + } + data_out_channel[w_even >> 1] = tmp / 9.f; + } + } else { // two lines + data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), + vsum_123_345, 1); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), + vsum_123_345, 2); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), + vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); + data_out_channel[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vmax.f32 q8, q2, q5 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, + // s6\n" + "vext.f32 q0, q6, q7, #1 @vext " + "max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext " + "max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext " + "max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext " + "max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, " + "2345 \n" + "vadd.f32 q5, q7, q1 @add 5678, " + "4567 \n" + "vadd.f32 q4, q4, q2 @add 3456, " + "sum1 \n" + "vadd.f32 q5, q5, q3 @add 6789, " + "sum2 \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext " + "v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext " + "v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp += (r0[i] + r1[i]); + } + data_out_channel[w_even >> 1] = tmp / 9.f; + } + } + } + } + } +} + +void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int size_channel_out = w_out * h_out; + int size_channel_in = w_in * h_in; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = w_out * 2 + 1; + int h_needed = h_out * 2 + 1; + int pad_right = w_needed - w_in - pad_left; + int pad_bottom = h_needed - h_in - pad_top; + int w_even = ((w_in - 1) >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = ((h_in - 1) >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + int w_unroll_size = w_in >> 3; + int w_unroll_remian = (w_in - w_unroll_size * 8 - 1) / 2; + int w_in_2 = w_in << 1; + float minval = std::numeric_limits::lowest(); + float32x4_t vzero = vdupq_n_f32(minval); // zero pad + // printf("minval: %.2f\n", minval); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; + // w = w_in - 8; + int cnt_num1 = w_unroll_remian; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 0; + // data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], + // r1[1])); + // first row with zero pad + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + // data_out_channel += w_out; + int h = 0; + for (; h < h_even; h += 2) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); +// data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 0; + cnt = 0; + for (; w < w_in - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = data_out_channel; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + dr2 = r2; // (r2 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 d22, d4, d10 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 " + "5678\n" + "vmax.f32 d2, d22, d16 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax d10, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax d12, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax d11, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax d13, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax d0, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax d1, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "sub %[dr2], #8 @add w, 8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmov.f32 s11,s10 @movs11, s10\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vmax.f32 q0, q0, q2 @max q0, q0, " + "q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(r0[i], r1[i])); + tmp = std::max(tmp, r2[i]); + } + data_out_channel[w_even >> 1] = tmp; + // cnt ++; + } + r0 = r2; + r1 = r0 + w_in; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + if (pad_bottom) { +// deal with bottom pad +// first row with zero pad +// int hstart = (h >> 1) * stride_h - pad_h; +// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in); +// data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], +// r1[1])); +#ifdef __aarch64__ + w = 0; + cnt = 0; + for (; w < w_in - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr_out = data_out_channel; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vmax.f32 d16, d4, d10 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0, 2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1, 6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + data_out_channel[w_even >> 1] = tmp; + } + } + } + } +} + +void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int size_channel_out = w_out * h_out; + int size_channel_in = w_in * h_in; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = w_out * 2 + 1; + int h_needed = h_out * 2 + 1; + int pad_right = w_needed - w_in - pad_left; + int pad_bottom = h_needed - h_in - pad_top; + int w_even = ((w_in - 1) >> 1) << 1; + int h_even = ((h_in - 1) >> 1) << 1; + int w_in_2 = w_in << 1; + int w_unroll_size = w_in >> 3; + int w_unroll_remian = (w_even - w_unroll_size * 8 - 1) / 2; + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; + // w = w_in - 8; + int cnt_num1 = w_unroll_remian; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + + float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); + float32x4_t vzero = vdupq_n_f32(0.f); + + int h = 0; + for (; h < h_even; h += 2) { +// LOG(INFO) << "h: " << h<<", dr0:" << r0 <<", dr1: "< 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble loop3_ave_p0 @ble " + "exit\n" + "s3_ave_loop_mid_p0: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 d22, d4, d10 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 " + "1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 " + "5678\n" + "vadd.f32 d16, d22, d16 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 " + "\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 " + "\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 " + "\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 " + "\n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "sub %[dr2], #8 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne s3_ave_loop_mid_p0 @bne " + "s3_max_loop_mid\n" + "loop3_ave_p0: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble exit1_ave_p0 @ble " + "exit1\n" + "s3_ave_loop_mid_1_p0: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vadd.f32 q0, q0, q2 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne s3_ave_loop_mid_1_p0 @bne " + "s3_max_loop_mid_1\n" + "exit1_ave_p0: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef), + [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = 0.f; + int pool_size = 3 * (wend - wstart); + for (int i = wstart; i < wend; i++) { + tmp += (r0[i] + r1[i] + r2[i]); + } + data_out_channel[w_even >> 1] = tmp / pool_size; + // cnt ++; + } + r0 = r2; + r1 = r0 + w_in; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + if (pad_bottom) { +// deal with bottom pad +// first row with zero pad +// int hstart = (h >> 1) * stride_h - pad_h; +// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in); +// data_out_channel[0] =(r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +#if 1 // def __aarch64__ + int w = 0; + int cnt = 0; + vcoef = vdupq_n_f32(1.f / 6.f); + for (; w < w_in - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); + data_out_channel[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr_out = data_out_channel; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + // LOG(INFO) << "dr0:" << dr0 <<", dr1: "< 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 2f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vadd.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vadd.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vadd.f32 d16, d4, d10 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 " + "\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 " + "\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 " + "\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 " + "\n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "2: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 3f @ble exit\n" + "4: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 4b @bne s3_max_loop_bot_1\n" + "3: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); + } + +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = 0.f; + int pool_size = 2 * (wend - wstart); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp += (r0[i] + r1[i]); + } + data_out_channel[w_even >> 1] = tmp / pool_size; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/pooling.h b/paddle/fluid/lite/arm/math/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..36832187073c2d29a129a10fdd7984ba8d15db3d --- /dev/null +++ b/paddle/fluid/lite/arm/math/pooling.h @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +// !pooling fp32 Op +void pooling_basic(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling_global(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index f8346ab62aec86eb7b3f19e3a71d5f3c67b88259..c0fa480f0944515abec15001026768cd7e3abb46 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -11,12 +11,14 @@ cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) +lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) set(arm_kernels fc_compute_arm @@ -26,6 +28,7 @@ set(arm_kernels softmax_compute_arm conv_compute_arm elementwise_add_compute_arm + pool_compute_arm ) set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/paddle/fluid/lite/kernels/arm/pool_compute.cc b/paddle/fluid/lite/kernels/arm/pool_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a7716fae6bfc3aa52dad7c8b8192191e986b6f3 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/pool_compute.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/pool_compute.h" +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void PoolCompute::Run() { + auto& param = Param(); + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + + const float* din = param.x->data(); + float* dout = param.output->mutable_data(); + + std::vector& ksize = param.ksize; + std::vector& strides = param.strides; + std::vector& paddings = param.paddings; + + std::string& pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + bool ceil_mode = param.ceil_mode; + bool use_quantizer = param.use_quantizer; + std::string& data_format = param.data_format; + + if (param.global_pooling) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(in_dims[i + 2]); + } + } + +#if 0 + for (int i = 0; i < in_dims.size(); ++i) { + LOG(INFO) << "in_dims[" << i << "]:" << in_dims[i]; + } + for (int i = 0; i < out_dims.size(); ++i) { + LOG(INFO) << "out_dims[" << i << "]:" << out_dims[i]; + } + for (int i = 0; i < ksize.size(); ++i) { + LOG(INFO) << "ksize[" << i << "]:" << ksize[i]; + } + for (int i = 0; i < strides.size(); ++i) { + LOG(INFO) << "strides[" << i << "]:" << strides[i]; + } + for (int i = 0; i < paddings.size(); ++i) { + LOG(INFO) << "paddings[" << i << "]:" << paddings[i]; + } + LOG(INFO) << "global_pooling:" << global_pooling; + LOG(INFO) << "exclusive:" << exclusive; + LOG(INFO) << "adaptive:" << adaptive; + LOG(INFO) << "ceil_mode:" << ceil_mode; + LOG(INFO) << "use_quantizer:" << use_quantizer; + LOG(INFO) << "data_format:" << data_format; + LOG(INFO) << "din:" << din; + LOG(INFO) << "dout:" << dout; +#endif + + // global + if (global_pooling == true) { + lite::arm::math::pooling_global( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1]) { + if (pooling_type == "max") { + lite::arm::math::pooling2x2s2_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling2x2s2_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 1 && + strides[0] == strides[1] && paddings[0] == 1) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s1p1_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s1p1_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1] && paddings[0] == 0) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s2p0_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s2p0_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1] && paddings[0] == 1) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s2p1_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s2p1_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else { + lite::arm::math::pooling_basic( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + return; +} + +TargetType PoolCompute::target() const { return TARGET(kARM); } + +PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); } + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pool, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::PoolCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/pool_compute.h b/paddle/fluid/lite/kernels/arm/pool_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..76dedbc3132405cd70d74e233619572f97dc07e0 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/pool_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/operators/pool_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class PoolCompute : public KernelLite { + public: + using param_t = operators::PoolParam; + + void Run() override; + + TargetType target() const override; + PrecisionType precision() const override; + + virtual ~PoolCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/pool_compute_test.cc b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f6e1fcbf2d652b82a35101220488a0d623f1811 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc @@ -0,0 +1,276 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/pool_compute.h" +#include +#include +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void pool_compute_ref(const operators::PoolParam& param) { + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + + const float* src_ptr = param.x->data(); + float* dst_ptr = param.output->mutable_data(); + + std::vector ksize = param.ksize; + std::vector strides = param.strides; + std::vector paddings = param.paddings; + + std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + bool ceil_mode = param.ceil_mode; + bool use_quantizer = param.use_quantizer; + std::string data_format = param.data_format; + + int in_n = in_dims[0]; + int in_c = in_dims[1]; + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int size_in_n = in_c * in_h * in_w; + int size_in_c = in_h * in_w; + + int out_h = out_dims[2]; + int out_w = out_dims[3]; + int size_out_n = in_c * out_h * out_w; + int size_out_c = out_h * out_w; + + int window_h = ksize[0]; + int window_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + if (global_pooling == true) { + ksize[0] = in_h; + ksize[1] = in_w; + } + +#if 0 + for (int i = 0; i < ksize.size(); ++i) { + LOG(INFO) << "ksize[" << i << "]:" << ksize[i]; + } + for (int i = 0; i < strides.size(); ++i) { + LOG(INFO) << "strides[" << i << "]:" << strides[i]; + } + for (int i = 0; i < paddings.size(); ++i) { + LOG(INFO) << "paddings[" << i << "]:" << paddings[i]; + } + LOG(INFO) << "in nchw:" << in_n << ", " << in_c << ", " << in_h << ", " + << in_w; + LOG(INFO) << "size_in_n:" << size_in_n; + LOG(INFO) << "size_out_c:" << size_out_c; + LOG(INFO) << "out_h:" << out_h; + LOG(INFO) << "out_w:" << out_w; + LOG(INFO) << "size_out_n:" << size_out_n; + LOG(INFO) << "size_out_c:" << size_out_c; + LOG(INFO) << "window_h:" << window_h; + LOG(INFO) << "window_w:" << window_w; + LOG(INFO) << "stride_h:" << stride_h; + LOG(INFO) << "stride_w:" << stride_w; + LOG(INFO) << "pad_h:" << pad_h; + LOG(INFO) << "pad_w:" << pad_w; +#endif + + for (int ind_n = 0; ind_n < in_n; ++ind_n) { + for (int ind_c = 0; ind_c < in_c; ++ind_c) { + for (int ind_h = 0; ind_h < out_h; ++ind_h) { + int sh = ind_h * stride_h; + int eh = sh + window_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; + + for (int ind_w = 0; ind_w < out_w; ++ind_w) { + int sw = ind_w * stride_w; + int ew = sw + window_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; + + float result = static_cast(0); + + int dst_ind = + ind_n * size_out_n + ind_c * size_out_c + ind_h * out_w + ind_w; + + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_ind = + ind_n * size_in_n + ind_c * size_in_c + kh * in_w + kw; + + if (kh == sh && kw == sw) { + result = src_ptr[src_ind]; + } else { + if (pooling_type == "max") { + result = + result >= src_ptr[src_ind] ? result : src_ptr[src_ind]; + } + if (pooling_type == "avg" && exclusive == false) { + // Pooling_average_include_padding + result += src_ptr[src_ind]; + } + if (pooling_type == "avg" && exclusive == true) { + // Pooling_average_include_padding + result += src_ptr[src_ind]; + } + } + } + } + if (pooling_type == "avg" && exclusive == false) { + // Pooling_average_include_padding + // result /= param.window_h * param.window_w; + // LOG(ERROR)<<"cpu"<= in_w + pad_w ? in_w + pad_w : sw + window_w; + bw -= sw; + } + if (eh == in_h) { + bh = sh + window_h >= in_h + pad_h ? in_h + pad_h : sh + window_h; + bh -= sh; + } + result /= bh * bw; + } + if (pooling_type == "avg" && exclusive == true) { + // Pooling_average_exclude_padding + result /= (ew - sw) * (eh - sh); + } + dst_ptr[dst_ind] = result; + } + } + } + } +} + +TEST(pool_arm, init) { + PoolCompute pool; + ASSERT_EQ(pool.precision(), PRECISION(kFloat)); + ASSERT_EQ(pool.target(), TARGET(kARM)); +} + +TEST(pool_arm, compute) { + PoolCompute pool; + operators::PoolParam param; + + lite::Tensor x; + lite::Tensor output; + lite::Tensor output_ref; + + for (auto pooling_type : {"avg", "max"}) { + for (auto global_pooling : {true}) { + for (auto stride : {2}) { + for (auto pad : {0}) { + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 11, 4, 1024}) { + for (auto h : {3, 1, 11, 4, 1}) { + for (auto w : {1, 3, 4, 12, 1}) { + LOG(INFO) << "n:" << n << " c:" << c << " h:" << h + << " w:" << w << " stride:" << stride + << " pad:" << pad + << " pooling_type:" << pooling_type + << " global_pooling:" << global_pooling; + + // init x, output + x.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, 1, 1}))); + output_ref.Resize(DDim(std::vector({n, c, 1, 1}))); + auto* x_data = x.mutable_data(); + for (int i = 0; i < x.dims().production(); ++i) { + x_data[i] = i; + } + + // fill param + param.x = &x; + param.output = &output; + param.pooling_type = pooling_type; + param.ksize = {h, w}; + param.global_pooling = global_pooling; + param.strides = {stride, stride}; + param.paddings = {pad, pad}; + param.exclusive = true; + param.adaptive = false; + param.ceil_mode = false; + param.use_quantizer = false; + + // compute + pool.SetParam(param); + pool.Run(); + +#if 0 + LOG(INFO) << "n:" << n << " c:" << c << " h:" << h << " w:" << w + << " end"; + std::cout << "n:" << n << " c:" << c << " h:" << h << " w:" << w + << " end" << std::endl; + for (int i = 0; i < param.ksize.size(); ++i) { + std::cout << " ksize[" << i << "]:" << param.ksize[i]; + } + std::cout << "\n"; + for (int i = 0; i < param.strides.size(); ++i) { + std::cout << " strides[" << i << "]:" << param.strides[i]; + } + std::cout << "\n"; + for (int i = 0; i < param.paddings.size(); ++i) { + std::cout << " paddings[" << i << "]:" << param.paddings[i]; + } + std::cout << "\n"; +#endif + + // compute ref + // output_ref.Resize(output.dims()); + param.output = &output_ref; + pool_compute_ref(param); + LOG(INFO) << "pool_compute_ref(param) end"; + + // compare + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], + 1); // 1e-5); + } + + LOG(INFO) << "compare pass"; + } + } + } + } + } // pad + } // stride + } // global_pooling + } // pooling_type +} + +TEST(pool, retrive_op) { + auto pool = + KernelRegistry::Global().Create("pool"); + ASSERT_FALSE(pool.empty()); + ASSERT_TRUE(pool.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/use_kernels.h b/paddle/fluid/lite/kernels/arm/use_kernels.h index d856950f3a177d08cdc950c259abf3d1a194ee25..1f93a81aa94f09f8330aa385840adec559d7161d 100644 --- a/paddle/fluid/lite/kernels/arm/use_kernels.h +++ b/paddle/fluid/lite/kernels/arm/use_kernels.h @@ -19,5 +19,6 @@ USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index b65efae8db8e5e030343e383f647be5d69091fb6..628a60234864de982a425d1b3436074d034ce629 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -18,6 +18,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) +cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) set(ops_lite conv_op_lite @@ -46,3 +47,6 @@ lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) +lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc + DEPS pool_op_lite memory_lite + ARM_DEPS pool_compute_arm) diff --git a/paddle/fluid/lite/operators/pool_op.cc b/paddle/fluid/lite/operators/pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4aaecb282008bcb762ce240fb0ddd39265c09a49 --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/pool_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool PoolOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + + const auto& x_dims = param_.x->dims(); + const auto& ksize = param_.ksize; + const auto& strides = param_.strides; + const auto& paddings = param_.paddings; + + // "Pooling intput should be 4-D or 5-D tensor." + CHECK_OR_FALSE(x_dims.size() == 4 || x_dims.size() == 5); + // Input size and pooling size should be consistent. + CHECK_OR_FALSE(x_dims.size() - ksize.size() == 2U); + // Strides size and pooling size should be the same. + CHECK_OR_FALSE(ksize.size() == strides.size()); + // Paddings size and pooling size should be the same. + CHECK_OR_FALSE(ksize.size() == paddings.size()); + + return true; +} + +int PoolOutputSize(int input_size, int filter_size, int padding, int stride, + bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } + return output_size; +} + +bool PoolOpLite::InferShape() const { + const auto x_dims = param_.x->dims(); + std::vector& ksize = param_.ksize; + if (param_.global_pooling) { + ksize.resize(static_cast(x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) { + param_.paddings[i] = 0; + ksize[i] = static_cast(x_dims[i + 2]); + } + } + + std::vector output_shape({x_dims[0], x_dims[1]}); + if (param_.adaptive) { + output_shape.insert(output_shape.end(), param_.ksize.begin(), + param_.ksize.end()); + } else { + for (size_t i = 0; i < param_.ksize.size(); ++i) { + output_shape.push_back( + PoolOutputSize(x_dims[i + 2], param_.ksize[i], param_.paddings[i], + param_.strides[i], param_.ceil_mode)); + } + } + param_.output->Resize(lite::DDim(output_shape)); + + // ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + // ctx->ShareLoD("X", "Out"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(pool, paddle::lite::operators::PoolOpLite); diff --git a/paddle/fluid/lite/operators/pool_op.h b/paddle/fluid/lite/operators/pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a6fc4ac9f3a0d64b833616c5e8742db78c4dbb58 --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op.h @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class PoolOpLite : public OpLite { + public: + PoolOpLite() {} + + explicit PoolOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + /* + bool Run() override { + CHECK(kernel_); + kernel_->Run(); + return true; + } + */ + + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + + CHECK(scope->FindVar(x)); + CHECK(scope->FindVar(out)); + param_.x = scope->FindVar(x)->GetMutable(); + param_.output = scope->FindVar(out)->GetMutable(); + + param_.pooling_type = op_desc.GetAttr("pooling_type"); + param_.ksize = op_desc.GetAttr>("ksize"); + param_.global_pooling = op_desc.GetAttr("global_pooling"); + param_.strides = op_desc.GetAttr>("strides"); + param_.paddings = op_desc.GetAttr>("paddings"); + + param_.exclusive = op_desc.GetAttr("exclusive"); + param_.adaptive = op_desc.GetAttr("adaptive"); + param_.ceil_mode = op_desc.GetAttr("ceil_mode"); + param_.use_quantizer = op_desc.GetAttr("use_quantizer"); + // param_.data_format = op_desc.GetAttr("data_format"); + return true; + } + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "pool"; } + + private: + mutable PoolParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/pool_op_test.cc b/paddle/fluid/lite/operators/pool_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf46a2ecbd8a465fa5a52bc099389ff3838a5840 --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op_test.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/pool_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(pool_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x->Resize(DDim(std::vector({1, 3, 224, 224}))); + output->Resize(DDim(std::vector{1, 3, 112, 112})); + + // set data + for (int i = 0; i < 1 * 3 * 224 * 224; i++) { + x->mutable_data()[i] = i; + } + for (int i = 0; i < 1 * 3 * 112 * 112; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("pool"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + + std::string pooling_type("max"); + desc.SetAttr("pooling_type", pooling_type); + // desc.SetAttr("ksize", static_cast>({2, 2})); + std::vector ksize{2, 2}; + desc.SetAttr("ksize", ksize); + + bool global_pooling{false}; + desc.SetAttr("global_pooling", global_pooling); + + std::vector strides{1, 1}; + desc.SetAttr("strides", strides); + + std::vector paddings{0, 0}; + desc.SetAttr("paddings", paddings); + + bool exclusive{true}; + desc.SetAttr("exclusive", exclusive); + + bool adaptive{false}; + desc.SetAttr("adaptive", adaptive); + + bool ceil_mode{false}; + desc.SetAttr("ceil_mode", ceil_mode); + + bool use_quantizer{false}; + desc.SetAttr("use_quantizer", use_quantizer); + + PoolOpLite pool("pool"); + pool.SetValidPlaces({Place{TARGET(kARM), PRECISION(kFloat)}}); + pool.Attach(desc, &scope); + auto kernels = pool.CreateKernels({Place{TARGET(kARM), PRECISION(kFloat)}}); + LOG(INFO) << "kernels.size(): " << kernels.size(); + ASSERT_FALSE(kernels.empty()); +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def); +#endif