diff --git a/paddle/fluid/lite/arm/math/pooling.cc b/paddle/fluid/lite/arm/math/pooling.cc index 7f00ffa89a0b24b4a0aa6f81c9fc2b885673888d..4866257b577267bc67a64dad7d5e2b2537ec29f9 100644 --- a/paddle/fluid/lite/arm/math/pooling.cc +++ b/paddle/fluid/lite/arm/math/pooling.cc @@ -22,7 +22,7 @@ namespace lite { namespace arm { namespace math { -void pooling_basic(const void* din, void* dout, int num, int chout, int hout, +void pooling_basic(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, const std::vector& ksize, const std::vector& strides, @@ -38,98 +38,51 @@ void pooling_basic(const void* din, void* dout, int num, int chout, int hout, int pad_w = paddings[1]; int size_channel_in = win * hin; int size_channel_out = wout * hout; - - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); - if (global_pooling) { if (pooling_type == "max") { // Pooling_max for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; ++c) { - const float* data_in_channel = - data_in_batch + c * size_channel_in; // in address - data_out_batch[c] = data_in_channel[0]; + const float* din_ch = din_batch + c * size_channel_in; // in address + float tmp1 = din_ch[0]; for (int i = 0; i < size_channel_in; ++i) { - data_out_batch[c] = data_out_batch[c] > data_in_channel[i] - ? data_out_batch[c] - : data_in_channel[i]; + float tmp2 = din_ch[i]; + tmp1 = tmp1 > tmp2 ? tmp1 : tmp2; } + dout_batch[c] = tmp1; } } - } else if (pooling_type == "avg") { // Pooling_average_include_padding // Pooling_average_exclude_padding for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; ++c) { - const float* data_in_channel = - data_in_batch + c * size_channel_in; // in address + const float* din_ch = din_batch + c * size_channel_in; // in address float sum = 0.f; for (int i = 0; i < size_channel_in; ++i) { - sum += data_in_channel[i]; + sum += din_ch[i]; } - data_out_batch[c] = sum / size_channel_in; + dout_batch[c] = sum / size_channel_in; } } } else { - LOG(FATAL) << "not support"; + LOG(FATAL) << "unsupported pooling type: " << pooling_type; } - return; - } - - if (pooling_type == "max") { - // Pooling_max - for (int n = 0; n < num; ++n) { - float* data_out_channel = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; -#pragma omp parallel for - for (int q = 0; q < chout; q++) { - float* data_out_row = data_out_channel + q * size_channel_out; - const float* data_in_channel = data_in_batch + q * size_channel_in; - - for (int i = 0; i < hout; i++) { - for (int j = 0; j < wout; j++) { - int hstart = i * stride_h - pad_h; - int wstart = j * stride_w - pad_w; - int hend = std::min(hstart + kernel_h, hin + pad_h); - int wend = std::min(wstart + kernel_w, win + pad_w); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, hin); - wend = std::min(wend, win); - - data_out_row[j] = data_in_channel[hstart * win + wstart]; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - data_out_row[j] = data_out_row[j] > data_in_channel[h * win + w] - ? data_out_row[j] - : data_in_channel[h * win + w]; - } - } - } - data_out_row += wout; - } - } - } - } else if (pooling_type == "avg") { - if (exclusive == false) { - // Pooling_average_include_padding + } else { + if (pooling_type == "max") { + // Pooling_max for (int n = 0; n < num; ++n) { - int pool_size = - kernel_w * - kernel_h; // (hend - hstart) * (wend - wstart); // problem - float* data_out_channel = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + float* dout_ch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int q = 0; q < chout; q++) { - float* data_out_row = data_out_channel + q * size_channel_out; - const float* data_in_channel = data_in_batch + q * size_channel_in; + for (int c = 0; c < chout; c++) { + float* dout_row = dout_ch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; for (int i = 0; i < hout; i++) { for (int j = 0; j < wout; j++) { int hstart = i * stride_h - pad_h; @@ -140,196 +93,212 @@ void pooling_basic(const void* din, void* dout, int num, int chout, int hout, wstart = std::max(wstart, 0); hend = std::min(hend, hin); wend = std::min(wend, win); - - int bh = kernel_h; - int bw = kernel_w; - if (wend == win) { - bw = wstart + kernel_w >= win + pad_w ? win + pad_w - : wstart + kernel_w; - bw -= wstart; - } - if (hend == hin) { - bh = hstart + kernel_h >= hin + pad_h ? hin + pad_h - : hstart + kernel_h; - bh -= hstart; - } - pool_size = bh * bw; - - data_out_row[j] = data_in_channel[hstart * win + wstart]; - float sum = 0.f; + int pool_size = (hend - hstart) * (wend - wstart); + if (pool_size == 0) continue; + float tmp1 = din_ch[hstart * win + wstart]; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - sum += data_in_channel[h * win + w]; + float tmp2 = din_ch[h * win + w]; + tmp1 = tmp1 > tmp2 ? tmp1 : tmp2; } } - data_out_row[j] = sum / pool_size; + dout_row[j] = tmp1; } - data_out_row += wout; + dout_row += wout; } } } - } else { // exclusive == true, Pooling_average_exclude_padding - for (int n = 0; n < num; ++n) { - float* data_out_channel = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + } else if (pooling_type == "avg") { + if (exclusive) { + // Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* dout_ch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int q = 0; q < chout; q++) { - float* data_out_row = data_out_channel + q * size_channel_out; - const float* data_in_channel = data_in_batch + q * size_channel_in; - for (int i = 0; i < hout; i++) { - for (int j = 0; j < wout; j++) { - int hstart = i * stride_h - pad_h; - int wstart = j * stride_w - pad_w; - int hend = std::min(hstart + kernel_h, hin + pad_h); - int wend = std::min(wstart + kernel_w, win + pad_w); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, hin); - wend = std::min(wend, win); - - data_out_row[j] = data_in_channel[hstart * win + wstart]; - float sum = 0.f; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - sum += data_in_channel[h * win + w]; + for (int c = 0; c < chout; c++) { + float* dout_row = dout_ch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + int pool_size = (hend - hstart) * (wend - wstart); + if (pool_size == 0) continue; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += din_ch[h * win + w]; + } } + dout_row[j] = sum / pool_size; } - int pool_size = (hend - hstart) * (wend - wstart); - data_out_row[j] = sum / pool_size; + dout_row += wout; + } + } + } + } else { // Pooling_average_include_padding + for (int n = 0; n < num; ++n) { + float* dout_ch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* dout_row = dout_ch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + int pool_size = (hend - hstart) * (wend - wstart); + if (pool_size == 0) continue; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += din_ch[h * win + w]; + } + } + dout_row[j] = sum / (kernel_w * kernel_h); + } + dout_row += wout; } - data_out_row += wout; } } } + } else { + LOG(FATAL) << "unsupported pooling type: " << pooling_type; } - - } else { - LOG(FATAL) << "not support"; } } -void pooling_global(const void* din, void* dout, int num, int chout, int hout, - int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { +void pooling_global_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win) { int size_channel_in = win * hin; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); int cnt = size_channel_in / 8; - for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * chout; - const float* data_in_batch = data_in + n * chin * size_channel_in; - if (pooling_type == "max") { + float* dout_batch = dout + n * chout; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int c = 0; c < chout; ++c) { - const float* data_in_channel = data_in_batch + c * size_channel_in; - int i = 0; - float minval = std::numeric_limits::lowest(); - float32x4_t vmax = vdupq_n_f32(minval); + for (int c = 0; c < chout; ++c) { + const float* din_ch = din_batch + c * size_channel_in; + int i = 0; + float minval = std::numeric_limits::lowest(); + float32x4_t vmax = vdupq_n_f32(minval); #ifdef __aarch64__ - for (; i < cnt; i++) { - float32x4_t vdin1 = vld1q_f32(data_in_channel); - vmax = vmaxq_f32(vdin1, vmax); - float32x4_t vdin2 = vld1q_f32(data_in_channel + 4); - vmax = vmaxq_f32(vmax, vdin2); - data_in_channel += 8; - } + for (; i < cnt; i++) { + float32x4_t vdin1 = vld1q_f32(din_ch); + vmax = vmaxq_f32(vdin1, vmax); + float32x4_t vdin2 = vld1q_f32(din_ch + 4); + vmax = vmaxq_f32(vmax, vdin2); + din_ch += 8; + } #else - for (; i < cnt; i++) { - float32x4_t vdin1 = vld1q_f32(data_in_channel); - vmax = vmaxq_f32(vdin1, vmax); - float32x4_t vdin2 = vld1q_f32(data_in_channel + 4); - vmax = vmaxq_f32(vmax, vdin2); - data_in_channel += 8; - } + int cnt_num = cnt; + if (cnt_num > 0) { + asm volatile( + "max_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n" + "vmax.f32 %q[vmax], %q[vmax], q0 @max vmax,vmax,din_ch\n" + "vld1.f32 {d2-d3}, [%[din_ch]]! @load 2nd 4 data\n" + "vmax.f32 %q[vmax], %q[vmax], q1 @compare 2nd 4 datas\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne max_loop @bne cnt_num\n" + : [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vmax] "+w"(vmax) + : + : "cc", "memory", "q0", "q1"); + } #endif // __aarch64__ - float32x2_t vmax_tmp = - vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax)); - float tmp1 = vget_lane_f32(vmax_tmp, 0); - float tmp2 = vget_lane_f32(vmax_tmp, 1); - float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2; - for (i = cnt * 8; i < size_channel_in; ++i) { - /* code */ - max_tmp = max_tmp > data_in_channel[0] ? max_tmp : data_in_channel[0]; - data_in_channel++; - } - data_out_batch[c] = max_tmp; + float32x2_t vmax_tmp = vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax)); + float tmp1 = vget_lane_f32(vmax_tmp, 0); + float tmp2 = vget_lane_f32(vmax_tmp, 1); + float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2; + for (i = cnt * 8; i < size_channel_in; ++i) { + /* code */ + max_tmp = max_tmp > din_ch[0] ? max_tmp : din_ch[0]; + din_ch++; } - } else { + dout_batch[c] = max_tmp; + } + } +} + +void pooling_global_avg(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win) { + int size_channel_in = win * hin; + int cnt = size_channel_in / 4; + for (int n = 0; n < num; ++n) { + float* dout_batch = dout + n * chout; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int c = 0; c < chout; c++) { - const float* data_in_channel = - data_in_batch + c * size_channel_in; // in address - int i = 0; - float32x4_t vsum = vdupq_n_f32(0.0f); + for (int c = 0; c < chout; c++) { + const float* din_ch = din_batch + c * size_channel_in; // in address + int i = 0; + float32x4_t vsum = vdupq_n_f32(0.0f); #ifdef __aarch64__ - for (; i < cnt; i++) { // - vsum = vaddq_f32(vld1q_f32(data_in_channel), vsum); - data_in_channel += 4; - } + for (; i < cnt; i++) { + vsum = vaddq_f32(vld1q_f32(din_ch), vsum); + din_ch += 4; + } #else - int num = cnt; - if (num > 0) { - asm volatile( - "add_loop: @main loop\n" - "vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, " - "data_in_channel\n" - "vadd.f32 %q[vsum], %q[vsum], q0 @add vmax, " - "vmax, data_in_channel\n" - "subs %[num], #1 @subs num, 1\n" - "bne add_loop @bne num\n" - : [data_in_channel] "+r"(data_in_channel), [num] "+r"(num), - [vsum] "+w"(vsum) - : - : "cc", "memory", "q0"); - } + int cnt_num = cnt; + if (cnt_num > 0) { + asm volatile( + "add_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n" + "vadd.f32 %q[vsum], %q[vsum], q0 @add vmax,vmax, din_ch\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne add_loop @bne num\n" + : [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vsum] "+w"(vsum) + : + : "cc", "memory", "q0"); + } #endif // __aarch64__ - float32x2_t vsum_tmp = - vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum)); - float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1); - for (i = cnt * 4; i < size_channel_in; i++) { - sum += data_in_channel[0]; - data_in_channel++; - } - data_out_batch[c] = sum / size_channel_in; + float32x2_t vsum_tmp = vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum)); + float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1); + for (i = cnt * 4; i < size_channel_in; i++) { + sum += din_ch[0]; + din_ch++; } + dout_batch[c] = sum / size_channel_in; } } } -void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, - int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { +void pooling2x2s2_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win) { + int kernel = 2; + int stride = 2; + int padding = 0; int size_channel_out = wout * hout; int size_channel_in = win * hin; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); - int w_even = (win >> 1) << 1; - // int w_remains = w_in - w_even; // should be 0 or 1 - int h_even = (hin >> 1) << 1; - // int h_remains = h_in - h_even; // should be 0 or 1 + int w_needed = (wout << 1); + int h_needed = (hout << 1); + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; int w_unroll_size = (w_even >> 3) << 3; - // int w_unroll_remian = w_even - w_unroll_size; + // int w_unroll_remain = w_even - w_unroll_size; int w_in_2 = win << 1; - float32x4_t vzero = vdupq_n_f32(0.f); - for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; const float* r1 = r0 + win; int h = 0; for (; h < h_even; h += 2) { @@ -351,50 +320,45 @@ void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, vpmax_f32(vget_low_f32(dmax2), vget_high_f32(dmax2)); float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); #endif - vst1q_f32(&data_out_channel[w >> 1], dmax); + vst1q_f32(&dout_ch[w >> 1], dmax); } #else - w = w_unroll_size; - int num = w_unroll_size >> 3; + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; - float* dr_out = data_out_channel; - if (num > 0) { + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { asm volatile( - "s2_max_loop: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n" - "vld1.f32 {d4-d7}, [%[dr1]]! @load q1, dr1\n" - "vmax.f32 q0, q0, q2 @max q0, q0, " - "q2\n" - "vmax.f32 q1, q1, q3 @max q1, q1, " - "q2\n" - "vpmax.f32 d4, d0, d1 @max d4, d0, " - "d1\n" - "vpmax.f32 d5, d2, d3 @max d5, d2, " - "d3\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, " - "dr_out\n" - "subs %[num], #1 @subs num, 1\n" - "bne s2_max_loop @bne num\n" + "s2_max_loop: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n" + "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" + "vmax.f32 q1, q1, q3 @max q1,q1,q2\n" + "vpmax.f32 d4, d0, d1 @max d4,d0,d1\n" + "vpmax.f32 d5, d2, d3 @max d5,d2,d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne s2_max_loop @bne cnt_num\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [num] "+r"(num) + [cnt_num] "+r"(cnt_num) : : "cc", "memory", "q0", "q1", "q2", "q3"); } + w = w_unroll_size; #endif // __aarch64__ for (; w < w_even; w += 2) { - data_out_channel[w >> 1] = + dout_ch[w >> 1] = std::max(std::max(r0[w], r0[w + 1]), std::max(r1[w], r1[w + 1])); } - for (; w < win; ++w) { // run 0 or 1 time - data_out_channel[w >> 1] = std::max(r0[w], r1[w]); + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = std::max(r0[w], r1[w]); } r0 += w_in_2; // << 1; r1 += w_in_2; // << 1; - data_out_channel += wout; + dout_ch += wout; } // process remain row (odd, last row) - for (; h < hin; h++) { // run 0 or 1 time + for (; h < h_limit; h++) { // run 0 or 1 time int w = 0; #ifdef __aarch64__ for (; w < w_unroll_size; w += 8) { @@ -409,72 +373,70 @@ void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, vpmax_f32(vget_low_f32(dr01), vget_high_f32(dr01)); float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); #endif - float32x4_t dmax_cmp_zero = vmaxq_f32(dmax, vzero); - vst1q_f32(&data_out_channel[w >> 1], dmax_cmp_zero); + vst1q_f32(&dout_ch[w >> 1], dmax); } #else - w = w_unroll_size; - int num = w_unroll_size >> 3; + float* dr_out = dout_ch; const float* dr0 = r0; - float* dr_out = data_out_channel; - if (num > 0) { + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { asm volatile( - "s2_max_loop1: @main " - "loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n" - "vpmax.f32 d4, d0, d1 @max d4, d0, " - "d1\n" - "vpmax.f32 d5, d2, d3 @max d5, d2, " - "d3\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, " - "dr_out\n" - "subs %[num], #1 @subs num, 1\n" - "bne s2_max_loop1 @bne num\n" - : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [num] "+r"(num) + "s2_max_loop1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vpmax.f32 d4, d0, d1 @max d4,d0,d1\n" + "vpmax.f32 d5, d2, d3 @max d5,d2,d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne s2_max_loop1 @bne cnt_num\n" + : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : : "cc", "memory", "q0", "q1", "q2"); } + w = w_unroll_size; #endif // __aarch64__ for (; w < w_even; w += 2) { - data_out_channel[w >> 1] = std::max(std::max(r0[w], r0[w + 1]), 0.f); + dout_ch[w >> 1] = std::max(r0[w], r0[w + 1]); } - for (; w < win; ++w) { // run 0 or 1 time - data_out_channel[w >> 1] = std::max(r0[w], 0.f); + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = r0[w]; } } } } } -void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, - int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { +void pooling2x2s2_avg(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + bool exclusive) { + int kernel = 2; + int stride = 2; + int padding = 0; int size_channel_out = wout * hout; int size_channel_in = win * hin; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); - int w_even = (win >> 1) << 1; - // int w_remains = w_in - w_even; // should be 0 or 1 - int h_even = (hin >> 1) << 1; - // int h_remains = h_in - h_even; // should be 0 or 1 + int w_needed = (wout << 1); + int h_needed = (hout << 1); + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; int w_unroll_size = (w_even >> 3) << 3; - // int w_unroll_remian = w_even - w_unroll_size; + // int w_unroll_remain = w_even - w_unroll_size; int w_in_2 = win << 1; - float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4 - + const float coef = 1.f / 4.f; + const float coef_1 = exclusive ? 1.f : coef; + const float coef_2 = exclusive ? 1.f / 2.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_1 = vdupq_n_f32(coef_1); + float32x4_t vcoef_2 = vdupq_n_f32(coef_2); for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; const float* r1 = r0 + win; int h = 0; for (; h < h_even; h += 2) { @@ -497,55 +459,45 @@ void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, float32x4_t dsum = vcombine_f32(dsuml, dsumh); #endif float32x4_t res = vmulq_f32(dsum, vcoef); - vst1q_f32(&data_out_channel[w >> 1], res); + vst1q_f32(&dout_ch[w >> 1], res); } #else - w = w_unroll_size; - int num = w_unroll_size >> 3; + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; - float* dr_out = data_out_channel; - - if (num > 0) { + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { asm volatile( - "1: @ main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, " - "dr0\n" - "vld1.f32 {d4-d7}, [%[dr1]]! @ load q1, " - "dr1\n" - "vadd.f32 q0, q0, q2 @ add q0, q0, " - "q2\n" - "vadd.f32 q1, q1, q3 @ add q1, q1, " - "q2\n" - "vpadd.f32 d4, d0, d1 @ add d4, d0, " - "d1\n" - "vpadd.f32 d5, d2, d3 @ add d5, d2, " - "d3\n" - "vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, " - "vcoef\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, " - "dr_out\n" - "subs %[num], #1 @ subs num, 1\n" - "bne 1b @ bne num\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n" + "vadd.f32 q0, q0, q2 @add q0,q0,q2\n" + "vadd.f32 q1, q1, q3 @add q1,q1,q2\n" + "vpadd.f32 d4, d0, d1 @add d4,d0,d1\n" + "vpadd.f32 d5, d2, d3 @add d5,d2,d3\n" + "vmul.f32 q2, q2, %q[vcoef] @mul q2,q2,vcoef\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne 1b @bne cnt_num\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [vcoef] "+w"(vcoef), [num] "+r"(num) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(num), "w"(vcoef) + [vcoef] "+w"(vcoef), [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "w"(vcoef) : "cc", "memory", "q0", "q1", "q2", "q3"); } + w = w_unroll_size; #endif // __aarch64__ for (; w < w_even; w += 2) { - data_out_channel[w >> 1] = - (r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) / 4.f; + dout_ch[w >> 1] = (r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) * coef; } - for (; w < win; ++w) { // run 0 or 1 time - data_out_channel[w >> 1] = (r0[w] + r1[w]) / 4.f; + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = (r0[w] + r1[w]) * coef_2; } r0 += w_in_2; // << 1; r1 += w_in_2; // << 1; - data_out_channel += wout; + dout_ch += wout; } // process remain row (odd, last row) - for (; h < hin; h++) { // run 0 or 1 time + for (; h < h_limit; h++) { // run 0 or 1 time int w = 0; #ifdef __aarch64__ for (; w < w_unroll_size; w += 8) { @@ -560,104 +512,74 @@ void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, vpadd_f32(vget_low_f32(dr01), vget_high_f32(dr01)); float32x4_t dsum = vcombine_f32(dsuml, dsumh); #endif - float32x4_t res = vmulq_f32(dsum, vcoef); - vst1q_f32(&data_out_channel[w >> 1], res); + float32x4_t res = vmulq_f32(dsum, vcoef_2); + vst1q_f32(&dout_ch[w >> 1], res); } #else - w = w_unroll_size; - int num = w_unroll_size >> 3; + float* dr_out = dout_ch; const float* dr0 = r0; - float* dr_out = data_out_channel; - - if (num > 0) { + int cnt_num = w_unroll_size >> 3; + if (cnt_num > 0) { asm volatile( - "1: @ main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, " - "dr0\n" - "vpadd.f32 d4, d0, d1 @ add d4, d0, " - "d1\n" - "vpadd.f32 d5, d2, d3 @ add d5, d2, " - "d3\n" - "vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, " - "vcoef\n" - "vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, " - "dr_out\n" - "subs %[num], #1 @ subs num, 1\n" - "bne 1b @ bne num\n" - : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [vcoef] "+w"(vcoef), - [num] "+r"(num) - : "r"(dr0), "r"(dr_out), "r"(num), "w"(vcoef) + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n" + "vpadd.f32 d4, d0, d1 @add d4,d0,d1\n" + "vpadd.f32 d5, d2, d3 @add d5,d2,d3\n" + "vmul.f32 q2, q2, %q[vcoef_2] @mul q2,q2,vcoef_2\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "bne 1b @bne cnt_num\n" + : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [vcoef_2] "+w"(vcoef_2), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr_out), "r"(cnt_num), "w"(vcoef_2) : "cc", "memory", "q0", "q1", "q2"); } + w = w_unroll_size; #endif // __aarch64__ for (; w < w_even; w += 2) { - data_out_channel[w >> 1] = (r0[w] + r0[w + 1]) / 4.f; + dout_ch[w >> 1] = (r0[w] + r0[w + 1]) * coef_2; } - for (; w < win; ++w) { // run 0 or 1 time - data_out_channel[w >> 1] = r0[w] / 4.f; + for (; w < w_limit; ++w) { // run 0 or 1 time + dout_ch[w >> 1] = r0[w] * coef_1; } } } } } -void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, - int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { - // no need to pad input tensor, pad_size is not used, default border is zero - // padded - int ch_in = chin; - int h_in = hin; - int w_in = win; - - int ch_out = chout; - int h_out = hout; - int w_out = wout; - - int size_channel_out = w_out * h_out; +void pooling3x3s1p1_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win) { + int kernel = 3; + int stride = 1; + int padding = 1; + int size_channel_out = wout * hout; int size_channel_in = win * hin; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); - - int w_even = (w_in >> 1) << 1; - // int w_remains = w_in - w_even; // should be 0 or 1 - int h_even = (h_in >> 1) << 1; - // int h_remains = h_in - h_even; // should be 0 or 1 - // int w_unroll_size = (w_even >> 3) << 3; - // int w_unroll_remian = w_even - w_unroll_size; - int w_in_2 = w_in << 1; - int w_unroll_size = (w_in - 2) >> 2; - int w_unroll_remian = w_in - 2 - w_unroll_size * 4; - float minval = std::numeric_limits::lowest(); - float32x4_t vzero = vdupq_n_f32(minval); // zero pad + int w_unroll_size = ((win - 2) >> 2) << 2; + int w_unroll_remain = win - 2 - w_unroll_size; + const float minval = std::numeric_limits::lowest(); for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * ch_out * size_channel_out; - const float* data_in_batch = data_in + n * ch_in * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int c = 0; c < ch_out; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; - const float* r1 = r0 + w_in; - const float* r2 = r1 + w_in; - int cnt_num = w_unroll_size; // w_in / 4 - float* dr_out = data_out_channel; + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4 + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; const float* dr2 = r2; int w = 0; int cnt = 1; // left - data_out_channel[0] = - std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); + dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); // first row with zero pad #ifdef __aarch64__ - for (; w <= w_in - 6; w += 4) { + for (; w < w_unroll_size; w += 4) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr1_1234 = vld1q_f32(&r1[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); @@ -679,49 +601,39 @@ void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); - vst1q_f32(&data_out_channel[cnt], vmax); + vst1q_f32(&dout_ch[cnt], vmax); cnt += 4; } #else dr_out = dr_out + 1; - if (cnt_num > 0) { asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" - "vmax.f32 q5, q0, q2 @max " - "r0_1234,r1_1234\n" - "vmax.f32 d12, d2, d6 @max " - "r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vpmax.f32 d2, d10, d11 @pmax d4, " - "max_1234, max_1234\n" - "vpmax.f32 d3, d0, d1 @pmax d4, " - "max_2345, max_2345\n" - "vpmax.f32 d6, d4, d5 @pmax d6, " - "max_3456, max_3456\n" - "vmax.f32 d8, d2, d3 @max d2, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d9, d3, d6 @max d2, " - "vmax_23_45, vmax_34_56\n" - "sub %[dr0], #8 @sub w, 8\n" - "sub %[dr1], #8 @sub w, 8\n" + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" + "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" // swap - "vmov.f32 s0, s17 @mov \n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s0 @mov \n" - "subs %[cnt_num], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" - "bne 1b @bne s1_max_loop\n" + "vmov.f32 s0, s17 @mov\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s0 @mov\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) @@ -729,34 +641,34 @@ void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, } #endif - // remian - w = w_unroll_size * 4; - for (int j = 0; j < w_unroll_remian; j++) { + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { float tmp_max = std::max(r0[j + w], r1[j + w]); tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); - data_out_channel[j + w + 1] = tmp_max; + dout_ch[j + w + 1] = tmp_max; } // right - float tmp = std::max(r0[w_in - 2], r1[w_in - 2]); - tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); - data_out_channel[w_out - 1] = tmp; + float tmp = std::max(r0[win - 2], r1[win - 2]); + tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); + dout_ch[wout - 1] = tmp; // r0 = r1; // r1 = r0 + w_in; // r2 = r1 + w_in; - data_out_channel += w_out; + dout_ch += wout; int h = 0; - for (; h < h_in - 2; h += 1) { + for (; h < hin - 2; h += 1) { // deal with left pad float maxr0 = std::max(r0[0], r0[1]); float maxr1 = std::max(r1[0], r1[1]); float maxr2 = std::max(r2[0], r2[1]); - data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); + dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); #ifdef __aarch64__ w = 0; cnt = 1; - for (; w <= w_in - 6; w += 4) { + for (; w < w_unroll_size; w += 4) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr1_1234 = vld1q_f32(&r1[w]); float32x4_t vr2_1234 = vld1q_f32(&r2[w]); @@ -782,64 +694,47 @@ void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); - vst1q_f32(&data_out_channel[cnt], vmax); + vst1q_f32(&dout_ch[cnt], vmax); cnt += 4; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = r0; dr1 = r1; dr2 = r2; - cnt_num = w_unroll_size; + cnt_num = w_unroll_size >> 2; if (cnt_num > 0) { asm volatile( - "1: @main " - "loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" - "vmax.f32 q7, q0, q2 @max " - "r0_1234,r1_1234\n" - "vmax.f32 d16, d2, d6 @max " - "r0_5678,r1_5678\n" - "vmax.f32 q3, q7, q4 @max " - "r0_1234,r1_1234\n" - "vmax.f32 d12, d16, d10 @max " - "r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q3, q6, #1 @vext max_2345\n" - "vext.f32 q2, q3, q6, #2 @vext max_3456\n" - "vpmax.f32 d2, d6, d7 @pmax d4, " - "max_1234, max_1234\n" - "vpmax.f32 d3, d0, d1 @pmax d4, " - "max_2345, max_2345\n" - "vpmax.f32 d6, d4, d5 @pmax d6, " - "max_3456, max_3456\n" - "vmax.f32 d8, d2, d3 @max d2, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d9, d3, d6 @max d2, " - "vmax_23_45, vmax_34_56\n" - "sub %[dr0], #8 @sub w, 8\n" - "sub %[dr1], #8 @sub w, 8\n" - "sub %[dr2], #8 @sub w, 8\n" + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" + "vmax.f32 q7, q0, q2 @max r0_1234,r1_1234\n" + "vmax.f32 d16, d2, d6 @max r0_5678,r1_5678\n" + "vmax.f32 q3, q7, q4 @max r0_1234,r1_1234\n" + "vmax.f32 d12, d16, d10 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d6, d7 @pmax d4,max_1234,max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" + "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "sub %[dr2], #8 @sub w,8\n" // swap - "vmov.f32 s0, s17 @mov \n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s0 @mov \n" - "subs %[cnt_num], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "bne 1b @ bne " - "s1_max_loop\n" + "vmov.f32 s0, s17 @mov\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s0 @mov\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) @@ -847,36 +742,36 @@ void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, "q8"); } #endif - // remian - w = w_unroll_size * 4; - for (int j = 0; j < w_unroll_remian; j++) { + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { float tmp_max = std::max(r0[j + w], r1[j + w]); tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); tmp_max = std::max(tmp_max, std::max(r2[j + w], r2[j + w + 1])); tmp_max = std::max(tmp_max, r2[j + w + 2]); - data_out_channel[j + w + 1] = tmp_max; + dout_ch[j + w + 1] = tmp_max; } // right - tmp = std::max(r0[w_in - 2], r1[w_in - 2]); - tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); - tmp = std::max(tmp, std::max(r2[w_in - 2], r2[w_in - 1])); - data_out_channel[w_out - 1] = tmp; + tmp = std::max(r0[win - 2], r1[win - 2]); + tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); + tmp = std::max(tmp, std::max(r2[win - 2], r2[win - 1])); + dout_ch[wout - 1] = tmp; r0 = r1; r1 = r2; - r2 = r1 + w_in; - data_out_channel += w_out; + r2 = r1 + win; + dout_ch += wout; } // the last two line float maxr0 = std::max(r0[0], r0[1]); float maxr1 = std::max(r1[0], r1[1]); - data_out_channel[0] = std::max(maxr0, maxr1); + dout_ch[0] = std::max(maxr0, maxr1); #ifdef __aarch64__ w = 0; cnt = 1; - for (; w <= w_in - 6; w += 4) { + for (; w < w_unroll_size; w += 4) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr1_1234 = vld1q_f32(&r1[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); @@ -898,50 +793,41 @@ void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); - vst1q_f32(&data_out_channel[cnt], vmax); + vst1q_f32(&dout_ch[cnt], vmax); cnt += 4; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = r0; dr1 = r1; - cnt_num = w_unroll_size; + cnt_num = w_unroll_size >> 2; if (cnt_num > 0) { asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" - "vmax.f32 q5, q0, q2 @max " - "r0_1234,r1_1234\n" - "vmax.f32 d12, d2, d6 @max " - "r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vpmax.f32 d2, d10, d11 @pmax d4, " - "max_1234, max_1234\n" - "vpmax.f32 d3, d0, d1 @pmax d4, " - "max_2345, max_2345\n" - "vpmax.f32 d6, d4, d5 @pmax d6, " - "max_3456, max_3456\n" - "vmax.f32 d8, d2, d3 @max d2, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d9, d3, d6 @max d2, " - "vmax_23_45, vmax_34_56\n" - "sub %[dr0], #8 @sub w, 8\n" - "sub %[dr1], #8 @sub w, 8\n" + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n" + "vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" // swap - "vmov.f32 s0, s17 @mov \n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s0 @mov \n" - "subs %[cnt_num], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" - "bne 1b @bne s1_max_loop\n" + "vmov.f32 s0, s17 @mov\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s0 @mov\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) @@ -949,70 +835,61 @@ void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, } #endif // remian - w = w_unroll_size * 4; - for (int j = 0; j < w_unroll_remian; j++) { + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { float tmp_max = std::max(r0[j + w], r1[j + w]); tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); - data_out_channel[j + w + 1] = tmp_max; + dout_ch[j + w + 1] = tmp_max; } - tmp = std::max(r0[w_in - 2], r1[w_in - 2]); - tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); - data_out_channel[w_out - 1] = tmp; + tmp = std::max(r0[win - 2], r1[win - 2]); + tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1])); + dout_ch[wout - 1] = tmp; } } } -void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, +void pooling3x3s1p1_avg(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { - int w_in = win; - int h_in = hin; - int ch_in = chin; - - int w_out = wout; - int h_out = hout; - int ch_out = chout; - - int size_channel_out = w_out * h_out; - int size_channel_in = w_in * h_in; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); - - int w_even = (w_in >> 1) << 1; - int h_even = (h_in >> 1) << 1; - int w_in_2 = w_in << 1; - int w_unroll_size = (w_in - 2) >> 2; - int w_unroll_remian = w_in - 2 - w_unroll_size * 4; - float32x4_t vzero = vdupq_n_f32(0.f); // zero pad - float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); // zero pad + bool exclusive) { + int kernel = 3; + int stride = 1; + int padding = 1; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + int w_unroll_size = ((win - 2) >> 2) << 2; + int w_unroll_remain = win - 2 - w_unroll_size; + const float coef = 1.f / 9.f; + const float coef_2 = exclusive ? 1.f / 2.f : coef; + const float coef_4 = exclusive ? 1.f / 4.f : coef; + const float coef_6 = exclusive ? 1.f / 6.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_2 = vdupq_n_f32(coef_2); + float32x4_t vcoef_4 = vdupq_n_f32(coef_4); + float32x4_t vcoef_6 = vdupq_n_f32(coef_6); for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * ch_out * size_channel_out; - const float* data_in_batch = data_in + n * ch_in * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int c = 0; c < ch_out; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; - const float* r1 = r0 + w_in; - const float* r2 = r1 + w_in; - int cnt_num = w_unroll_size; // w_in / 4 - float* dr_out = data_out_channel; + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4 + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; const float* dr2 = r2; int w = 0; int cnt = 1; // left - data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; + dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; // first row with zero pad #ifdef __aarch64__ - for (; w <= w_in - 6; w += 4) { + for (; w < w_unroll_size; w += 4) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr1_1234 = vld1q_f32(&r1[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); @@ -1024,74 +901,68 @@ void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); vsum = vaddq_f32(vsum, vsum_3456); - vsum = vmulq_f32(vsum, vcoef); - vst1q_f32(&data_out_channel[cnt], vsum); + vsum = vmulq_f32(vsum, vcoef_6); + vst1q_f32(&dout_ch[cnt], vsum); cnt += 4; } - #else dr_out = dr_out + 1; - if (cnt_num > 0) { asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" - "vadd.f32 q5, q0, q2 @max " - "r0_1234,r1_1234\n" - "vadd.f32 d12, d2, d6 @max " - "r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vadd.f32 q1, q5, q0 @add 1234 + 2345\n" - "vadd.f32 q1, q1, q2 @add + 3456\n" - "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" - "sub %[dr0], #8 @sub w, 8\n" - "sub %[dr1], #8 @sub w, 8\n" - "subs %[cnt_num], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" - "bne 1b @bne s1_max_loop\n" + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234+2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef) + [cnt_num] "+r"(cnt_num), [vcoef_6] "+w"(vcoef_6) : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); } #endif - // remian - w = w_unroll_size * 4; - for (int j = 0; j < w_unroll_remian; j++) { + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { float tmp_sum = r0[j + w] + r1[j + w]; tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); - data_out_channel[j + w + 1] = tmp_sum / 9.f; + dout_ch[j + w + 1] = tmp_sum * coef_6; } // right - float tmp = r0[w_in - 2] + r1[w_in - 2]; - tmp += (r0[w_in - 1] + r1[w_in - 1]); - data_out_channel[w_out - 1] = tmp / 9.f; + float tmp = r0[win - 2] + r1[win - 2]; + tmp += (r0[win - 1] + r1[win - 1]); + dout_ch[wout - 1] = tmp * coef_4; // r0 = r1; // r1 = r0 + w_in; // r2 = r1 + w_in; - data_out_channel += w_out; + dout_ch += wout; int h = 0; - for (; h < h_in - 2; h += 1) { + for (; h < hin - 2; h += 1) { // deal with left pad float maxr0 = r0[0] + r0[1]; float maxr1 = r1[0] + r1[1]; float maxr2 = r2[0] + r2[1]; - data_out_channel[0] = (maxr0 + maxr1 + maxr2) / 9.f; + dout_ch[0] = (maxr0 + maxr1 + maxr2) * coef_6; #ifdef __aarch64__ w = 0; cnt = 1; - for (; w <= w_in - 6; w += 4) { + for (; w < w_unroll_size; w += 4) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr1_1234 = vld1q_f32(&r1[w]); float32x4_t vr2_1234 = vld1q_f32(&r2[w]); @@ -1108,53 +979,41 @@ void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); vsum = vaddq_f32(vsum, vsum_3456); vsum = vmulq_f32(vsum, vcoef); - vst1q_f32(&data_out_channel[cnt], vsum); + vst1q_f32(&dout_ch[cnt], vsum); cnt += 4; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = r0; dr1 = r1; dr2 = r2; - cnt_num = w_unroll_size; + cnt_num = w_unroll_size >> 2; if (cnt_num > 0) { asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" - "vadd.f32 q7, q0, q2 @max " - "r0_1234,r1_1234\n" - "vadd.f32 d16, d2, d6 @max " - "r0_5678,r1_5678\n" - "vadd.f32 q3, q7, q4 @max " - "r0_1234,r1_1234\n" - "vadd.f32 d12, d16, d10 @max " - "r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q3, q6, #1 @vext max_2345\n" - "vext.f32 q2, q3, q6, #2 @vext max_3456\n" - "vadd.f32 q1, q3, q0 @add 1234 + " - "2345\n" - "vadd.f32 q1, q1, q2 @add + 3456\n" - "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" - "sub %[dr0], #8 @sub w, 8\n" - "sub %[dr1], #8 @sub w, 8\n" - "sub %[dr2], #8 @sub w, 8\n" - "subs %[cnt_num], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "bne 1b @bne " - "s1_max_loop\n" + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7,dr1\n" + "vadd.f32 q7, q0, q2 @max r0_1234,r1_1234\n" + "vadd.f32 d16, d2, d6 @max r0_5678,r1_5678\n" + "vadd.f32 q3, q7, q4 @max r0_1234,r1_1234\n" + "vadd.f32 d12, d16, d10 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q3, q0 @add 1234+2345\n" + "vadd.f32 q1, q1, q2 @add+3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul*1/9.f\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "sub %[dr2], #8 @sub w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef) @@ -1163,36 +1022,36 @@ void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, "q8"); } #endif - // remian - w = w_unroll_size * 4; - for (int j = 0; j < w_unroll_remian; j++) { + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { float tmp_sum = r0[j + w] + r1[j + w]; tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); tmp_sum += (r2[j + w + 1] + r2[j + w + 2]); tmp_sum += r2[j + w]; - data_out_channel[j + w + 1] = tmp_sum / 9.f; + dout_ch[j + w + 1] = tmp_sum * coef; } // right - tmp = r0[w_in - 2] + r1[w_in - 2]; - tmp += (r0[w_in - 1] + r1[w_in - 1]); - tmp += (r2[w_in - 2] + r2[w_in - 1]); - data_out_channel[w_out - 1] = tmp / 9.f; + tmp = r0[win - 2] + r1[win - 2]; + tmp += (r0[win - 1] + r1[win - 1]); + tmp += (r2[win - 2] + r2[win - 1]); + dout_ch[wout - 1] = tmp * coef_6; r0 = r1; r1 = r2; - r2 = r1 + w_in; - data_out_channel += w_out; + r2 = r1 + win; + dout_ch += wout; } - // the last two line + // last line float maxr0 = (r0[0] + r0[1]); float maxr1 = (r1[0] + r1[1]); - data_out_channel[0] = (maxr0 + maxr1) / 9.f; + dout_ch[0] = (maxr0 + maxr1) * coef_4; #ifdef __aarch64__ w = 0; cnt = 1; - for (; w <= w_in - 6; w += 4) { + for (; w < w_unroll_size; w += 4) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr1_1234 = vld1q_f32(&r1[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); @@ -1204,119 +1063,100 @@ void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); vsum = vaddq_f32(vsum, vsum_3456); - vsum = vmulq_f32(vsum, vcoef); - vst1q_f32(&data_out_channel[cnt], vsum); + vsum = vmulq_f32(vsum, vcoef_6); + vst1q_f32(&dout_ch[cnt], vsum); cnt += 4; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = r0; dr1 = r1; - cnt_num = w_unroll_size; + cnt_num = w_unroll_size >> 2; if (cnt_num > 0) { asm volatile( - "1: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" - "vadd.f32 q5, q0, q2 @max " - "r0_1234,r1_1234\n" - "vadd.f32 d12, d2, d6 @max " - "r0_5678,r1_5678\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q5, q6, #1 @vext max_2345\n" - "vext.f32 q2, q5, q6, #2 @vext max_3456\n" - "vadd.f32 q1, q5, q0 @add 1234 + 2345\n" - "vadd.f32 q1, q1, q2 @add + 3456\n" - "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" - "sub %[dr0], #8 @sub w, 8\n" - "sub %[dr1], #8 @sub w, 8\n" - "subs %[cnt_num], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" - "bne 1b @bne s1_max_loop\n" + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234+2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n" + "sub %[dr0], #8 @sub w,8\n" + "sub %[dr1], #8 @sub w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s1_max_loop\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef) + [cnt_num] "+r"(cnt_num), [vcoef_6] "+w"(vcoef_6) : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); } #endif - // remian - w = w_unroll_size * 4; - for (int j = 0; j < w_unroll_remian; j++) { + // remain + w = w_unroll_size; + for (int j = 0; j < w_unroll_remain; j++) { float tmp_sum = r0[j + w] + r1[j + w]; tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); - data_out_channel[j + w + 1] = tmp_sum / 9.f; + dout_ch[j + w + 1] = tmp_sum * coef_6; } // right - tmp = r0[w_in - 2] + r1[w_in - 2]; - tmp += (r0[w_in - 1] + r1[w_in - 1]); - data_out_channel[w_out - 1] = tmp / 9.f; + tmp = r0[win - 2] + r1[win - 2]; + tmp += (r0[win - 1] + r1[win - 1]); + dout_ch[wout - 1] = tmp * coef_4; } } } -void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, - int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { +void pooling3x3s2p1_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win) { + int kernel = 3; + int stride = 2; + int padding = 1; int size_channel_out = wout * hout; int size_channel_in = win * hin; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); - - int kernel_h = ksize[0]; - int kernel_w = ksize[1]; - int stride_h = strides[0]; - int stride_w = strides[1]; - int pad_h = paddings[0]; - int pad_w = paddings[1]; - int pad_top = pad_h; - int pad_left = pad_w; - int w_needed = wout * 2 + 1; - int h_needed = hout * 2 + 1; - int pad_right = w_needed - win - pad_left; - int pad_bottom = h_needed - hin - pad_top; - int w_even = (win >> 1) << 1; - int h_even = (hin >> 1) << 1; + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; + int w_unroll_size = ((w_even - 1) >> 3) << 3; + int w_unroll_remain = w_even - 1 - w_unroll_size; + int w_remain = w_needed - w_limit - padding; + int h_remain = h_needed - h_limit - padding; int w_in_2 = win << 1; float minval = std::numeric_limits::lowest(); - float32x4_t vzero = vdupq_n_f32(minval); // zero pad - int cnt_col = (win - 1) / 8; - // remain - int remain = ((win - 1) % 8) / 2; - for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; const float* r1 = r0 + win; const float* r2 = r1 + win; - float* dr_out = data_out_channel; + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; const float* dr2 = r2; int w = 1; int cnt = 1; - int cnt_num = cnt_col; - int cnt_num1 = remain; - data_out_channel[0] = - std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); + dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); // first row with zero pad -#ifdef __aarch64__ - for (; w < win - 8; w += 8) { +#if __aarch64__ + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -1338,8 +1178,8 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&data_out_channel[cnt], vmax_123_345); - vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -1351,112 +1191,94 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, float32x2_t vmax2 = vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); vmax2 = vpmax_f32(vmax2, vmax2); - data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + dout_ch[cnt] = vget_lane_f32(vmax2, 0); cnt++; } #else dr0 = dr0 + 1; dr1 = dr1 + 1; dr_out = dr_out + 1; - if (cnt_num > 0 || cnt_num1 > 0) { + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vmax.f32 q6, q0, q3 @max " - "r0_1234,r1_1234\n" - "vmax.f32 q7, q1, q4 @max " - "r0_5678,r1_5678\n" - "vmax.f32 q8, q2, q5 @max " - "r0_9101112,r1_9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q7, q8, #1 @vext max_6789\n" - "vpmax.f32 d4, d12, d13 @pmax d4, " - "vmax_1234, vmax_1234\n" - "vpmax.f32 d6, d14, d15 @pmax d6, " - "vmax_5678, vmax_5678\n" - "vpmax.f32 d5, d0, d1 @pmax d5, " - "vmax_2345, vmax_2345\n" - "vpmax.f32 d7, d2, d3 @pmax d7, " - "vmax_6789, vmax_6789\n" - "vmax.f32 d8, d4, d5 @max d2, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d9, d6, d7 @max d2, " - "vmax_56_78, vmax_67_89\n" - "sub %[dr0], #16 @add w, 8\n" - "sub %[dr1], #16 @add w, 8\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "bne 1b @bne s3_max_loop\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp cnt_num, " - "0\n" - "ble 4f @ble exit\n" - "2: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vmov.f32 s7,s6 @movs7, s6\n" - "vmax.f32 q0, q0, q1 @max q0, q0, q1\n" - "vpmax.f32 d0, d0, d1 @pmax d0, d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0, d0, d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs " - "cnt_num, #1\n" - "bne 2b @bne " - "s3_max_loop_1\n" - "4: @exit\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q6, q0, q3 @max r0_1234,r1_1234\n" + "vmax.f32 q7, q1, q4 @max r0_5678,r1_5678\n" + "vmax.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q7, q8, #1 @vext max_6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4,vmax_1234,vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6,vmax_5678,vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5,vmax_2345,vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7,vmax_6789,vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num, #1\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); } -// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); #endif // int w = w_even - 1; - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); float tmp = r0[wstart]; // std::numeric_limits::min(); for (int i = wstart; i < wend; i++) { // only run 1 or 2 times tmp = std::max(tmp, std::max(r0[i], r1[i])); } - data_out_channel[w_even >> 1] = tmp; + dout_ch[w_even >> 1] = tmp; // cnt ++; } r0 = r1; r1 = r0 + win; r2 = r1 + win; - data_out_channel += wout; + dout_ch += wout; int h = 2; for (; h < h_even; h += 2) { // deal with left pad float maxr0 = std::max(r0[0], r0[1]); float maxr1 = std::max(r1[0], r1[1]); float maxr2 = std::max(r2[0], r2[1]); - data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); -#ifdef __aarch64__ + dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); +#if __aarch64__ w = 1; cnt = 1; - for (; w < win - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -1484,8 +1306,8 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&data_out_channel[cnt], vmax_123_345); - vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -1500,140 +1322,108 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, float32x2_t vmax2 = vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); float32x2_t vmax = vpmax_f32(vmax2, vmax2); - data_out_channel[cnt] = vget_lane_f32(vmax, 0); + dout_ch[cnt] = vget_lane_f32(vmax, 0); cnt++; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = (r0 + 1); dr1 = (r1 + 1); dr2 = (r2 + 1); - cnt_num = cnt_col; - cnt_num1 = remain; - if (cnt_num > 0 || cnt_num1 > 0) { + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" - "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" - "vmax.f32 q11, q2, q5 @max q1,q1,q3\n" - "vmax.f32 q0, q9, q6 @max q0,q0,q2 " - "1234\n" - "vmax.f32 q3, q10, q7 @max q1,q1,q3 " - "5678\n" - "vmax.f32 q1, q11, q8 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q4, q0, q3, #1 @vext 2345\n" - "vext.f32 q2, q3, q1, #1 @vext 6789\n" - "vpmax.f32 d10, d0, d1 @pmax d10, " - "vmax_1234, vmax_1234\n" - "vpmax.f32 d12, d6, d7 @pmax d12, " - "vmax_5678, vmax_5678\n" - "vpmax.f32 d11, d8, d9 @pmax d11, " - "vmax_2345, vmax_2345\n" - "vpmax.f32 d13, d4, d5 @pmax d13, " - "vmax_6789, vmax_6789\n" - "vmax.f32 d0, d10, d11 @pmax d0, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d1, d12, d13 @pmax d1, " - "vmax_56_78, vmax_67_89\n" - "sub %[dr0], #16 @add w, 8\n" - "sub %[dr1], #16 @add w, 8\n" - "sub %[dr2], #16 @add w, 8\n" - "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "bne 1b @bne " - "s3_max_loop_mid\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit1\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 q11, q2, q5 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n" + "vmax.f32 q1, q11, q8 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax d10,vmax_1234,vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax d12,vmax_5678,vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax d11,vmax_2345,vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax d13,vmax_6789,vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax d0,vmax_12_34,vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax d1,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "sub %[dr2], #16 @add w,8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit1\n" "2: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " - "dr1\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vmov.f32 s7,s6 @movs7, s6\n" - "vmov.f32 s11,s10 @movs11, s10\n" - "vmax.f32 q0, q0, q1 @max q0, q0, " - "q1\n" - "vmax.f32 q0, q0, q2 @max q0, q0, " - "q2\n" - "vpmax.f32 d0, d0, d1 @pmax d0, " - "d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0, d0, " - "d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "sub %[dr2], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs cnt_num, " - "#1\n" - "bne 2b @bne " - "s3_max_loop_mid_1\n" - "4: @exit\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmov.f32 s11,s10 @movs11,s10\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), - [cnt_num1] "+r"(cnt_num1) + [cnt_num_remain] "+r"(cnt_num_remain) : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), - "r"(cnt_num1) + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); float tmp = r0[wstart]; // std::numeric_limits::min(); for (int i = wstart; i < wend; i++) { tmp = std::max(tmp, std::max(r0[i], r1[i])); tmp = std::max(tmp, r2[i]); } - data_out_channel[w_even >> 1] = tmp; + dout_ch[w_even >> 1] = tmp; // cnt ++; } r0 = r2; r1 = r0 + win; r2 = r1 + win; - data_out_channel += wout; + dout_ch += wout; } - if (pad_bottom) { + if (h_remain > 0) { // deal with bottom pad // first row with zero pad - int hstart = (h >> 1) * stride_h - pad_h; - int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); - + int hstart = (h >> 1) * stride - padding; + int hend = std::min(std::min(hstart + kernel, hin + padding), hin); if (hstart == hend - 1) { // only one lline - data_out_channel[0] = std::max(r0[0], r0[1]); -#ifdef __aarch64__ + dout_ch[0] = std::max(r0[0], r0[1]); +#if __aarch64__ w = 1; cnt = 1; - for (; w < win - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vmax_1234 = vld1q_f32(&r0[w]); float32x4_t vmax_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vmax_9101112 = vld1q_f32(&r0[w + 8]); @@ -1649,8 +1439,8 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&data_out_channel[cnt], vmax_123_345); - vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -1658,93 +1448,73 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, vr0 = vsetq_lane_f32(minval, vr0, 3); float32x2_t vmax = vpmax_f32(vget_low_f32(vr0), vget_high_f32(vr0)); vmax = vpmax_f32(vmax, vmax); - data_out_channel[cnt] = vget_lane_f32(vmax, 0); + dout_ch[cnt] = vget_lane_f32(vmax, 0); cnt++; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = (r0 + 1); - cnt_num = cnt_col; - cnt_num1 = remain; - if (cnt_num > 0 || cnt_num1 > 0) { + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3, " - "dr0\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " - "dr0\n" - "vext.f32 q4, q0, q1, #1 @vext q4, q0, " - "q1, 1 2345\n" - "vext.f32 q5, q1, q2, #1 @vext q5, q0, " - "q1, 1 6789\n" - "vpmax.f32 d12, d0, d1 @pmax d12, " - "vmax_1234, vmax_1234\n" - "vpmax.f32 d14, d2, d3 @pmax d14, " - "vmax_5678, vmax_5678\n" - "vpmax.f32 d13, d8, d9 @pmax d13, " - "vmax_2345, vmax_2345\n" - "vpmax.f32 d15, d10, d11 @pmax d15, " - "vmax_6789, vmax_6789\n" - "vmax.f32 d0, d12, d13 @max d0, " - "vmax_12_34,vmax_23_45\n" - "vmax.f32 d1, d14, d15 @pmax d2, " - "vmax_56_78, vmax_67_89\n" - "sub %[dr0], #16 @add w, 6\n" - "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "bne 1b @bne " - "s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vpmax.f32 d0, d0, d1 @pmax d0, " - "d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0, d0, " - "d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "sub %[dr0], #8 @add w, 2\n" - "subs %[cnt_num1], #1 @subs " - "cnt_num, #1\n" - "bne 2b @bne " - "s3_max_loop_bot_1\n" - "4: @exit\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" + "vext.f32 q4, q0, q1, #1 @vmax_2345\n" + "vext.f32 q5, q1, q2, #1 @vmax_6789\n" + "vpmax.f32 d12, d0, d1 @vmax_12_34\n" + "vpmax.f32 d14, d2, d3 @vmax_56_78\n" + "vpmax.f32 d13, d8, d9 @vmax_23_45\n" + "vpmax.f32 d15, d10, d11 @vmax_67_89\n" + "vmax.f32 d0, d12, d13 @12_34,23_45\n" + "vmax.f32 d1, d14, d15 @56_78,67_89\n" + "sub %[dr0], #16 @add w,6\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,2\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); float tmp = r0[wstart]; // std::numeric_limits::min(); for (int i = wstart; i < wend; i++) { tmp = std::max(tmp, r0[i]); } - data_out_channel[w_even >> 1] = tmp; + dout_ch[w_even >> 1] = tmp; } } else { // two lines - data_out_channel[0] = - std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); + dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); #ifdef __aarch64__ w = 1; cnt = 1; - for (; w < win - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -1766,8 +1536,8 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&data_out_channel[cnt], vmax_123_345); - vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -1779,105 +1549,84 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, float32x2_t vmax2 = vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); vmax2 = vpmax_f32(vmax2, vmax2); - data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + dout_ch[cnt] = vget_lane_f32(vmax2, 0); cnt++; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = (r0 + 1); dr1 = (r1 + 1); - cnt_num = cnt_col; - cnt_num1 = remain; - if (cnt_num > 0 || cnt_num1 > 0) { + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " - "dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vmax.f32 q6, q0, q3 @max q0,q0,q2 " - "1234\n" - "vmax.f32 q7, q1, q4 @max q1,q1,q3 " - "5678\n" - "vmax.f32 q8, q2, q5 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, - // s6\n" - "vext.f32 q0, q6, q7, #1 @vext q0, " - "2345\n" - "vext.f32 q1, q7, q8, #1 @vext q1, " - "6789\n" - "vpmax.f32 d4, d12, d13 @pmax d4, " - "vmax_1234, vmax_1234\n" - "vpmax.f32 d6, d14, d15 @pmax d6, " - "vmax_5678, vmax_5678\n" - "vpmax.f32 d5, d0, d1 @pmax d5, " - "vmax_2345, vmax_2345\n" - "vpmax.f32 d7, d2, d3 @pmax d7, " - "vmax_6789, vmax_6789\n" - "vmax.f32 d8, d4, d5 @max d2, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d9, d6, d7 @max d2, " - "vmax_56_78, vmax_67_89\n" - "sub %[dr0], #16 @add w, 8\n" - "sub %[dr1], #16 @add w, 8\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "bne 1b @bne " - "s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vmov.f32 s7,s6 @movs7, s6\n" - "vmax.f32 q0, q0, q1 @max q0, q0, " - "q1\n" - "vpmax.f32 d0, d0, d1 @pmax d0, " - "d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0, d0, " - "d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs " - "cnt_num, #1\n" - "bne 2b @bne " - "s3_max_loop_bot_1\n" - "4: @exit\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n" + "vmax.f32 q8, q2, q5 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0,2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1,6789\n" + "vpmax.f32 d4, d12, d13 @pmax " + "d4,vmax_1234,vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax " + "d6,vmax_5678,vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax " + "d5,vmax_2345,vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax " + "d7,vmax_6789,vmax_6789\n" + "vmax.f32 d8, d4, d5 @max " + "d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max " + "d2,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); float tmp = r0[wstart]; // std::numeric_limits::min(); for (int i = wstart; i < wend; i++) { // only run 1 or 2 times tmp = std::max(tmp, std::max(r0[i], r1[i])); } - data_out_channel[w_even >> 1] = tmp; + dout_ch[w_even >> 1] = tmp; } } } @@ -1885,62 +1634,61 @@ void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, } } -void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, +void pooling3x3s2p1_avg(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { + bool exclusive) { + int kernel = 3; + int stride = 2; + int padding = 1; int size_channel_out = wout * hout; int size_channel_in = win * hin; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); - - int kernel_h = ksize[0]; - int kernel_w = ksize[1]; - int stride_h = strides[0]; - int stride_w = strides[1]; - int pad_h = paddings[0]; - int pad_w = paddings[1]; - int pad_top = pad_h; - int pad_left = pad_w; - int w_needed = wout * 2 + 1; - int h_needed = hout * 2 + 1; - int pad_right = w_needed - win - pad_left; - int pad_bottom = h_needed - hin - pad_top; - int w_even = (win >> 1) << 1; - int h_even = (hin >> 1) << 1; + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = (w_limit >> 1) << 1; + int h_even = (h_limit >> 1) << 1; + int w_unroll_size = ((w_even - 1) >> 3) << 3; + int w_unroll_remain = w_even - 1 - w_unroll_size; + int w_remain = w_needed - w_limit - padding; + int h_remain = h_needed - h_limit - padding; int w_in_2 = win << 1; - int w_unroll_size = (win - 1) / 8; - // remain - int w_unroll_remian = ((win - 1) % 8) / 2; - + const float coef = 1.f / 9.f; + const float coef_1 = exclusive ? 1.f : coef; + const float coef_2 = exclusive ? 1.f / 2.f : coef; + const float coef_3 = exclusive ? 1.f / 3.f : coef; + const float coef_4 = exclusive ? 1.f / 4.f : coef; + const float coef_6 = exclusive ? 1.f / 6.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_1 = vdupq_n_f32(coef_1); + float32x4_t vcoef_2 = vdupq_n_f32(coef_2); + float32x4_t vcoef_3 = vdupq_n_f32(coef_3); + float32x4_t vcoef_4 = vdupq_n_f32(coef_4); + float32x4_t vcoef_6 = vdupq_n_f32(coef_6); for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * chout * size_channel_out; - const float* data_in_batch = data_in + n * chin * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for for (int c = 0; c < chout; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; const float* r1 = r0 + win; const float* r2 = r1 + win; - int cnt_num = w_unroll_size; - int cnt_num1 = w_unroll_remian; - float* dr_out = data_out_channel; + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; const float* dr2 = r2; int w = 1; int cnt = 1; - float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); float32x4_t vzero = vdupq_n_f32(0.f); - data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; + dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; // first row with zero pad #ifdef __aarch64__ - for (; w < win - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -1965,8 +1713,8 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); - vst1q_f32(&data_out_channel[cnt], vrst); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); + vst1q_f32(&dout_ch[cnt], vrst); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -1978,113 +1726,102 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, float32x2_t vsum2 = vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); vsum2 = vpadd_f32(vsum2, vsum2); - float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); - data_out_channel[cnt] = vget_lane_f32(vrst, 0); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); + dout_ch[cnt] = vget_lane_f32(vrst, 0); cnt++; } #else dr0 = dr0 + 1; dr1 = dr1 + 1; dr_out = dr_out + 1; - // printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); - if (cnt_num > 0 || cnt_num1 > 0) { + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vadd.f32 q6, q0, q3 @max " - "r0_1234,r1_1234\n" - "vadd.f32 q7, q1, q4 @max " - "r0_5678,r1_5678\n" - "vadd.f32 q8, q2, q5 @max " - "r0_9101112,r1_9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, 2345 \n" - "vadd.f32 q5, q7, q1 @add 5678, 4567 \n" - "vadd.f32 q4, q4, q2 @add 3456, sum1 \n" - "vadd.f32 q5, q5, q3 @add 6789, sum2 \n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s21 @mov \n" - "vmov.f32 s19, s23 @mov \n" - "vmul.f32 q4, q4, %q[vcoef] @mul \n" - "sub %[dr0], #16 @add w, 8\n" - "sub %[dr1], #16 @add w, 8\n" - "subs %[cnt_num], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" - "bne 1b @bne s3_max_loop\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp cnt_num, " - "0\n" - "ble 4f @ble exit\n" - "2: @main loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0, q0, q1\n" - "vpadd.f32 d0, d0, d1 @padd d0, d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0, d0, d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "bne 2b @bne s3_max_loop_1\n" - "4: @exit\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q6, q0, q3 @max r0_1234,r1_1234\n" + "vadd.f32 q7, q1, q4 @max r0_5678,r1_5678\n" + "vadd.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @subs cnt_num,#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 2b @bne s3_max_loop_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), - [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_6] "+w"(vcoef_6), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); } -// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); #endif // int w = w_even - 1; - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); - float tmp = 0.f; // std::numeric_limits::min(); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; // std::numeric_limits::min(); + float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp += (r0[i] + r1[i]); + tmp1 += (r0[i] + r1[i]); } - data_out_channel[w_even >> 1] = tmp / 9.f; + dout_ch[w_even >> 1] = tmp1 * tmp2; // cnt ++; } r0 = r1; r1 = r0 + win; r2 = r1 + win; - data_out_channel += wout; + dout_ch += wout; int h = 2; for (; h < h_even; h += 2) { // deal with left pad float sum0 = r0[0] + r0[1]; float sum1 = r1[0] + r1[1]; float sum2 = r2[0] + r2[1]; - data_out_channel[0] = (sum0 + sum1 + sum2) / 9.f; + dout_ch[0] = (sum0 + sum1 + sum2) * coef_6; #ifdef __aarch64__ w = 1; cnt = 1; - for (; w < win - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -2116,7 +1853,7 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); - vst1q_f32(&data_out_channel[cnt], vrst); + vst1q_f32(&dout_ch[cnt], vrst); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -2131,141 +1868,115 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, float32x2_t vsum2 = vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); float32x2_t vsum = vpadd_f32(vsum2, vsum2); - data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f; + dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef; cnt++; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = (r0 + 1); dr1 = (r1 + 1); dr2 = (r2 + 1); - cnt_num = w_unroll_size; - cnt_num1 = w_unroll_remian; - if (cnt_num > 0 || cnt_num1 > 0) { + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " "dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" - "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" - "vadd.f32 q11, q2, q5 @max q1,q1,q3\n" - "vadd.f32 q6, q9, q6 @max q0,q0,q2 " - "1234\n" - "vadd.f32 q7, q10, q7 @max q1,q1,q3 " - "5678\n" - "vadd.f32 q8, q11, q8 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, 2345 " - "\n" - "vadd.f32 q5, q7, q1 @add 5678, 4567 " - "\n" - "vadd.f32 q4, q4, q2 @add 3456, sum1 " - "\n" - "vadd.f32 q5, q5, q3 @add 6789, sum2 " - "\n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s21 @mov \n" - "vmov.f32 s19, s23 @mov \n" - "vmul.f32 q4, q4, %q[vcoef] @mul \n" - "sub %[dr0], #16 @add w, 8\n" - "sub %[dr1], #16 @add w, 8\n" - "sub %[dr2], #16 @add w, 8\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "bne 1b @bne s3_max_loop_mid\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit1\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 q11, q2, q5 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n" + "vadd.f32 q8, q11, q8 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef] @mul\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "sub %[dr2], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit1\n" "2: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " - "dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" - "vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0, q0, " - "q1\n" - "vadd.f32 q0, q0, q2 @add q0, q0, " - "q1\n" - "vpadd.f32 d0, d0, d1 @padd d0, " - "d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0, d0, " - "d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "sub %[dr2], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "bne 2b @bne s3_max_loop_mid_1\n" - "4: @exit\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vadd.f32 q0, q0, q2 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), - [cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef), + [cnt_num_remain] "+r"(cnt_num_remain), [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), - "r"(cnt_num1) + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); - float tmp = 0.f; + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef; for (int i = wstart; i < wend; i++) { - tmp += (r0[i] + r1[i] + r2[i]); + tmp1 += (r0[i] + r1[i] + r2[i]); } - data_out_channel[w_even >> 1] = tmp / 9.f; + dout_ch[w_even >> 1] = tmp1 * tmp2; // cnt ++; } r0 = r2; r1 = r0 + win; r2 = r1 + win; - data_out_channel += wout; + dout_ch += wout; } - if (pad_bottom) { + if (h_remain > 0) { // deal with bottom pad // first row with zero pad - int hstart = (h >> 1) * stride_h - pad_h; - int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); - - if (hstart == hend - 1) { // only one lline - data_out_channel[0] = (r0[0] + r0[1]) / 9.f; + int hstart = (h >> 1) * stride - padding; + int hend = std::min(std::min(hstart + kernel, hin + padding), hin); + if (hstart == hend - 1) { // only one line + dout_ch[0] = (r0[0] + r0[1]) * coef_2; #ifdef __aarch64__ w = 1; cnt = 1; - for (; w < win - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vsum_1234 = vld1q_f32(&r0[w]); float32x4_t vsum_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vsum_9101112 = vld1q_f32(&r0[w + 8]); @@ -2284,8 +1995,8 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, vsum_123_345, 2); vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); - vst1q_f32(&data_out_channel[cnt], vrst); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_3); + vst1q_f32(&dout_ch[cnt], vrst); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -2293,97 +2004,79 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, vr0 = vsetq_lane_f32(0.f, vr0, 3); float32x2_t vsum = vpadd_f32(vget_low_f32(vr0), vget_high_f32(vr0)); vsum = vpadd_f32(vsum, vsum); - data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f; + dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef_3; cnt++; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = (r0 + 1); - cnt_num = w_unroll_size; - cnt_num1 = w_unroll_remian; - if (cnt_num > 0 || cnt_num1 > 0) { + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d12-d15}, [%[dr0]]! @load " - "d0-d3, dr0\n" - "vld1.f32 {d16-d17}, [%[dr0]]! @load " - "d0-d3, dr0\n" - "vext.f32 q0, q6, q7, #1 @vext " - "max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext " - "max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext " - "max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext " - "max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, " - "2345 \n" - "vadd.f32 q5, q7, q1 @add 5678, " - "4567 \n" - "vadd.f32 q4, q4, q2 @add 3456, " - "sum1 \n" - "vadd.f32 q5, q5, q3 @add 6789, " - "sum2 \n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s21 @mov \n" - "vmov.f32 s19, s23 @mov \n" - "vmul.f32 q4, q4, %q[vcoef] @mul \n" - "sub %[dr0], #16 @add w, 6\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vext.f32 q0, %q[vzero], q0, #3 @ ext " - "v0_0123\n" - "vpadd.f32 d0, d0, d1 @padd d0, " - "d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0, d0, " - "d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 2\n" - "subs %[cnt_num1], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d12-d15}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d16-d17}, [%[dr0]]! @load d0-d3,dr0\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_3] @mul\n" + "sub %[dr0], #16 @add w,6\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_3] @mul\n" + "sub %[dr0], #8 @add w,2\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), - [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_3] "+w"(vcoef_3), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); - float tmp = 0.f; + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (1.f * (wend - wstart)) : coef; for (int i = wstart; i < wend; i++) { - tmp += r0[i]; + tmp1 += r0[i]; } - data_out_channel[w_even >> 1] = tmp / 9.f; + dout_ch[w_even >> 1] = tmp1 * tmp2; } } else { // two lines - data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; + dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4; #ifdef __aarch64__ w = 1; cnt = 1; - for (; w < win - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -2408,8 +2101,8 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, vsum_123_345, 2); vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); - vst1q_f32(&data_out_channel[cnt], vrst); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); + vst1q_f32(&dout_ch[cnt], vrst); cnt += 4; } for (; w < w_even - 1; w += 2) { @@ -2421,112 +2114,85 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, float32x2_t vsum2 = vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); vsum2 = vpadd_f32(vsum2, vsum2); - float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); - data_out_channel[cnt] = vget_lane_f32(vrst, 0); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); + dout_ch[cnt] = vget_lane_f32(vrst, 0); cnt++; } #else - dr_out = data_out_channel + 1; + dr_out = dout_ch + 1; dr0 = (r0 + 1); dr1 = (r1 + 1); - cnt_num = w_unroll_size; - cnt_num1 = w_unroll_remian; - if (cnt_num > 0 || cnt_num1 > 0) { + cnt_num = w_unroll_size >> 3; + cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " - "dr0\n" - "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vmax.f32 q6, q0, q3 @max q0,q0,q2 " - "1234\n" - "vmax.f32 q7, q1, q4 @max q1,q1,q3 " - "5678\n" - "vmax.f32 q8, q2, q5 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, - // s6\n" - "vext.f32 q0, q6, q7, #1 @vext " - "max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext " - "max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext " - "max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext " - "max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, " - "2345 \n" - "vadd.f32 q5, q7, q1 @add 5678, " - "4567 \n" - "vadd.f32 q4, q4, q2 @add 3456, " - "sum1 \n" - "vadd.f32 q5, q5, q3 @add 6789, " - "sum2 \n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s21 @mov \n" - "vmov.f32 s19, s23 @mov \n" - "vmul.f32 q4, q4, %q[vcoef] @mul \n" - "sub %[dr0], #16 @add w, 8\n" - "sub %[dr1], #16 @add w, 8\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ ext " - "v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ ext " - "v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0, q0, " - "q1\n" - "vpadd.f32 d0, d0, d1 @padd d0, " - "d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0, d0, " - "d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q6, q0, q3 @add q0,q0,q2 1234\n" + "vadd.f32 q7, q1, q4 @add q1,q1,q3 5678\n" + "vadd.f32 q8, q2, q5 @add q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" + "sub %[dr0], #16 @add w,8\n" + "sub %[dr1], #16 @add w,8\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), - [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), + [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_6] "+w"(vcoef_6), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); - float tmp = 0.f; + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp += (r0[i] + r1[i]); + tmp1 += (r0[i] + r1[i]); } - data_out_channel[w_even >> 1] = tmp / 9.f; + dout_ch[w_even >> 1] = tmp1 * tmp2; } } } @@ -2534,87 +2200,61 @@ void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, } } -void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, - int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { - int w_in = win; - int h_in = hin; - int ch_in = chin; - - int w_out = wout; - int h_out = hout; - int ch_out = chout; - - int kernel_h = ksize[0]; - int kernel_w = ksize[1]; - int stride_h = strides[0]; - int stride_w = strides[1]; - int pad_h = paddings[0]; - int pad_w = paddings[1]; - - int size_channel_out = w_out * h_out; - int size_channel_in = w_in * h_in; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); +void pooling3x3s2p0_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win) { + int kernel = 3; + int stride = 2; + int padding = 0; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; - int pad_top = pad_h; - int pad_left = pad_w; - int w_needed = w_out * 2 + 1; - int h_needed = h_out * 2 + 1; - int pad_right = w_needed - w_in - pad_left; - int pad_bottom = h_needed - h_in - pad_top; - int w_even = ((w_in - 1) >> 1) << 1; - // int w_remains = w_in - w_even; // should be 0 or 1 - int h_even = ((h_in - 1) >> 1) << 1; - // int h_remains = h_in - h_even; // should be 0 or 1 - int w_unroll_size = w_in >> 3; - int w_unroll_remian = (w_in - w_unroll_size * 8 - 1) / 2; - int w_in_2 = w_in << 1; + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = ((w_limit - 1) >> 1) << 1; + int h_even = ((h_limit - 1) >> 1) << 1; + int w_unroll_size = (w_even >> 3) << 3; + int w_unroll_remain = w_even - w_unroll_size; + int w_remain = w_needed - w_limit; + int h_remain = h_needed - h_limit; + int w_in_2 = win << 1; float minval = std::numeric_limits::lowest(); - float32x4_t vzero = vdupq_n_f32(minval); // zero pad - // printf("minval: %.2f\n", minval); - for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * ch_out * size_channel_out; - const float* data_in_batch = data_in + n * ch_in * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int c = 0; c < ch_out; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; - const float* r1 = r0 + w_in; - const float* r2 = r1 + w_in; - int cnt_num = w_unroll_size; + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; // w = w_in - 8; - int cnt_num1 = w_unroll_remian; - float* dr_out = data_out_channel; + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; const float* dr2 = r2; int w = 0; int cnt = 0; - // data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], + // dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], // r1[1])); // first row with zero pad // r0 = r1; // r1 = r0 + w_in; // r2 = r1 + w_in; - // data_out_channel += w_out; + // dout_channel += w_out; int h = 0; for (; h < h_even; h += 2) { // deal with left pad float maxr0 = std::max(r0[0], r0[1]); float maxr1 = std::max(r1[0], r1[1]); float maxr2 = std::max(r2[0], r2[1]); -// data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); +// dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2); #ifdef __aarch64__ w = 0; cnt = 0; - for (; w < w_in - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -2642,11 +2282,11 @@ void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&data_out_channel[cnt], vmax_123_345); - vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); cnt += 4; } - for (; w < w_even - 1; w += 2) { + for (; w < w_even; w += 2) { float32x4_t vr0 = vld1q_f32(&r0[w]); float32x4_t vr1 = vld1q_f32(&r1[w]); float32x4_t vr2 = vld1q_f32(&r2[w]); @@ -2658,134 +2298,114 @@ void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, float32x2_t vmax2 = vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); float32x2_t vmax = vpmax_f32(vmax2, vmax2); - data_out_channel[cnt] = vget_lane_f32(vmax, 0); + dout_ch[cnt] = vget_lane_f32(vmax, 0); cnt++; } #else - dr_out = data_out_channel; // + 1; - dr0 = r0; // (r0 + 1); - dr1 = r1; // (r1 + 1); - dr2 = r2; // (r2 + 1); - cnt_num = w_unroll_size; - cnt_num1 = w_unroll_remian; - if (cnt_num > 0 || cnt_num1 > 0) { + dr_out = dout_ch; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + dr2 = r2; // (r2 + 1); + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n" - "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" - "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" - "vmax.f32 d22, d4, d10 @max q1,q1,q3\n" - "vmax.f32 q0, q9, q6 @max q0,q0,q2 " - "1234\n" - "vmax.f32 q3, q10, q7 @max q1,q1,q3 " - "5678\n" - "vmax.f32 d2, d22, d16 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q4, q0, q3, #1 @vext 2345\n" - "vext.f32 q2, q3, q1, #1 @vext 6789\n" - "vpmax.f32 d10, d0, d1 @pmax d10, " - "vmax_1234, vmax_1234\n" - "vpmax.f32 d12, d6, d7 @pmax d12, " - "vmax_5678, vmax_5678\n" - "vpmax.f32 d11, d8, d9 @pmax d11, " - "vmax_2345, vmax_2345\n" - "vpmax.f32 d13, d4, d5 @pmax d13, " - "vmax_6789, vmax_6789\n" - "vmax.f32 d0, d10, d11 @pmax d0, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d1, d12, d13 @pmax d1, " - "vmax_56_78, vmax_67_89\n" - "sub %[dr0], #8 @add w, 8\n" - "sub %[dr1], #8 @add w, 8\n" - "sub %[dr2], #8 @add w, 8\n" - "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7,dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 d22, d4, d10 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n" + "vmax.f32 d2, d22, d16 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax " + "d10,vmax_1234,vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax " + "d12,vmax_5678,vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax " + "d11,vmax_2345,vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax " + "d13,vmax_6789,vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax " + "d0,vmax_12_34,vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax " + "d1,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "sub %[dr2], #8 @add w,8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @cnt_num--\n" "bne 1b @bne s3_max_loop_mid\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit1\n" - "2: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " - "dr1\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vmov.f32 s7,s6 @movs7, s6\n" - "vmov.f32 s11,s10 @movs11, s10\n" - "vmax.f32 q0, q0, q1 @max q0, q0, " - "q1\n" - "vmax.f32 q0, q0, q2 @max q0, q0, " - "q2\n" - "vpmax.f32 d0, d0, d1 @pmax d0, " - "d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0, d0, " - "d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "sub %[dr2], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs cnt_num, " - "#1\n" - "bne 2b @bne s3_max_loop_mid_1\n" - "4: @exit\n" + "3: @loop\n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmov.f32 s11,s10 @movs11,s10\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vmax.f32 q0, q0, q2 @max q0,q0,q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), - [cnt_num1] "+r"(cnt_num1) + [cnt_num_remain] "+r"(cnt_num_remain) : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), - "r"(cnt_num1) + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); float tmp = r0[wstart]; // std::numeric_limits::min(); for (int i = wstart; i < wend; i++) { tmp = std::max(tmp, std::max(r0[i], r1[i])); tmp = std::max(tmp, r2[i]); } - data_out_channel[w_even >> 1] = tmp; + dout_ch[w_even >> 1] = tmp; // cnt ++; } r0 = r2; - r1 = r0 + w_in; - r2 = r1 + w_in; - data_out_channel += w_out; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; } - if (pad_bottom) { + if (h_remain > 0) { // deal with bottom pad // first row with zero pad // int hstart = (h >> 1) * stride_h - pad_h; -// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in); -// data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], +// int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); +// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], // r1[1])); #ifdef __aarch64__ w = 0; cnt = 0; - for (; w < w_in - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -2807,11 +2427,11 @@ void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); - vst1_f32(&data_out_channel[cnt], vmax_123_345); - vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + vst1_f32(&dout_ch[cnt], vmax_123_345); + vst1_f32(&dout_ch[cnt + 2], vmax_567_789); cnt += 4; } - for (; w < w_even - 1; w += 2) { + for (; w < w_even; w += 2) { float32x4_t vr0 = vld1q_f32(&r0[w]); float32x4_t vr1 = vld1q_f32(&r1[w]); vr0 = vsetq_lane_f32(minval, vr0, 3); @@ -2820,175 +2440,140 @@ void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, float32x2_t vmax2 = vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); vmax2 = vpmax_f32(vmax2, vmax2); - data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + dout_ch[cnt] = vget_lane_f32(vmax2, 0); cnt++; } #else - dr_out = data_out_channel; // + 1; - dr0 = r0; // (r0 + 1); - dr1 = r1; // (r1 + 1); - cnt_num = w_unroll_size; - cnt_num1 = w_unroll_remian; - if (cnt_num > 0 || cnt_num1 > 0) { + dr_out = dout_ch; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 3f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" - "vmax.f32 q6, q0, q3 @max q0,q0,q2 " - "1234\n" - "vmax.f32 q7, q1, q4 @max q1,q1,q3 " - "5678\n" - "vmax.f32 d16, d4, d10 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext q0, 2345\n" - "vext.f32 q1, q7, q8, #1 @vext q1, 6789\n" - "vpmax.f32 d4, d12, d13 @pmax d4, " - "vmax_1234, vmax_1234\n" - "vpmax.f32 d6, d14, d15 @pmax d6, " - "vmax_5678, vmax_5678\n" - "vpmax.f32 d5, d0, d1 @pmax d5, " - "vmax_2345, vmax_2345\n" - "vpmax.f32 d7, d2, d3 @pmax d7, " - "vmax_6789, vmax_6789\n" - "vmax.f32 d8, d4, d5 @max d2, " - "vmax_12_34, vmax_23_45\n" - "vmax.f32 d9, d6, d7 @max d2, " - "vmax_56_78, vmax_67_89\n" - "sub %[dr0], #8 @add w, 8\n" - "sub %[dr1], #8 @add w, 8\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "bne 1b @bne s3_max_loop_bot\n" - "3: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 4f @ble exit\n" - "2: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vmov.f32 s3,s2 @movs3, s2\n" - "vmov.f32 s7,s6 @movs7, s6\n" - "vmax.f32 q0, q0, q1 @max q0, q0, " - "q1\n" - "vpmax.f32 d0, d0, d1 @pmax d0, " - "d0,d1\n" - "vpmax.f32 d0, d0, d0 @pmax d0, d0, " - "d0\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs " - "cnt_num, #1\n" - "bne 2b @bne s3_max_loop_bot_1\n" - "4: @exit\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n" + "vmax.f32 d16, d4, d10 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7,s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0,2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1,6789\n" + "vpmax.f32 d4, d12, d13 @pmax " + "d4,vmax_1234,vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax " + "d6,vmax_5678,vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax " + "d5,vmax_2345,vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax " + "d7,vmax_6789,vmax_6789\n" + "vmax.f32 d8, d4, d5 @max " + "d2,vmax_12_34,vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max " + "d2,vmax_56_78,vmax_67_89\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "subs %[cnt_num], #1 @subs cnt_num,#1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vmov.f32 s3,s2 @movs3,s2\n" + "vmov.f32 s7,s6 @movs7,s6\n" + "vmax.f32 q0, q0, q1 @max q0,q0,q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num_remain) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); float tmp = r0[wstart]; // std::numeric_limits::min(); for (int i = wstart; i < wend; i++) { // only run 1 or 2 times tmp = std::max(tmp, std::max(r0[i], r1[i])); } - data_out_channel[w_even >> 1] = tmp; + dout_ch[w_even >> 1] = tmp; } } } } } -void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, +void pooling3x3s2p0_avg(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type) { - int w_in = win; - int h_in = hin; - int ch_in = chin; - - int w_out = wout; - int h_out = hout; - int ch_out = chout; - - int kernel_h = ksize[0]; - int kernel_w = ksize[1]; - int stride_h = strides[0]; - int stride_w = strides[1]; - int pad_h = paddings[0]; - int pad_w = paddings[1]; - - int size_channel_out = w_out * h_out; - int size_channel_in = w_in * h_in; - float* data_out = static_cast(dout); - const float* data_in = static_cast(din); + bool exclusive) { + int kernel = 3; + int stride = 2; + int padding = 0; + int size_channel_out = wout * hout; + int size_channel_in = win * hin; - int pad_top = pad_h; - int pad_left = pad_w; - int w_needed = w_out * 2 + 1; - int h_needed = h_out * 2 + 1; - int pad_right = w_needed - w_in - pad_left; - int pad_bottom = h_needed - h_in - pad_top; - int w_even = ((w_in - 1) >> 1) << 1; - int h_even = ((h_in - 1) >> 1) << 1; - int w_in_2 = w_in << 1; - int w_unroll_size = w_in >> 3; - int w_unroll_remian = (w_even - w_unroll_size * 8 - 1) / 2; + int w_needed = (wout << 1) + 1; + int h_needed = (hout << 1) + 1; + int w_limit = w_needed > win ? win : w_needed; + int h_limit = h_needed > hin ? hin : h_needed; + int w_even = ((w_limit - 1) >> 1) << 1; + int h_even = ((h_limit - 1) >> 1) << 1; + int w_unroll_size = (w_even >> 3) << 3; + int w_unroll_remain = w_even - w_unroll_size; + int w_remain = w_needed - w_limit; + int h_remain = h_needed - h_limit; + int w_in_2 = win << 1; + const float coef = 1.f / 9.f; + const float coef_6 = exclusive ? 1.f / 6.f : coef; + float32x4_t vcoef = vdupq_n_f32(coef); + float32x4_t vcoef_6 = vdupq_n_f32(coef_6); for (int n = 0; n < num; ++n) { - float* data_out_batch = data_out + n * ch_out * size_channel_out; - const float* data_in_batch = data_in + n * ch_in * size_channel_in; + float* dout_batch = dout + n * chout * size_channel_out; + const float* din_batch = din + n * chin * size_channel_in; #pragma omp parallel for - for (int c = 0; c < ch_out; c++) { - float* data_out_channel = data_out_batch + c * size_channel_out; - const float* data_in_channel = data_in_batch + c * size_channel_in; - const float* r0 = data_in_channel; - const float* r1 = r0 + w_in; - const float* r2 = r1 + w_in; - int cnt_num = w_unroll_size; + for (int c = 0; c < chout; c++) { + float* dout_ch = dout_batch + c * size_channel_out; + const float* din_ch = din_batch + c * size_channel_in; + const float* r0 = din_ch; + const float* r1 = r0 + win; + const float* r2 = r1 + win; // w = w_in - 8; - int cnt_num1 = w_unroll_remian; - float* dr_out = data_out_channel; + float* dr_out = dout_ch; const float* dr0 = r0; const float* dr1 = r1; const float* dr2 = r2; - float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); float32x4_t vzero = vdupq_n_f32(0.f); int h = 0; for (; h < h_even; h += 2) { -// LOG(INFO) << "h: " << h<<", dr0:" << r0 <<", dr1: "< 0 || cnt_num1 > 0) { + dr_out = dout_ch; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + dr2 = r2; // (r2 + 1); + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble loop3_ave_p0 @ble " - "exit\n" - "s3_ave_loop_mid_p0: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" - "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n" - "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" - "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" - "vadd.f32 d22, d4, d10 @max q1,q1,q3\n" - "vadd.f32 q6, q9, q6 @max q0,q0,q2 " - "1234\n" - "vadd.f32 q7, q10, q7 @max q1,q1,q3 " - "5678\n" - "vadd.f32 d16, d22, d16 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, 2345 " - "\n" - "vadd.f32 q5, q7, q1 @add 5678, 4567 " - "\n" - "vadd.f32 q4, q4, q2 @add 3456, sum1 " - "\n" - "vadd.f32 q5, q5, q3 @add 6789, sum2 " - "\n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s21 @mov \n" - "vmov.f32 s19, s23 @mov \n" - "vmul.f32 q4, q4, %q[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 8\n" - "sub %[dr1], #8 @add w, 8\n" - "sub %[dr2], #8 @add w, 8\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "bne s3_ave_loop_mid_p0 @bne " - "s3_max_loop_mid\n" - "loop3_ave_p0: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble exit1_ave_p0 @ble " - "exit1\n" - "s3_ave_loop_mid_1_p0: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " - "dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" - "vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0, q0, " - "q1\n" - "vadd.f32 q0, q0, q2 @add q0, q0, " - "q1\n" - "vpadd.f32 d0, d0, d1 @padd d0, " - "d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0, d0, " - "d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "sub %[dr2], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs cnt_num, " - "#1\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "bne s3_ave_loop_mid_1_p0 @bne " - "s3_max_loop_mid_1\n" - "exit1_ave_p0: @exit\n" + "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" + "ble loop3_ave_p0 @ble exit\n" + "s3_ave_loop_mid_p0: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, dr2\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr2\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 d22, d4, d10 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n" + "vadd.f32 d16, d22, d16 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef] @mul\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "sub %[dr2], #8 @add w,8\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne s3_ave_loop_mid_p0 @bne s3_max_loop_mid\n" + "loop3_ave_p0: @loop\n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n" + "ble exit1_ave_p0 @ble exit1\n" + "s3_ave_loop_mid_1_p0: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vadd.f32 q0, q0, q2 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "sub %[dr2], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne s3_ave_loop_mid_1_p0 @bne s3_max_loop_mid_1\n" + "exit1_ave_p0: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), - [cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef), + [cnt_num_remain] "+r"(cnt_num_remain), [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), - "r"(cnt_num1) + "r"(cnt_num_remain) : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); - float tmp = 0.f; - int pool_size = 3 * (wend - wstart); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef; for (int i = wstart; i < wend; i++) { - tmp += (r0[i] + r1[i] + r2[i]); + tmp1 += (r0[i] + r1[i] + r2[i]); } - data_out_channel[w_even >> 1] = tmp / pool_size; + dout_ch[w_even >> 1] = tmp1 * tmp2; // cnt ++; } r0 = r2; - r1 = r0 + w_in; - r2 = r1 + w_in; - data_out_channel += w_out; + r1 = r0 + win; + r2 = r1 + win; + dout_ch += wout; } - if (pad_bottom) { + if (h_remain > 0) { // deal with bottom pad // first row with zero pad // int hstart = (h >> 1) * stride_h - pad_h; -// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in); -// data_out_channel[0] =(r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; -#if 1 // def __aarch64__ +// int hend = std::min(std::min(hstart + kernel_h, hin + padding_h), +// hin); data_out_channel[0] =(r0[0] + r0[1] + r0[2] + r1[0] + r1[1] + +// r1[2]) / 9.f; +#ifdef __aarch64__ int w = 0; int cnt = 0; - vcoef = vdupq_n_f32(1.f / 6.f); - for (; w < w_in - 8; w += 8) { + for (; w < w_unroll_size; w += 8) { float32x4_t vr0_1234 = vld1q_f32(&r0[w]); float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); @@ -3196,11 +2754,11 @@ void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); - float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); - vst1q_f32(&data_out_channel[cnt], vrst); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6); + vst1q_f32(&dout_ch[cnt], vrst); cnt += 4; } - for (; w < w_even - 1; w += 2) { + for (; w < w_even; w += 2) { float32x4_t vr0 = vld1q_f32(&r0[w]); float32x4_t vr1 = vld1q_f32(&r1[w]); vr0 = vsetq_lane_f32(0.f, vr0, 3); @@ -3209,105 +2767,86 @@ void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, float32x2_t vsum2 = vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); vsum2 = vpadd_f32(vsum2, vsum2); - float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); - data_out_channel[cnt] = vget_lane_f32(vrst, 0); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6)); + dout_ch[cnt] = vget_lane_f32(vrst, 0); cnt++; } #else - dr_out = data_out_channel; // + 1; - dr0 = r0; // (r0 + 1); - dr1 = r1; // (r1 + 1); - cnt_num = w_unroll_size; - cnt_num1 = w_unroll_remian; - // LOG(INFO) << "dr0:" << dr0 <<", dr1: "< 0 || cnt_num1 > 0) { + dr_out = dout_ch; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + int cnt_num = w_unroll_size >> 3; + int cnt_num_remain = w_unroll_remain >> 1; + // LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " << + // cnt_num_remain; + if (cnt_num > 0 || cnt_num_remain > 0) { asm volatile( - "cmp %[cnt_num], #0 @cmp cnt_num, " - "0\n" - "ble 2f @ble exit\n" - "1: @main loop\n" - "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " - "dr0\n" - "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " - "dr1\n" - "vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n" - "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" - "vadd.f32 q6, q0, q3 @max q0,q0,q2 " - "1234\n" - "vadd.f32 q7, q1, q4 @max q1,q1,q3 " - "5678\n" - "vadd.f32 d16, d4, d10 @max q1,q1,q3 " - "9101112\n" - //"vmov.f32 s7,s6 @mov s7, s6\n" - "vext.f32 q0, q6, q7, #1 @vext max_2345\n" - "vext.f32 q1, q6, q7, #3 @vext max_4567\n" - "vext.f32 q2, q6, q7, #2 @vext max_3456\n" - "vext.f32 q3, q7, q8, #1 @vext max_6789\n" - "vadd.f32 q4, q6, q0 @add 1234, 2345 " - "\n" - "vadd.f32 q5, q7, q1 @add 5678, 4567 " - "\n" - "vadd.f32 q4, q4, q2 @add 3456, sum1 " - "\n" - "vadd.f32 q5, q5, q3 @add 6789, sum2 " - "\n" - "vmov.f32 s17, s18 @mov \n" - "vmov.f32 s18, s21 @mov \n" - "vmov.f32 s19, s23 @mov \n" - "vmul.f32 q4, q4, %q[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 8\n" - "sub %[dr1], #8 @add w, 8\n" - "subs %[cnt_num], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " - "dr_out\n" - "bne 1b @bne s3_max_loop_bot\n" - "2: @loop \n" - "cmp %[cnt_num1], #0 @cmp " - "cnt_num, 0\n" - "ble 3f @ble exit\n" - "4: @bot loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " - "dr0\n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " - "dr1\n" - "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" - "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" - "vadd.f32 q0, q0, q1 @add q0, q0, " - "q1\n" - "vpadd.f32 d0, d0, d1 @padd d0, " - "d0,d1\n" - "vpadd.f32 d0, d0, d0 @padd d0, d0, " - "d0\n" - "vmul.f32 d0, d0, %e[vcoef] @mul \n" - "sub %[dr0], #8 @add w, 6\n" - "sub %[dr1], #8 @add w, 6\n" - "subs %[cnt_num1], #1 @subs " - "cnt_num, #1\n" - "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " - "dr_out\n" - "bne 4b @bne s3_max_loop_bot_1\n" - "3: @exit\n" + "cmp %[cnt_num], #0 @cmp cnt_num,0\n" + "ble 2f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n" + "vadd.f32 q6, q0, q3 @max q0,q0,q2 1234\n" + "vadd.f32 q7, q1, q4 @max q1,q1,q3 5678\n" + "vadd.f32 d16, d4, d10 @max q1,q1,q3 9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234,2345\n" + "vadd.f32 q5, q7, q1 @add 5678,4567\n" + "vadd.f32 q4, q4, q2 @add 3456,sum1\n" + "vadd.f32 q5, q5, q3 @add 6789,sum2\n" + "vmov.f32 s17, s18 @mov\n" + "vmov.f32 s18, s21 @mov\n" + "vmov.f32 s19, s23 @mov\n" + "vmul.f32 q4, q4, %q[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,8\n" + "sub %[dr1], #8 @add w,8\n" + "subs %[cnt_num], #1 @cnt_num--\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "2: @loop\n" + "cmp %[cnt_num_remain], #0 @cmp cnt_num_remain, 0\n" + "ble 3f @ble exit\n" + "4: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0,q0,q1\n" + "vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n" + "vmul.f32 d0, d0, %e[vcoef_6] @mul\n" + "sub %[dr0], #8 @add w,6\n" + "sub %[dr1], #8 @add w,6\n" + "subs %[cnt_num_remain], #1 @cnt_num_remain--\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n" + "bne 4b @bne s3_max_loop_bot_1\n" + "3: @exit\n" : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), - [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) - : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + [cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain), + [vcoef_6] "+w"(vcoef_6), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num_remain) : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); } #endif - if (pad_right) { + if (w_remain > 0) { // deal with right pad - int wstart = (w_even >> 1) * stride_w - pad_w; - int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); - float tmp = 0.f; - int pool_size = 2 * (wend - wstart); + int wstart = (w_even >> 1) * stride - padding; + int wend = std::min(std::min(wstart + kernel, win + padding), win); + float tmp1 = 0.f; + float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef; for (int i = wstart; i < wend; i++) { // only run 1 or 2 times - tmp += (r0[i] + r1[i]); + tmp1 += (r0[i] + r1[i]); } - data_out_channel[w_even >> 1] = tmp / pool_size; + dout_ch[w_even >> 1] = tmp1 * tmp2; } } } diff --git a/paddle/fluid/lite/arm/math/pooling.h b/paddle/fluid/lite/arm/math/pooling.h index 36832187073c2d29a129a10fdd7984ba8d15db3d..b8ad0780dda9ae522c673ec8bd46cb5a9ed2adef 100644 --- a/paddle/fluid/lite/arm/math/pooling.h +++ b/paddle/fluid/lite/arm/math/pooling.h @@ -25,7 +25,7 @@ namespace arm { namespace math { // !pooling fp32 Op -void pooling_basic(const void* din, void* dout, int num, int chout, int hout, +void pooling_basic(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, const std::vector& ksize, const std::vector& strides, @@ -33,77 +33,39 @@ void pooling_basic(const void* din, void* dout, int num, int chout, int hout, bool exclusive, bool adaptive, bool ceil_mode, bool use_quantizer, const std::string& pooling_type); -void pooling_global(const void* din, void* dout, int num, int chout, int hout, - int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); +void pooling_global_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win); -void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, - int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); +void pooling_global_avg(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win); -void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, - int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); +void pooling2x2s2_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win); -void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, - int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); +void pooling2x2s2_avg(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + bool exclusive); -void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, - int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); +void pooling3x3s1p1_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win); -void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, +void pooling3x3s1p1_avg(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); + bool exclusive); -void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, - int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); +void pooling3x3s2p1_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win); -void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, +void pooling3x3s2p1_avg(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); + bool exclusive); + +void pooling3x3s2p0_max(const float* din, float* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win); -void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, +void pooling3x3s2p0_avg(const float* din, float* dout, int num, int chout, int hout, int wout, int chin, int hin, int win, - const std::vector& ksize, - const std::vector& strides, - const std::vector& paddings, bool global_pooling, - bool exclusive, bool adaptive, bool ceil_mode, - bool use_quantizer, const std::string& pooling_type); + bool exclusive); } // namespace math } // namespace arm diff --git a/paddle/fluid/lite/kernels/arm/pool_compute.cc b/paddle/fluid/lite/kernels/arm/pool_compute.cc index 3ee82ae6303f849a11d8685aae09b267bb991604..0b5eb6ac847b28be93e76de09cea8c2e31fcf9e2 100644 --- a/paddle/fluid/lite/kernels/arm/pool_compute.cc +++ b/paddle/fluid/lite/kernels/arm/pool_compute.cc @@ -48,120 +48,96 @@ void PoolCompute::Run() { bool use_quantizer = param.use_quantizer; std::string& data_format = param.data_format; - if (param.global_pooling) { + bool kps_equal = (ksize[0] == ksize[1]) && (strides[0] == strides[1]) && + (paddings[0] == paddings[1]); + + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(in_dims[i + 2]); } - } - -#if 0 - for (int i = 0; i < in_dims.size(); ++i) { - LOG(INFO) << "in_dims[" << i << "]:" << in_dims[i]; - } - for (int i = 0; i < out_dims.size(); ++i) { - LOG(INFO) << "out_dims[" << i << "]:" << out_dims[i]; - } - for (int i = 0; i < ksize.size(); ++i) { - LOG(INFO) << "ksize[" << i << "]:" << ksize[i]; - } - for (int i = 0; i < strides.size(); ++i) { - LOG(INFO) << "strides[" << i << "]:" << strides[i]; - } - for (int i = 0; i < paddings.size(); ++i) { - LOG(INFO) << "paddings[" << i << "]:" << paddings[i]; - } - LOG(INFO) << "global_pooling:" << global_pooling; - LOG(INFO) << "exclusive:" << exclusive; - LOG(INFO) << "adaptive:" << adaptive; - LOG(INFO) << "ceil_mode:" << ceil_mode; - LOG(INFO) << "use_quantizer:" << use_quantizer; - LOG(INFO) << "data_format:" << data_format; - LOG(INFO) << "din:" << din; - LOG(INFO) << "dout:" << dout; -#endif - - // global - if (global_pooling == true) { - lite::arm::math::pooling_global( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); - } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && - strides[0] == strides[1]) { if (pooling_type == "max") { - lite::arm::math::pooling2x2s2_max( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); + lite::arm::math::pooling_global_max(din, dout, out_dims[0], out_dims[1], + out_dims[2], out_dims[3], in_dims[1], + in_dims[2], in_dims[3]); + VLOG(3) << "invoking pooling_global_max"; + return; } else if (pooling_type == "avg") { - lite::arm::math::pooling2x2s2_ave( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); - } - } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 1 && - strides[0] == strides[1] && paddings[0] == 1) { - if (pooling_type == "max") { - lite::arm::math::pooling3x3s1p1_max( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); - } else if (pooling_type == "avg") { - lite::arm::math::pooling3x3s1p1_ave( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); - } - } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && - strides[0] == strides[1] && paddings[0] == 0) { - if (pooling_type == "max") { - lite::arm::math::pooling3x3s2p0_max( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); - } else if (pooling_type == "avg") { - lite::arm::math::pooling3x3s2p0_ave( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); - } - } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && - strides[0] == strides[1] && paddings[0] == 1) { - if (pooling_type == "max") { - lite::arm::math::pooling3x3s2p1_max( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); - } else if (pooling_type == "avg") { - lite::arm::math::pooling3x3s2p1_ave( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); + lite::arm::math::pooling_global_avg(din, dout, out_dims[0], out_dims[1], + out_dims[2], out_dims[3], in_dims[1], + in_dims[2], in_dims[3]); + VLOG(3) << "invoking pooling_global_ave"; + return; } } else { - lite::arm::math::pooling_basic( - din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], - in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, - global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, - pooling_type); + if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) { + if (pooling_type == "max") { + lite::arm::math::pooling2x2s2_max(din, dout, out_dims[0], out_dims[1], + out_dims[2], out_dims[3], in_dims[1], + in_dims[2], in_dims[3]); + VLOG(3) << "invoking pooling2x2s2_max"; + return; + } else if (pooling_type == "avg") { + lite::arm::math::pooling2x2s2_avg(din, dout, out_dims[0], out_dims[1], + out_dims[2], out_dims[3], in_dims[1], + in_dims[2], in_dims[3], exclusive); + VLOG(3) << "invoking pooling2x2s2_avg"; + return; + } + } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && + kps_equal) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s1p1_max(din, dout, out_dims[0], out_dims[1], + out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3]); + VLOG(3) << "invokingpooling3x3s1p1_max"; + return; + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s1p1_avg( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], exclusive); + VLOG(3) << "invoking pooling3x3s1p1_avg"; + return; + } + } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && + kps_equal) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s2p0_max(din, dout, out_dims[0], out_dims[1], + out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3]); + VLOG(3) << "pooling3x3s2p0_max"; + return; + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s2p0_avg( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], exclusive); + VLOG(3) << "invoking pooling3x3s2p0_avg"; + return; + } + } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && + kps_equal) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s2p1_max(din, dout, out_dims[0], out_dims[1], + out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3]); + VLOG(3) << "invoking pooling3x3s2p1_max"; + return; + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s2p1_avg( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], exclusive); + VLOG(3) << "invoking pooling3x3s2p1_avg"; + return; + } + } } - return; + lite::arm::math::pooling_basic( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], in_dims[1], + in_dims[2], in_dims[3], ksize, strides, paddings, global_pooling, + exclusive, adaptive, ceil_mode, use_quantizer, pooling_type); + VLOG(3) << "invoking pooling_basic"; } -TargetType PoolCompute::target() const { return TARGET(kARM); } - -PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); } - } // namespace arm } // namespace kernels } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/pool_compute.h b/paddle/fluid/lite/kernels/arm/pool_compute.h index 3a8b0f99c5b8292ec845f00383c4751079db2c77..1cb4e6db1bbc366ebc38760e4c347e70d3439040 100644 --- a/paddle/fluid/lite/kernels/arm/pool_compute.h +++ b/paddle/fluid/lite/kernels/arm/pool_compute.h @@ -29,9 +29,6 @@ class PoolCompute : public KernelLite { void PrepareForRun() override; void Run() override; - TargetType target() const override; - PrecisionType precision() const override; - virtual ~PoolCompute() = default; }; diff --git a/paddle/fluid/lite/kernels/arm/pool_compute_test.cc b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc index 399a1ccac5043155de578cb41bf6d0f625772fa1..8371568d2f0e224740aa925fcfe6e2093659e044 100644 --- a/paddle/fluid/lite/kernels/arm/pool_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc @@ -101,94 +101,65 @@ void pool_compute_ref(const operators::PoolParam& param) { int pad_w = paddings[1]; if (global_pooling == true) { - ksize[0] = in_h; - ksize[1] = in_w; - for (int n = 0; n < in_n; ++n) { for (int c = 0; c < in_c; ++c) { - const float* src = src_ptr + n * in_c * in_h * in_w + c * in_h * in_w; + const float* src = src_ptr + n * size_in_n + c * size_in_c; float res = src[0]; if (pooling_type == "max") { - for (int i = 1; i < in_h * in_w; ++i) { + for (int i = 1; i < size_in_c; ++i) { float cur_val = src[i]; res = cur_val > res ? cur_val : res; } } else if (pooling_type == "avg") { - for (int i = 1; i < in_h * in_w; ++i) { + for (int i = 1; i < size_in_c; ++i) { float cur_val = src[i]; res += cur_val; } - res /= (in_h * in_w); + res /= size_in_c; } - dst_ptr[n * in_c * out_h * out_w + c] = res; + dst_ptr[n * size_out_n + c] = res; } } - return; - } - - for (int ind_n = 0; ind_n < in_n; ++ind_n) { - for (int ind_c = 0; ind_c < in_c; ++ind_c) { - for (int ind_h = 0; ind_h < out_h; ++ind_h) { - int sh = ind_h * stride_h; - int eh = sh + window_h; - sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; - eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; - - for (int ind_w = 0; ind_w < out_w; ++ind_w) { - int sw = ind_w * stride_w; - int ew = sw + window_w; - sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; - ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; - - float result = static_cast(0); - - int dst_ind = - ind_n * size_out_n + ind_c * size_out_c + ind_h * out_w + ind_w; - - for (int kh = sh; kh < eh; ++kh) { - for (int kw = sw; kw < ew; ++kw) { - int src_ind = - ind_n * size_in_n + ind_c * size_in_c + kh * in_w + kw; - - if (kh == sh && kw == sw) { - result = src_ptr[src_ind]; - } else { - if (pooling_type == "max") { - result = - result >= src_ptr[src_ind] ? result : src_ptr[src_ind]; - } - if (pooling_type == "avg" && exclusive == false) { - // Pooling_average_include_padding - result += src_ptr[src_ind]; - } - if (pooling_type == "avg" && exclusive == true) { - // Pooling_average_include_padding - result += src_ptr[src_ind]; + } else { + for (int n = 0; n < in_n; ++n) { + for (int c = 0; c < in_c; ++c) { + for (int h = 0; h < out_h; ++h) { + int sh = h * stride_h; + int eh = sh + window_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; + for (int w = 0; w < out_w; ++w) { + int sw = w * stride_w; + int ew = sw + window_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; + int pooling_size = (ew - sw) * (eh - sh); + if (pooling_size == 0) continue; + float res = 0.f; + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw; + if (kh == sh && kw == sw) { + res = src_ptr[src_idx]; + } else { + if (pooling_type == "max") { + res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx]; + } + if (pooling_type == "avg") { + res += src_ptr[src_idx]; + } } } } - } - if (pooling_type == "avg" && exclusive == false) { - // Pooling_average_include_padding - // result /= param.window_h * param.window_w; - // LOG(ERROR)<<"cpu"<= in_w + pad_w ? in_w + pad_w : sw + window_w; - bw -= sw; - } - if (eh == in_h) { - bh = sh + window_h >= in_h + pad_h ? in_h + pad_h : sh + window_h; - bh -= sh; + if (pooling_type == "avg") { + if (exclusive) { + res /= pooling_size; + } else { + res /= window_h * window_w; + } } - result /= bh * bw; - } - if (pooling_type == "avg" && exclusive == true) { - // Pooling_average_exclude_padding - result /= (ew - sw) * (eh - sh); + dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res; } - dst_ptr[dst_ind] = result; } } } @@ -209,92 +180,96 @@ TEST(pool_arm, compute) { lite::Tensor output; lite::Tensor output_ref; - for (auto pooling_type : {"avg", "max"}) { - for (auto global_pooling : {true}) { - // for (auto ksize: {3}) { // TODO(yuanshuai): ksize enable 2, 3 - for (auto stride : {1, 2}) { - for (auto pad : {0, 1}) { - for (auto n : {1, 3, 4, 11}) { - for (auto c : {1, 3, 11 /* ,1024 */}) { // speedup for ci - for (auto h : {2, 3, 4, 11}) { - for (auto w : {2, 3, 4, 11}) { - LOG(INFO) << "n:" << n << " c:" << c << " h:" << h - << " w:" << w // << " ksize:" << ksize - << " stride:" << stride << " pad:" << pad - << " pooling_type:" << pooling_type - << " global_pooling:" << global_pooling; - - // init x, output - x.Resize(DDim(std::vector({n, c, h, w}))); - auto* x_data = x.mutable_data(); - for (int i = 0; i < x.dims().production(); ++i) { - x_data[i] = i; - } - - // fill param - param.x = &x; - param.output = &output; - param.pooling_type = pooling_type; - // param.ksize = {ksize, ksize}; //TODO(yuanshuai): ksize - // enable - param.ksize = {h, w}; - param.global_pooling = global_pooling; - param.strides = {stride, stride}; - param.paddings = {pad, pad}; - param.exclusive = true; - param.adaptive = false; - param.ceil_mode = false; - param.use_quantizer = false; - - const std::vector& output_shape = - compute_output_shape(¶m); - output.Resize(DDim(output_shape)); - output_ref.Resize(DDim(output_shape)); - - // compute - pool.SetParam(param); - pool.Run(); - - // compute ref - param.output = &output_ref; - pool_compute_ref(param); - - // compare - auto* output_data = output.mutable_data(); - auto* output_ref_data = output_ref.mutable_data(); - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); - float tmp = output_data[i] - output_ref_data[i]; - tmp = tmp < 0 ? -tmp : tmp; - if (tmp > 1e-5) { - std::cout << "output_data[0]:" << output_data[0] - << std::endl; - std::cout << "output_ref_data[0]:" << output_ref_data[0] - << std::endl; - std::cout - << "x.dims().production():" << x.dims().production() - << std::endl; - for (int ii = 0; ii < x.dims().production(); ++ii) { - std::cout << x_data[ii] << " "; + // speedup for ci + for (auto pooling_type : {"max", "avg"}) { + for (auto ceil_mode : {true, false}) { + for (auto global_pooling : {true, false}) { + for (auto exclusive : {true, false}) { + for (auto ksize : {2, 3}) { + for (auto stride : {1, 2}) { + for (auto pad : {0, 1}) { + for (auto n : {1, 2}) { + for (auto c : {1, 3}) { +#if 1 + for (auto h : {2, 3, 4, 11}) { + for (auto w : {2, 3, 4, 11}) { +#else + for (int h = 2; h < 25; h++) { + for (int w = 2; w < 25; w++) { +#endif + VLOG(3) << "n:" << n << " c:" << c << " h:" << h + << " w:" << w << " ksize:" << ksize + << " stride:" << stride << " pad:" << pad + << " exclusive:" << exclusive + << " global_pooling:" << global_pooling + << " ceil_mode: " << ceil_mode + << " pooling_type:" << pooling_type; + + // init x, output + x.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_data = x.mutable_data(); + for (int i = 0; i < x.dims().production(); ++i) { + float sign = i % 3 == 0 ? -0.03 : 0.05f; + x_data[i] = sign * (i % 128); + } + + // fill param + param.x = &x; + param.output = &output; + param.pooling_type = pooling_type; + if (global_pooling) { + param.ksize = {h, w}; + } else { + param.ksize = {ksize, ksize}; + } + param.global_pooling = global_pooling; + param.strides = {stride, stride}; + param.paddings = {pad, pad}; + param.exclusive = exclusive; + param.ceil_mode = ceil_mode; + param.adaptive = false; + param.use_quantizer = false; + + const std::vector& output_shape = + compute_output_shape(¶m); + output.Resize(DDim(output_shape)); + output_ref.Resize(DDim(output_shape)); + + auto* output_data = output.mutable_data(); + auto* output_ref_data = + output_ref.mutable_data(); + for (int i = 0; i < output.dims().production(); ++i) { + output_data[i] = -2; + output_ref_data[i] = -2; + } + + // compute + pool.SetParam(param); + pool.Run(); + + // compute ref + param.output = &output_ref; + pool_compute_ref(param); + + // compare + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4); + } + VLOG(3) << "compare pass"; } - std::cout; - exit(0); } } - - VLOG(3) << "compare pass"; } } } } - } // pad - } // stride - //} // ksize TODO(yuanshuai): ksize enable - } // global_pooling - } // pooling_type + } + } + } + } } -TEST(pool, retrive_op) { +TEST(pool_arm, retrive_op) { auto pool = KernelRegistry::Global().Create( "pool2d"); ASSERT_FALSE(pool.empty());