diff --git a/.gitignore b/.gitignore index 369fa1cb919c82caec326d1429c8a2eba3b928d6..fa01346094773845ba6f11e174774d2f08e47f77 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,10 @@ paddle/fluid/operators/distributed/send_recv.proto *.vs build/ build_doc/ +build.* *.user +*.sh +*.bkp .vscode .idea diff --git a/CMakeLists.txt b/CMakeLists.txt index 036a5faf24f24a50361e16b5810bfc7051f07118..4ef4a4c351e4b701f481b5b23076ea3535fa7231 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,7 +43,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(NOT DEFINED TARGET_ARCH_ABI) set(ARCH_ABI "arm64-v8a" CACHE STRING "Choose android platform") endif() - + include(cross_compiling/host) include(cross_compiling/armlinux) include(cross_compiling/android) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 8d3864c6b3da5500bb9017437c3cd16f06494abb..9c955103ba70fc087a267eb748c8db9a3e6e8e40 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_desc.h" -#include #include #include #include // NOLINT #include #include #include +#include "glog/logging.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index f53f6105d1bf8abdce928ad8fb8fc36ac79935c6..0cc786c024f6d7447ec57bb4a539ddf8bcdb1c25 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -32,9 +32,9 @@ void Run(const char* model_dir) { valid_places); auto* input_tensor = predictor.GetInput(0); - input_tensor->Resize(DDim(std::vector({100, 100}))); + input_tensor->Resize(DDim(std::vector({3, 224, 224}))); auto* data = input_tensor->mutable_data(); - for (int i = 0; i < 100 * 100; i++) { + for (int i = 0; i < 3 * 224 * 224; i++) { data[i] = i; } @@ -65,6 +65,14 @@ USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); +USE_LITE_OP(con2d); +// USE_LITE_OP(batch_norm); +USE_LITE_OP(relu); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(pool2d); +USE_LITE_OP(elementwise_add); +USE_LITE_OP(softmax); + USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); @@ -72,7 +80,15 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); + +USE_LITE_KERNEL(con2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(depthwise_con2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); + // USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); // USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); #endif // LITE_WITH_ARM diff --git a/paddle/fluid/lite/api/light_api.h b/paddle/fluid/lite/api/light_api.h index a43755c87387e6af4d65f541cf1ba61828f3d2a5..474e5da78bd2cd201b17f9a223bd1a177861a532 100644 --- a/paddle/fluid/lite/api/light_api.h +++ b/paddle/fluid/lite/api/light_api.h @@ -72,8 +72,9 @@ class LightPredictor { // Create the kernels of the target places, and filter out the specific // kernel with the target alias. - for (auto& op : program.ops()) { - auto kernel_type = op->op_info()->GetAttr(kKernelTypeAttr); + for (auto& op : program.ops_) { + lite::pb::OpDesc desc(op->op_info()->desc()); + auto kernel_type = desc.GetAttr(kKernelTypeAttr).get(); std::string op_type, alias; Place place; KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); @@ -88,8 +89,8 @@ class LightPredictor { insts.emplace_back(op, std::move(*it)); } program_.reset(new RuntimeProgram(std::move(insts))); - CHECK(program.exec_scope()); - program_->set_exec_scope(program.exec_scope()); + CHECK(program.exec_scope_); + program_->set_exec_scope(program.exec_scope_); } private: diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt index 8af2c33943f7e2abe7e539b04e3759e8e2d4da33..2a912e434ae60ab8be587d044541c4b8b464a435 100644 --- a/paddle/fluid/lite/arm/math/CMakeLists.txt +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -6,4 +6,31 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) return() endif() -cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc scale.cc elementwise.cc DEPS ${lite_kernel_deps} eigen3) +# TODO(xxx): seperate them +cc_library(math_arm SRCS + funcs.cc + packed_sgemm.cc + softmax.cc + scale.cc + pooling.cc + elementwise.cc + sgemv.cc + type_trans.cpp + conv_impl.cc + conv_direct_3x3s1.cc + conv_direct_3x3s2.cc + conv_direct.cc + conv_depthwise_3x3_int7.cc + conv_depthwise_3x3_int8.cc + conv_depthwise_5x5s1_int8.cc + conv_depthwise_3x3p0.cc + conv_depthwise_3x3p1.cc + conv_depthwise_5x5s1.cc + conv_depthwise_5x5s2.cc + conv_depthwise.cc + conv_gemmlike.cc + conv_winograd_3x3.cc + conv_winograd.cc + split.cc + DEPS ${lite_kernel_deps} eigen3) + diff --git a/paddle/fluid/lite/arm/math/pooling.cc b/paddle/fluid/lite/arm/math/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc916d0f37c14fa0fcbed1dc74dc8a0964bac05e --- /dev/null +++ b/paddle/fluid/lite/arm/math/pooling.cc @@ -0,0 +1,3347 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/pooling.h" +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +void pooling_basic(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + // no need to pad input tensor, border is zero pad inside this function + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + int size_channel_in = win * hin; + int size_channel_out = wout * hout; + + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + if (global_pooling) { + if (pooling_type == "max") { // Pooling_max + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* data_in_channel = + data_in_batch + c * size_channel_in; // in address + data_out_batch[c] = data_in_channel[0]; + for (int i = 0; i < size_channel_in; ++i) { + data_out_batch[c] = data_out_batch[c] > data_in_channel[i] + ? data_out_batch[c] + : data_in_channel[i]; + } + } + } + + } else if (pooling_type == "avg") { + // Pooling_average_include_padding + // Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* data_in_channel = + data_in_batch + c * size_channel_in; // in address + float sum = 0.f; + for (int i = 0; i < size_channel_in; ++i) { + sum += data_in_channel[i]; + } + data_out_batch[c] = sum / size_channel_in; + } + } + } else { + LOG(FATAL) << "not support"; + } + return; + } + + if (pooling_type == "max") { + // Pooling_max + for (int n = 0; n < num; ++n) { + float* data_out_channel = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int q = 0; q < chout; q++) { + float* data_out_row = data_out_channel + q * size_channel_out; + const float* data_in_channel = data_in_batch + q * size_channel_in; + + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + + data_out_row[j] = data_in_channel[hstart * win + wstart]; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + data_out_row[j] = data_out_row[j] > data_in_channel[h * win + w] + ? data_out_row[j] + : data_in_channel[h * win + w]; + } + } + } + data_out_row += wout; + } + } + } + } else if (pooling_type == "avg") { + if (exclusive == false) { + // Pooling_average_include_padding + for (int n = 0; n < num; ++n) { + int pool_size = + kernel_w * + kernel_h; // (hend - hstart) * (wend - wstart); // problem + float* data_out_channel = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int q = 0; q < chout; q++) { + float* data_out_row = data_out_channel + q * size_channel_out; + const float* data_in_channel = data_in_batch + q * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + + int bh = kernel_h; + int bw = kernel_w; + if (wend == win) { + bw = wstart + kernel_w >= win + pad_w ? win + pad_w + : wstart + kernel_w; + bw -= wstart; + } + if (hend == hin) { + bh = hstart + kernel_h >= hin + pad_h ? hin + pad_h + : hstart + kernel_h; + bh -= hstart; + } + pool_size = bh * bw; + + data_out_row[j] = data_in_channel[hstart * win + wstart]; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += data_in_channel[h * win + w]; + } + } + data_out_row[j] = sum / pool_size; + } + data_out_row += wout; + } + } + } + } else { // exclusive == true, Pooling_average_exclude_padding + for (int n = 0; n < num; ++n) { + float* data_out_channel = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int q = 0; q < chout; q++) { + float* data_out_row = data_out_channel + q * size_channel_out; + const float* data_in_channel = data_in_batch + q * size_channel_in; + for (int i = 0; i < hout; i++) { + for (int j = 0; j < wout; j++) { + int hstart = i * stride_h - pad_h; + int wstart = j * stride_w - pad_w; + int hend = std::min(hstart + kernel_h, hin + pad_h); + int wend = std::min(wstart + kernel_w, win + pad_w); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, hin); + wend = std::min(wend, win); + + data_out_row[j] = data_in_channel[hstart * win + wstart]; + float sum = 0.f; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + sum += data_in_channel[h * win + w]; + } + } + int pool_size = (hend - hstart) * (wend - wstart); + data_out_row[j] = sum / pool_size; + } + data_out_row += wout; + } + } + } + } + + } else { + LOG(FATAL) << "not support"; + } +} + +void pooling_global(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int cnt = size_channel_in / 8; + +#if 0 + LOG(INFO) << "size_channel_in:" << size_channel_in; + LOG(INFO) << "cnt:" << cnt; + LOG(INFO) << "num:" << num; + LOG(INFO) << "chout:" << chout; + LOG(INFO) << "hout:" << hout; + LOG(INFO) << "wout:" << wout; + + LOG(INFO) << "chin:" << chin; + LOG(INFO) << "hin:" << hin; + LOG(INFO) << "win:" << win; + LOG(INFO) << "pooling_type " << pooling_type; +#endif + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout; + const float* data_in_batch = data_in + n * chin * size_channel_in; + if (pooling_type == "max") { +#pragma omp parallel for + for (int c = 0; c < chout; ++c) { + const float* data_in_channel = data_in_batch + c * size_channel_in; + int i = 0; + float minval = std::numeric_limits::lowest(); + float32x4_t vmax = vdupq_n_f32(minval); +#ifdef __aarch64__ + for (; i < cnt; i++) { + float32x4_t vdin1 = vld1q_f32(data_in_channel); + vmax = vmaxq_f32(vdin1, vmax); + float32x4_t vdin2 = vld1q_f32(data_in_channel + 4); + vmax = vmaxq_f32(vmax, vdin2); + data_in_channel += 8; + } +#else + int num = cnt; + if (num > 0) { + asm volatile( + "max_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, " + "data_in_channel\n" + "vmax.f32 %q[vmax], %q[vmax], q0 @max vmax, " + "vmax, data_in_channel\n" + "vld1.f32 {d2-d3}, [%[data_in_channel]]! @ load 2nd 4 " + "data" + "vmax.f32 %q[vmax], %q[vmax], q1 @ compare 2nd " + "4 datas\n" + "subs %[num], #1 @subs num, 1\n" + "bne max_loop @bne num\n" + : [data_in_channel] "+r"(data_in_channel), [num] "+r"(num), + [vmax] "+w"(vmax) + : + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ + float32x2_t vmax_tmp = + vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax)); + float tmp1 = vget_lane_f32(vmax_tmp, 0); + float tmp2 = vget_lane_f32(vmax_tmp, 1); + float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2; + for (i = cnt * 8; i < size_channel_in; ++i) { + /* code */ + max_tmp = max_tmp > data_in_channel[0] ? max_tmp : data_in_channel[0]; + data_in_channel++; + } + data_out_batch[c] = max_tmp; + } + } else { +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + const float* data_in_channel = + data_in_batch + c * size_channel_in; // in address + int i = 0; + float32x4_t vsum = vdupq_n_f32(0.0f); +#ifdef __aarch64__ + for (; i < cnt; i++) { // + vsum = vaddq_f32(vld1q_f32(data_in_channel), vsum); + data_in_channel += 4; + } +#else + int num = cnt; + if (num > 0) { + asm volatile( + "add_loop: @main loop\n" + "vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, " + "data_in_channel\n" + "vadd.f32 %q[vsum], %q[vsum], q0 @add vmax, " + "vmax, data_in_channel\n" + "subs %[num], #1 @subs num, 1\n" + "bne add_loop @bne num\n" + : [data_in_channel] "+r"(data_in_channel), [num] "+r"(num), + [vsum] "+w"(vsum) + : + : "cc", "memory", "q0"); + } +#endif // __aarch64__ + float32x2_t vsum_tmp = + vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum)); + float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1); + for (i = cnt * 4; i < size_channel_in; i++) { + sum += data_in_channel[0]; + data_in_channel++; + } + data_out_batch[c] = sum / size_channel_in; + } + } + } +} + +void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (win >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = (hin >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remian = w_even - w_unroll_size; + int w_in_2 = win << 1; + float32x4_t vzero = vdupq_n_f32(0.f); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + int h = 0; + for (; h < h_even; h += 2) { + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); + float32x4_t dr10 = vld1q_f32(&r1[w]); + float32x4_t dr11 = vld1q_f32(&r1[w + 4]); + float32x4_t dmax1 = vmaxq_f32(dr00, dr10); + float32x4_t dmax2 = vmaxq_f32(dr01, dr11); +#ifdef __aarch64__ + float32x4_t dmax = vpmaxq_f32(dmax1, dmax2); +#else + float32x2_t dmaxl = + vpmax_f32(vget_low_f32(dmax1), vget_high_f32(dmax1)); + float32x2_t dmaxh = + vpmax_f32(vget_low_f32(dmax2), vget_high_f32(dmax2)); + float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); +#endif + vst1q_f32(&data_out_channel[w >> 1], dmax); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + const float* dr0 = r0; + const float* dr1 = r1; + float* dr_out = data_out_channel; + if (num > 0) { + asm volatile( + "s2_max_loop: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @load q1, dr1\n" + "vmax.f32 q0, q0, q2 @max q0, q0, " + "q2\n" + "vmax.f32 q1, q1, q3 @max q1, q1, " + "q2\n" + "vpmax.f32 d4, d0, d1 @max d4, d0, " + "d1\n" + "vpmax.f32 d5, d2, d3 @max d5, d2, " + "d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, " + "dr_out\n" + "subs %[num], #1 @subs num, 1\n" + "bne s2_max_loop @bne num\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [num] "+r"(num) + : + : "cc", "memory", "q0", "q1", "q2", "q3"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = + std::max(std::max(r0[w], r0[w + 1]), std::max(r1[w], r1[w + 1])); + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = std::max(r0[w], r1[w]); + } + r0 += w_in_2; // << 1; + r1 += w_in_2; // << 1; + data_out_channel += wout; + } + // process remain row (odd, last row) + for (; h < hin; h++) { // run 0 or 1 time + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); +#ifdef __aarch64__ + float32x4_t dmax = vpmaxq_f32(dr00, dr01); +#else + float32x2_t dmaxl = + vpmax_f32(vget_low_f32(dr00), vget_high_f32(dr00)); + float32x2_t dmaxh = + vpmax_f32(vget_low_f32(dr01), vget_high_f32(dr01)); + float32x4_t dmax = vcombine_f32(dmaxl, dmaxh); +#endif + float32x4_t dmax_cmp_zero = vmaxq_f32(dmax, vzero); + vst1q_f32(&data_out_channel[w >> 1], dmax_cmp_zero); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + const float* dr0 = r0; + float* dr_out = data_out_channel; + if (num > 0) { + asm volatile( + "s2_max_loop1: @main " + "loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n" + "vpmax.f32 d4, d0, d1 @max d4, d0, " + "d1\n" + "vpmax.f32 d5, d2, d3 @max d5, d2, " + "d3\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, " + "dr_out\n" + "subs %[num], #1 @subs num, 1\n" + "bne s2_max_loop1 @bne num\n" + : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [num] "+r"(num) + : + : "cc", "memory", "q0", "q1", "q2"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = std::max(std::max(r0[w], r0[w + 1]), 0.f); + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = std::max(r0[w], 0.f); + } + } + } + } +} + +void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (win >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = (hin >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remian = w_even - w_unroll_size; + int w_in_2 = win << 1; + float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4 + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + int h = 0; + for (; h < h_even; h += 2) { + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); + float32x4_t dr10 = vld1q_f32(&r1[w]); + float32x4_t dr11 = vld1q_f32(&r1[w + 4]); + float32x4_t dsum1 = vaddq_f32(dr00, dr10); + float32x4_t dsum2 = vaddq_f32(dr01, dr11); +#ifdef __aarch64__ + float32x4_t dsum = vpaddq_f32(dsum1, dsum2); +#else + float32x2_t dsuml = + vpadd_f32(vget_low_f32(dsum1), vget_high_f32(dsum1)); + float32x2_t dsumh = + vpadd_f32(vget_low_f32(dsum2), vget_high_f32(dsum2)); + float32x4_t dsum = vcombine_f32(dsuml, dsumh); +#endif + float32x4_t res = vmulq_f32(dsum, vcoef); + vst1q_f32(&data_out_channel[w >> 1], res); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + const float* dr0 = r0; + const float* dr1 = r1; + float* dr_out = data_out_channel; + + if (num > 0) { + asm volatile( + "1: @ main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, " + "dr0\n" + "vld1.f32 {d4-d7}, [%[dr1]]! @ load q1, " + "dr1\n" + "vadd.f32 q0, q0, q2 @ add q0, q0, " + "q2\n" + "vadd.f32 q1, q1, q3 @ add q1, q1, " + "q2\n" + "vpadd.f32 d4, d0, d1 @ add d4, d0, " + "d1\n" + "vpadd.f32 d5, d2, d3 @ add d5, d2, " + "d3\n" + "vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, " + "vcoef\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, " + "dr_out\n" + "subs %[num], #1 @ subs num, 1\n" + "bne 1b @ bne num\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [vcoef] "+w"(vcoef), [num] "+r"(num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(num), "w"(vcoef) + : "cc", "memory", "q0", "q1", "q2", "q3"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = + (r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) / 4.f; + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = (r0[w] + r1[w]) / 4.f; + } + r0 += w_in_2; // << 1; + r1 += w_in_2; // << 1; + data_out_channel += wout; + } + // process remain row (odd, last row) + for (; h < hin; h++) { // run 0 or 1 time + int w = 0; +#ifdef __aarch64__ + for (; w < w_unroll_size; w += 8) { + float32x4_t dr00 = vld1q_f32(&r0[w]); + float32x4_t dr01 = vld1q_f32(&r0[w + 4]); +#ifdef __aarch64__ + float32x4_t dsum = vpaddq_f32(dr00, dr01); +#else + float32x2_t dsuml = + vpadd_f32(vget_low_f32(dr00), vget_high_f32(dr00)); + float32x2_t dsumh = + vpadd_f32(vget_low_f32(dr01), vget_high_f32(dr01)); + float32x4_t dsum = vcombine_f32(dsuml, dsumh); +#endif + float32x4_t res = vmulq_f32(dsum, vcoef); + vst1q_f32(&data_out_channel[w >> 1], res); + } +#else + w = w_unroll_size; + int num = w_unroll_size >> 3; + const float* dr0 = r0; + float* dr_out = data_out_channel; + + if (num > 0) { + asm volatile( + "1: @ main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, " + "dr0\n" + "vpadd.f32 d4, d0, d1 @ add d4, d0, " + "d1\n" + "vpadd.f32 d5, d2, d3 @ add d5, d2, " + "d3\n" + "vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, " + "vcoef\n" + "vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, " + "dr_out\n" + "subs %[num], #1 @ subs num, 1\n" + "bne 1b @ bne num\n" + : [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [vcoef] "+w"(vcoef), + [num] "+r"(num) + : "r"(dr0), "r"(dr_out), "r"(num), "w"(vcoef) + : "cc", "memory", "q0", "q1", "q2"); + } +#endif // __aarch64__ + for (; w < w_even; w += 2) { + data_out_channel[w >> 1] = (r0[w] + r0[w + 1]) / 4.f; + } + for (; w < win; ++w) { // run 0 or 1 time + data_out_channel[w >> 1] = r0[w] / 4.f; + } + } + } + } +} + +void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + // no need to pad input tensor, pad_size is not used, default border is zero + // padded + int ch_in = chin; + int h_in = hin; + int w_in = win; + + int ch_out = chout; + int h_out = hout; + int w_out = wout; + + int size_channel_out = w_out * h_out; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (w_in >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = (h_in >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + // int w_unroll_size = (w_even >> 3) << 3; + // int w_unroll_remian = w_even - w_unroll_size; + int w_in_2 = w_in << 1; + int w_unroll_size = (w_in - 2) >> 2; + int w_unroll_remian = w_in - 2 - w_unroll_size * 4; + float minval = std::numeric_limits::lowest(); + float32x4_t vzero = vdupq_n_f32(minval); // zero pad + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; // w_in / 4 + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 1; + // left + data_out_channel[0] = + std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +// first row with zero pad +#ifdef __aarch64__ + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&data_out_channel[cnt], vmax); + cnt += 4; + } + +#else + dr_out = dr_out + 1; + + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vmax.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4, " + "max_1234, max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4, " + "max_2345, max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6, " + "max_3456, max_3456\n" + "vmax.f32 d8, d2, d3 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2, " + "vmax_23_45, vmax_34_56\n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + // swap + "vmov.f32 s0, s17 @mov \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s0 @mov \n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } + +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + data_out_channel[j + w + 1] = tmp_max; + } + // right + float tmp = std::max(r0[w_in - 2], r1[w_in - 2]); + tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); + data_out_channel[w_out - 1] = tmp; + + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + data_out_channel += w_out; + int h = 0; + for (; h < h_in - 2; h += 1) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); + data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&data_out_channel[cnt], vmax); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + dr2 = r2; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main " + "loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" + "vmax.f32 q7, q0, q2 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d16, d2, d6 @max " + "r0_5678,r1_5678\n" + "vmax.f32 q3, q7, q4 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d12, d16, d10 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d6, d7 @pmax d4, " + "max_1234, max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4, " + "max_2345, max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6, " + "max_3456, max_3456\n" + "vmax.f32 d8, d2, d3 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2, " + "vmax_23_45, vmax_34_56\n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "sub %[dr2], #8 @sub w, 8\n" + // swap + "vmov.f32 s0, s17 @mov \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s0 @mov \n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @ bne " + "s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + tmp_max = std::max(tmp_max, std::max(r2[j + w], r2[j + w + 1])); + tmp_max = std::max(tmp_max, r2[j + w + 2]); + data_out_channel[j + w + 1] = tmp_max; + } + // right + tmp = std::max(r0[w_in - 2], r1[w_in - 2]); + tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); + tmp = std::max(tmp, std::max(r2[w_in - 2], r2[w_in - 1])); + data_out_channel[w_out - 1] = tmp; + + r0 = r1; + r1 = r2; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + // the last two line + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + data_out_channel[0] = std::max(maxr0, maxr1); +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_34_56 = + vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56); + float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0)); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2); + vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3); + vst1q_f32(&data_out_channel[cnt], vmax); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vmax.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vmax.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vpmax.f32 d2, d10, d11 @pmax d4, " + "max_1234, max_1234\n" + "vpmax.f32 d3, d0, d1 @pmax d4, " + "max_2345, max_2345\n" + "vpmax.f32 d6, d4, d5 @pmax d6, " + "max_3456, max_3456\n" + "vmax.f32 d8, d2, d3 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d3, d6 @max d2, " + "vmax_23_45, vmax_34_56\n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + // swap + "vmov.f32 s0, s17 @mov \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s0 @mov \n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_max = std::max(r0[j + w], r1[j + w]); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1])); + tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2])); + data_out_channel[j + w + 1] = tmp_max; + } + tmp = std::max(r0[w_in - 2], r1[w_in - 2]); + tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1])); + data_out_channel[w_out - 1] = tmp; + } + } +} + +void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + + int size_channel_out = w_out * h_out; + int size_channel_in = w_in * h_in; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int w_even = (w_in >> 1) << 1; + int h_even = (h_in >> 1) << 1; + int w_in_2 = w_in << 1; + int w_unroll_size = (w_in - 2) >> 2; + int w_unroll_remian = w_in - 2 - w_unroll_size * 4; + float32x4_t vzero = vdupq_n_f32(0.f); // zero pad + float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); // zero pad + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; // w_in / 4 + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 1; + // left + data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +// first row with zero pad +#ifdef __aarch64__ + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef); + vst1q_f32(&data_out_channel[cnt], vsum); + cnt += 4; + } + +#else + dr_out = dr_out + 1; + + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vadd.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234 + 2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } + +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + data_out_channel[j + w + 1] = tmp_sum / 9.f; + } + // right + float tmp = r0[w_in - 2] + r1[w_in - 2]; + tmp += (r0[w_in - 1] + r1[w_in - 1]); + data_out_channel[w_out - 1] = tmp / 9.f; + + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + data_out_channel += w_out; + int h = 0; + for (; h < h_in - 2; h += 1) { + // deal with left pad + float maxr0 = r0[0] + r0[1]; + float maxr1 = r1[0] + r1[1]; + float maxr2 = r2[0] + r2[1]; + data_out_channel[0] = (maxr0 + maxr1 + maxr2) / 9.f; +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef); + vst1q_f32(&data_out_channel[cnt], vsum); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + dr2 = r2; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n" + "vadd.f32 q7, q0, q2 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d16, d2, d6 @max " + "r0_5678,r1_5678\n" + "vadd.f32 q3, q7, q4 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d12, d16, d10 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q3, q6, #1 @vext max_2345\n" + "vext.f32 q2, q3, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q3, q0 @add 1234 + " + "2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "sub %[dr2], #8 @sub w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne " + "s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [vcoef] "+w"(vcoef) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + tmp_sum += (r2[j + w + 1] + r2[j + w + 2]); + tmp_sum += r2[j + w]; + data_out_channel[j + w + 1] = tmp_sum / 9.f; + } + // right + tmp = r0[w_in - 2] + r1[w_in - 2]; + tmp += (r0[w_in - 1] + r1[w_in - 1]); + tmp += (r2[w_in - 2] + r2[w_in - 1]); + data_out_channel[w_out - 1] = tmp / 9.f; + + r0 = r1; + r1 = r2; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + // the last two line + float maxr0 = (r0[0] + r0[1]); + float maxr1 = (r1[0] + r1[1]); + data_out_channel[0] = (maxr0 + maxr1) / 9.f; +#ifdef __aarch64__ + w = 0; + cnt = 1; + for (; w <= w_in - 6; w += 4) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345); + vsum = vaddq_f32(vsum, vsum_3456); + vsum = vmulq_f32(vsum, vcoef); + vst1q_f32(&data_out_channel[cnt], vsum); + cnt += 4; + } +#else + dr_out = data_out_channel + 1; + dr0 = r0; + dr1 = r1; + cnt_num = w_unroll_size; + if (cnt_num > 0) { + asm volatile( + "1: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n" + "vadd.f32 q5, q0, q2 @max " + "r0_1234,r1_1234\n" + "vadd.f32 d12, d2, d6 @max " + "r0_5678,r1_5678\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q5, q6, #1 @vext max_2345\n" + "vext.f32 q2, q5, q6, #2 @vext max_3456\n" + "vadd.f32 q1, q5, q0 @add 1234 + 2345\n" + "vadd.f32 q1, q1, q2 @add + 3456\n" + "vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n" + "sub %[dr0], #8 @sub w, 8\n" + "sub %[dr1], #8 @sub w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s1_max_loop\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + } +#endif + // remian + w = w_unroll_size * 4; + for (int j = 0; j < w_unroll_remian; j++) { + float tmp_sum = r0[j + w] + r1[j + w]; + tmp_sum += (r0[j + w + 1] + r1[j + w + 1]); + tmp_sum += (r0[j + w + 2] + r1[j + w + 2]); + data_out_channel[j + w + 1] = tmp_sum / 9.f; + } + // right + tmp = r0[w_in - 2] + r1[w_in - 2]; + tmp += (r0[w_in - 1] + r1[w_in - 1]); + data_out_channel[w_out - 1] = tmp / 9.f; + } + } +} + +void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = wout * 2 + 1; + int h_needed = hout * 2 + 1; + int pad_right = w_needed - win - pad_left; + int pad_bottom = h_needed - hin - pad_top; + int w_even = (win >> 1) << 1; + int h_even = (hin >> 1) << 1; + int w_in_2 = win << 1; + float minval = std::numeric_limits::lowest(); + float32x4_t vzero = vdupq_n_f32(minval); // zero pad + int cnt_col = (win - 1) / 8; + // remain + int remain = ((win - 1) % 8) / 2; + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 1; + int cnt = 1; + int cnt_num = cnt_col; + int cnt_num1 = remain; + data_out_channel[0] = + std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +// first row with zero pad +#ifdef __aarch64__ + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr0 = dr0 + 1; + dr1 = dr1 + 1; + dr_out = dr_out + 1; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q6, q0, q3 @max " + "r0_1234,r1_1234\n" + "vmax.f32 q7, q1, q4 @max " + "r0_5678,r1_5678\n" + "vmax.f32 q8, q2, q5 @max " + "r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q7, q8, #1 @vext max_6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp cnt_num, " + "0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0, q0, q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0, d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne " + "s3_max_loop_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9"); + } +// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); +#endif + // int w = w_even - 1; + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + data_out_channel[w_even >> 1] = tmp; + // cnt ++; + } + + r0 = r1; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + int h = 2; + for (; h < h_even; h += 2) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); + data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + dr2 = (r2 + 1); + cnt_num = cnt_col; + cnt_num1 = remain; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 q11, q2, q5 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 " + "5678\n" + "vmax.f32 q1, q11, q8 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax d10, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax d12, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax d11, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax d13, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax d0, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax d1, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "sub %[dr2], #16 @add w, 8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne " + "s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmov.f32 s11,s10 @movs11, s10\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vmax.f32 q0, q0, q2 @max q0, q0, " + "q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "bne 2b @bne " + "s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(r0[i], r1[i])); + tmp = std::max(tmp, r2[i]); + } + data_out_channel[w_even >> 1] = tmp; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + } + + if (pad_bottom) { + // deal with bottom pad + // first row with zero pad + int hstart = (h >> 1) * stride_h - pad_h; + int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); + + if (hstart == hend - 1) { // only one lline + data_out_channel[0] = std::max(r0[0], r0[1]); +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vmax_1234 = vld1q_f32(&r0[w]); + float32x4_t vmax_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vmax_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + float32x2_t vmax = vpmax_f32(vget_low_f32(vr0), vget_high_f32(vr0)); + vmax = vpmax_f32(vmax, vmax); + data_out_channel[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + cnt_num = cnt_col; + cnt_num1 = remain; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vext.f32 q4, q0, q1, #1 @vext q4, q0, " + "q1, 1 2345\n" + "vext.f32 q5, q1, q2, #1 @vext q5, q0, " + "q1, 1 6789\n" + "vpmax.f32 d12, d0, d1 @pmax d12, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d14, d2, d3 @pmax d14, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d13, d8, d9 @pmax d13, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d15, d10, d11 @pmax d15, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d0, d12, d13 @max d0, " + "vmax_12_34,vmax_23_45\n" + "vmax.f32 d1, d14, d15 @pmax d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 6\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne " + "s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 2\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne " + "s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, r0[i]); + } + data_out_channel[w_even >> 1] = tmp; + } + } else { // two lines + data_out_channel[0] = + std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1])); +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + cnt_num = cnt_col; + cnt_num1 = remain; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vmax.f32 q8, q2, q5 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, + // s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0, " + "2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1, " + "6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne " + "s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne " + "s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + data_out_channel[w_even >> 1] = tmp; + } + } + } + } + } +} + +void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = wout * 2 + 1; + int h_needed = hout * 2 + 1; + int pad_right = w_needed - win - pad_left; + int pad_bottom = h_needed - hin - pad_top; + int w_even = (win >> 1) << 1; + int h_even = (hin >> 1) << 1; + int w_in_2 = win << 1; + int w_unroll_size = (win - 1) / 8; + // remain + int w_unroll_remian = ((win - 1) % 8) / 2; + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + const float* r2 = r1 + win; + int cnt_num = w_unroll_size; + int cnt_num1 = w_unroll_remian; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 1; + int cnt = 1; + float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); + float32x4_t vzero = vdupq_n_f32(0.f); + data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +// first row with zero pad +#ifdef __aarch64__ + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); + data_out_channel[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr0 = dr0 + 1; + dr1 = dr1 + 1; + dr_out = dr_out + 1; + // printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, 0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vadd.f32 q6, q0, q3 @max " + "r0_1234,r1_1234\n" + "vadd.f32 q7, q1, q4 @max " + "r0_5678,r1_5678\n" + "vadd.f32 q8, q2, q5 @max " + "r0_9101112,r1_9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 \n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 \n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 \n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n" + "bne 1b @bne s3_max_loop\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp cnt_num, " + "0\n" + "ble 4f @ble exit\n" + "2: @main loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9"); + } +// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1); +#endif + // int w = w_even - 1; + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp += (r0[i] + r1[i]); + } + data_out_channel[w_even >> 1] = tmp / 9.f; + // cnt ++; + } + + r0 = r1; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + int h = 2; + for (; h < h_even; h += 2) { + // deal with left pad + float sum0 = r0[0] + r0[1]; + float sum1 = r1[0] + r1[1]; + float sum2 = r2[0] + r2[1]; + data_out_channel[0] = (sum0 + sum1 + sum2) / 9.f; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + vsum_1234 = vaddq_f32(vsum_1234, vr2_1234); + vsum_5678 = vaddq_f32(vsum_5678, vr2_5678); + vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + vr2 = vsetq_lane_f32(0.f, vr2, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + vsum1 = vaddq_f32(vsum1, vr2); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + float32x2_t vsum = vpadd_f32(vsum2, vsum2); + data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f; + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + dr2 = (r2 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 q11, q2, q5 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 " + "1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 " + "5678\n" + "vadd.f32 q8, q11, q8 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 " + "\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 " + "\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 " + "\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 " + "\n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "sub %[dr2], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vadd.f32 q0, q0, q2 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef), + [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; + for (int i = wstart; i < wend; i++) { + tmp += (r0[i] + r1[i] + r2[i]); + } + data_out_channel[w_even >> 1] = tmp / 9.f; + // cnt ++; + } + r0 = r2; + r1 = r0 + win; + r2 = r1 + win; + data_out_channel += wout; + } + + if (pad_bottom) { + // deal with bottom pad + // first row with zero pad + int hstart = (h >> 1) * stride_h - pad_h; + int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin); + + if (hstart == hend - 1) { // only one lline + data_out_channel[0] = (r0[0] + r0[1]) / 9.f; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vsum_1234 = vld1q_f32(&r0[w]); + float32x4_t vsum_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vsum_9101112 = vld1q_f32(&r0[w + 8]); + + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), + vsum_123_345, 1); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), + vsum_123_345, 2); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), + vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + float32x2_t vsum = vpadd_f32(vget_low_f32(vr0), vget_high_f32(vr0)); + vsum = vpadd_f32(vsum, vsum); + data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f; + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d12-d15}, [%[dr0]]! @load " + "d0-d3, dr0\n" + "vld1.f32 {d16-d17}, [%[dr0]]! @load " + "d0-d3, dr0\n" + "vext.f32 q0, q6, q7, #1 @vext " + "max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext " + "max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext " + "max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext " + "max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, " + "2345 \n" + "vadd.f32 q5, q7, q1 @add 5678, " + "4567 \n" + "vadd.f32 q4, q4, q2 @add 3456, " + "sum1 \n" + "vadd.f32 q5, q5, q3 @add 6789, " + "sum2 \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 6\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext " + "v0_0123\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 2\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; + for (int i = wstart; i < wend; i++) { + tmp += r0[i]; + } + data_out_channel[w_even >> 1] = tmp / 9.f; + } + } else { // two lines + data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +#ifdef __aarch64__ + w = 1; + cnt = 1; + for (; w < win - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), + vsum_123_345, 1); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), + vsum_123_345, 2); + vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), + vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); + data_out_channel[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr_out = data_out_channel + 1; + dr0 = (r0 + 1); + dr1 = (r1 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, " + "dr0\n" + "vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vmax.f32 q8, q2, q5 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, + // s6\n" + "vext.f32 q0, q6, q7, #1 @vext " + "max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext " + "max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext " + "max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext " + "max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, " + "2345 \n" + "vadd.f32 q5, q7, q1 @add 5678, " + "4567 \n" + "vadd.f32 q4, q4, q2 @add 3456, " + "sum1 \n" + "vadd.f32 q5, q5, q3 @add 6789, " + "sum2 \n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #16 @add w, 8\n" + "sub %[dr1], #16 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext " + "v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext " + "v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win); + float tmp = 0.f; + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp += (r0[i] + r1[i]); + } + data_out_channel[w_even >> 1] = tmp / 9.f; + } + } + } + } + } +} + +void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int size_channel_out = w_out * h_out; + int size_channel_in = w_in * h_in; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = w_out * 2 + 1; + int h_needed = h_out * 2 + 1; + int pad_right = w_needed - w_in - pad_left; + int pad_bottom = h_needed - h_in - pad_top; + int w_even = ((w_in - 1) >> 1) << 1; + // int w_remains = w_in - w_even; // should be 0 or 1 + int h_even = ((h_in - 1) >> 1) << 1; + // int h_remains = h_in - h_even; // should be 0 or 1 + int w_unroll_size = w_in >> 3; + int w_unroll_remian = (w_in - w_unroll_size * 8 - 1) / 2; + int w_in_2 = w_in << 1; + float minval = std::numeric_limits::lowest(); + float32x4_t vzero = vdupq_n_f32(minval); // zero pad + // printf("minval: %.2f\n", minval); + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; + // w = w_in - 8; + int cnt_num1 = w_unroll_remian; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + int w = 0; + int cnt = 0; + // data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], + // r1[1])); + // first row with zero pad + // r0 = r1; + // r1 = r0 + w_in; + // r2 = r1 + w_in; + // data_out_channel += w_out; + int h = 0; + for (; h < h_even; h += 2) { + // deal with left pad + float maxr0 = std::max(r0[0], r0[1]); + float maxr1 = std::max(r1[0], r1[1]); + float maxr2 = std::max(r2[0], r2[1]); +// data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2); +#ifdef __aarch64__ + w = 0; + cnt = 0; + for (; w < w_in - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vr2_1234 = vld1q_f32(&r2[w]); + float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]); + float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + float32x4_t vr2 = vld1q_f32(&r2[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax, 0); + cnt++; + } +#else + dr_out = data_out_channel; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + dr2 = r2; // (r2 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n" + "vmax.f32 q9, q0, q3 @max q0,q0,q2\n" + "vmax.f32 q10, q1, q4 @max q1,q1,q3\n" + "vmax.f32 d22, d4, d10 @max q1,q1,q3\n" + "vmax.f32 q0, q9, q6 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q3, q10, q7 @max q1,q1,q3 " + "5678\n" + "vmax.f32 d2, d22, d16 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q4, q0, q3, #1 @vext 2345\n" + "vext.f32 q2, q3, q1, #1 @vext 6789\n" + "vpmax.f32 d10, d0, d1 @pmax d10, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d12, d6, d7 @pmax d12, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d11, d8, d9 @pmax d11, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d13, d4, d5 @pmax d13, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d0, d10, d11 @pmax d0, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d1, d12, d13 @pmax d1, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "sub %[dr2], #8 @add w, 8\n" + "vst1.f32 d0, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d1, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne s3_max_loop_mid\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmov.f32 s11,s10 @movs11, s10\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vmax.f32 q0, q0, q2 @max q0, q0, " + "q2\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "bne 2b @bne s3_max_loop_mid_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(r0[i], r1[i])); + tmp = std::max(tmp, r2[i]); + } + data_out_channel[w_even >> 1] = tmp; + // cnt ++; + } + r0 = r2; + r1 = r0 + w_in; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + if (pad_bottom) { +// deal with bottom pad +// first row with zero pad +// int hstart = (h >> 1) * stride_h - pad_h; +// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in); +// data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], +// r1[1])); +#ifdef __aarch64__ + w = 0; + cnt = 0; + for (; w < w_in - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(&data_out_channel[cnt], vmax_123_345); + vst1_f32(&data_out_channel[cnt + 2], vmax_567_789); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + vmax2 = vpmax_f32(vmax2, vmax2); + data_out_channel[cnt] = vget_lane_f32(vmax2, 0); + cnt++; + } +#else + dr_out = data_out_channel; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + if (cnt_num > 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 3f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vmax.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vmax.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vmax.f32 d16, d4, d10 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext q0, 2345\n" + "vext.f32 q1, q7, q8, #1 @vext q1, 6789\n" + "vpmax.f32 d4, d12, d13 @pmax d4, " + "vmax_1234, vmax_1234\n" + "vpmax.f32 d6, d14, d15 @pmax d6, " + "vmax_5678, vmax_5678\n" + "vpmax.f32 d5, d0, d1 @pmax d5, " + "vmax_2345, vmax_2345\n" + "vpmax.f32 d7, d2, d3 @pmax d7, " + "vmax_6789, vmax_6789\n" + "vmax.f32 d8, d4, d5 @max d2, " + "vmax_12_34, vmax_23_45\n" + "vmax.f32 d9, d6, d7 @max d2, " + "vmax_56_78, vmax_67_89\n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "bne 1b @bne s3_max_loop_bot\n" + "3: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 4f @ble exit\n" + "2: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vmov.f32 s3,s2 @movs3, s2\n" + "vmov.f32 s7,s6 @movs7, s6\n" + "vmax.f32 q0, q0, q1 @max q0, q0, " + "q1\n" + "vpmax.f32 d0, d0, d1 @pmax d0, " + "d0,d1\n" + "vpmax.f32 d0, d0, d0 @pmax d0, d0, " + "d0\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "bne 2b @bne s3_max_loop_bot_1\n" + "4: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = r0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp = std::max(tmp, std::max(r0[i], r1[i])); + } + data_out_channel[w_even >> 1] = tmp; + } + } + } + } +} + +void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type) { + int w_in = win; + int h_in = hin; + int ch_in = chin; + + int w_out = wout; + int h_out = hout; + int ch_out = chout; + + int kernel_h = ksize[0]; + int kernel_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + int size_channel_out = w_out * h_out; + int size_channel_in = w_in * h_in; + float* data_out = static_cast(dout); + const float* data_in = static_cast(din); + + int pad_top = pad_h; + int pad_left = pad_w; + int w_needed = w_out * 2 + 1; + int h_needed = h_out * 2 + 1; + int pad_right = w_needed - w_in - pad_left; + int pad_bottom = h_needed - h_in - pad_top; + int w_even = ((w_in - 1) >> 1) << 1; + int h_even = ((h_in - 1) >> 1) << 1; + int w_in_2 = w_in << 1; + int w_unroll_size = w_in >> 3; + int w_unroll_remian = (w_even - w_unroll_size * 8 - 1) / 2; + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * ch_out * size_channel_out; + const float* data_in_batch = data_in + n * ch_in * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < ch_out; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + w_in; + const float* r2 = r1 + w_in; + int cnt_num = w_unroll_size; + // w = w_in - 8; + int cnt_num1 = w_unroll_remian; + float* dr_out = data_out_channel; + const float* dr0 = r0; + const float* dr1 = r1; + const float* dr2 = r2; + + float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); + float32x4_t vzero = vdupq_n_f32(0.f); + + int h = 0; + for (; h < h_even; h += 2) { +// LOG(INFO) << "h: " << h<<", dr0:" << r0 <<", dr1: "< 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble loop3_ave_p0 @ble " + "exit\n" + "s3_ave_loop_mid_p0: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n" + "vadd.f32 q9, q0, q3 @max q0,q0,q2\n" + "vadd.f32 q10, q1, q4 @max q1,q1,q3\n" + "vadd.f32 d22, d4, d10 @max q1,q1,q3\n" + "vadd.f32 q6, q9, q6 @max q0,q0,q2 " + "1234\n" + "vadd.f32 q7, q10, q7 @max q1,q1,q3 " + "5678\n" + "vadd.f32 d16, d22, d16 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 " + "\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 " + "\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 " + "\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 " + "\n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "sub %[dr2], #8 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne s3_ave_loop_mid_p0 @bne " + "s3_max_loop_mid\n" + "loop3_ave_p0: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble exit1_ave_p0 @ble " + "exit1\n" + "s3_ave_loop_mid_1_p0: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vadd.f32 q0, q0, q2 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "sub %[dr2], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs cnt_num, " + "#1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne s3_ave_loop_mid_1_p0 @bne " + "s3_max_loop_mid_1\n" + "exit1_ave_p0: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num), + [cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef), + [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num), + "r"(cnt_num1) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", + "q10", "q11", "q12"); + } +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = 0.f; + int pool_size = 3 * (wend - wstart); + for (int i = wstart; i < wend; i++) { + tmp += (r0[i] + r1[i] + r2[i]); + } + data_out_channel[w_even >> 1] = tmp / pool_size; + // cnt ++; + } + r0 = r2; + r1 = r0 + w_in; + r2 = r1 + w_in; + data_out_channel += w_out; + } + + if (pad_bottom) { +// deal with bottom pad +// first row with zero pad +// int hstart = (h >> 1) * stride_h - pad_h; +// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in); +// data_out_channel[0] =(r0[0] + r0[1] + r1[0] + r1[1]) / 9.f; +#if 1 // def __aarch64__ + int w = 0; + int cnt = 0; + vcoef = vdupq_n_f32(1.f / 6.f); + for (; w < w_in - 8; w += 8) { + float32x4_t vr0_1234 = vld1q_f32(&r0[w]); + float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]); + float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]); + float32x4_t vr1_1234 = vld1q_f32(&r1[w]); + float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]); + float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]); + + float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234); + float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678); + float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112); + float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1); + float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2); + float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3); + float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1); + float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345); + vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456); + float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678); + vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2); + vsum_123_345 = + vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3); + float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef); + vst1q_f32(&data_out_channel[cnt], vrst); + cnt += 4; + } + for (; w < w_even - 1; w += 2) { + float32x4_t vr0 = vld1q_f32(&r0[w]); + float32x4_t vr1 = vld1q_f32(&r1[w]); + vr0 = vsetq_lane_f32(0.f, vr0, 3); + vr1 = vsetq_lane_f32(0.f, vr1, 3); + float32x4_t vsum1 = vaddq_f32(vr0, vr1); + float32x2_t vsum2 = + vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1)); + vsum2 = vpadd_f32(vsum2, vsum2); + float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef)); + data_out_channel[cnt] = vget_lane_f32(vrst, 0); + cnt++; + } +#else + dr_out = data_out_channel; // + 1; + dr0 = r0; // (r0 + 1); + dr1 = r1; // (r1 + 1); + cnt_num = w_unroll_size; + cnt_num1 = w_unroll_remian; + // LOG(INFO) << "dr0:" << dr0 <<", dr1: "< 0 || cnt_num1 > 0) { + asm volatile( + "cmp %[cnt_num], #0 @cmp cnt_num, " + "0\n" + "ble 2f @ble exit\n" + "1: @main loop\n" + "vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, " + "dr0\n" + "vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, " + "dr1\n" + "vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n" + "vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n" + "vadd.f32 q6, q0, q3 @max q0,q0,q2 " + "1234\n" + "vadd.f32 q7, q1, q4 @max q1,q1,q3 " + "5678\n" + "vadd.f32 d16, d4, d10 @max q1,q1,q3 " + "9101112\n" + //"vmov.f32 s7,s6 @mov s7, s6\n" + "vext.f32 q0, q6, q7, #1 @vext max_2345\n" + "vext.f32 q1, q6, q7, #3 @vext max_4567\n" + "vext.f32 q2, q6, q7, #2 @vext max_3456\n" + "vext.f32 q3, q7, q8, #1 @vext max_6789\n" + "vadd.f32 q4, q6, q0 @add 1234, 2345 " + "\n" + "vadd.f32 q5, q7, q1 @add 5678, 4567 " + "\n" + "vadd.f32 q4, q4, q2 @add 3456, sum1 " + "\n" + "vadd.f32 q5, q5, q3 @add 6789, sum2 " + "\n" + "vmov.f32 s17, s18 @mov \n" + "vmov.f32 s18, s21 @mov \n" + "vmov.f32 s19, s23 @mov \n" + "vmul.f32 q4, q4, %q[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 8\n" + "sub %[dr1], #8 @add w, 8\n" + "subs %[cnt_num], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d8, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "vst1.f32 d9, [%[dr_out]]! @vst1 d0, " + "dr_out\n" + "bne 1b @bne s3_max_loop_bot\n" + "2: @loop \n" + "cmp %[cnt_num1], #0 @cmp " + "cnt_num, 0\n" + "ble 3f @ble exit\n" + "4: @bot loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, " + "dr0\n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, " + "dr1\n" + "vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n" + "vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n" + "vadd.f32 q0, q0, q1 @add q0, q0, " + "q1\n" + "vpadd.f32 d0, d0, d1 @padd d0, " + "d0,d1\n" + "vpadd.f32 d0, d0, d0 @padd d0, d0, " + "d0\n" + "vmul.f32 d0, d0, %e[vcoef] @mul \n" + "sub %[dr0], #8 @add w, 6\n" + "sub %[dr1], #8 @add w, 6\n" + "subs %[cnt_num1], #1 @subs " + "cnt_num, #1\n" + "vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], " + "dr_out\n" + "bne 4b @bne s3_max_loop_bot_1\n" + "3: @exit\n" + : [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1), + [vcoef] "+w"(vcoef), [vzero] "+w"(vzero) + : "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9"); + } + +#endif + if (pad_right) { + // deal with right pad + int wstart = (w_even >> 1) * stride_w - pad_w; + int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in); + float tmp = 0.f; + int pool_size = 2 * (wend - wstart); + for (int i = wstart; i < wend; i++) { // only run 1 or 2 times + tmp += (r0[i] + r1[i]); + } + data_out_channel[w_even >> 1] = tmp / pool_size; + } + } + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/pooling.h b/paddle/fluid/lite/arm/math/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..36832187073c2d29a129a10fdd7984ba8d15db3d --- /dev/null +++ b/paddle/fluid/lite/arm/math/pooling.h @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +// !pooling fp32 Op +void pooling_basic(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling_global(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout, + int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout, + int hout, int wout, int chin, int hin, int win, + const std::vector& ksize, + const std::vector& strides, + const std::vector& paddings, bool global_pooling, + bool exclusive, bool adaptive, bool ceil_mode, + bool use_quantizer, const std::string& pooling_type); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/split.cc b/paddle/fluid/lite/arm/math/split.cc new file mode 100644 index 0000000000000000000000000000000000000000..6dd6de6242e806947dfc630fd8f2a4dd03c89335 --- /dev/null +++ b/paddle/fluid/lite/arm/math/split.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/split.h" +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void split_cpy(const float* din, float* dout, int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr; + dout_ptr++; + din_ptr++; + } + } +} + +template <> +void split(const float* din, std::vector* dout, + const int axis, const std::vector& in_strides) { + int input_offset = 0; + for (auto out : *dout) { + auto out_dim = out->dims(); + std::vector out_strides(out_dim.size()); + out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; + for (int i = out_dim.size() - 2; i >= 0; --i) { + out_strides[i] = out_strides[i + 1] * out_dim[i]; + } + + float* out_data = out->mutable_data(); + int before = out_strides[0] / out_strides[axis]; + int in_after = in_strides[axis]; + int out_after = out_strides[axis]; + + for (int i = 0; i < before; ++i) { + split_cpy(din + input_offset + i * in_after, out_data + i * out_after, + out_after); + } + input_offset += out_strides[axis]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/split.h b/paddle/fluid/lite/arm/math/split.h new file mode 100644 index 0000000000000000000000000000000000000000..9b5651d81ffa75362fcc39db82157c56548917c0 --- /dev/null +++ b/paddle/fluid/lite/arm/math/split.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void split_cpy(const T* din, T* dout, int num); + +template +void split(const T* din, std::vector* dout, const int axis, + const std::vector& in_strides); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/type_trans.cpp b/paddle/fluid/lite/arm/math/type_trans.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a60cc80f8d164324cd397f07e800d8e32a74533b --- /dev/null +++ b/paddle/fluid/lite/arm/math/type_trans.cpp @@ -0,0 +1,588 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/saturate.h" +#include +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void int32_to_dtype(const int* din, dtype* dout, const float* scale, + int axis_size, long long outer_size, long long inner_size); + +void fp32_to_int8(const float* din, signed char* dout, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + + int cnt = inner_size / 16; + int remain = inner_size & 15; + long long loop_size = outer_size * axis_size; + +#pragma omp parallel for + for (int j = 0; j < loop_size; ++j) { + float inv_scale = 1.f / scale[j % axis_size]; + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vscale = vdupq_n_f32(inv_scale); + float32x4_t vpoff = vdupq_n_f32(0.5f); + float32x4_t vnoff = vdupq_n_f32(-0.5f); + const float* din_c = din + j * inner_size; + signed char* dout_c = dout + j * inner_size; + if (cnt > 0) { + int cnt_loop = cnt; + const float* din_ptr = din_c; + signed char* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" + "ldp q2, q3, [%[in]], #32 \n" + "0: \n" /* main loop */ + "fmul v4.4s, v0.4s, %[scale].4s \n" + "fmul v5.4s, v1.4s, %[scale].4s \n" + "fmul v6.4s, v2.4s, %[scale].4s \n" + "fmul v7.4s, v3.4s, %[scale].4s \n" + "ldp q0, q1, [%[in]], #32 \n" + "subs %[cnt], %[cnt], #1 \n" + "FCVTAS v8.4s, v4.4s \n" + "FCVTAS v9.4s, v5.4s \n" + "FCVTAS v10.4s, v6.4s \n" + "FCVTAS v11.4s, v7.4s \n" + "ldp q2, q3, [%[in]], #32 \n" + "sqxtn v4.4h, v8.4s \n" + "sqxtn2 v4.8h, v9.4s \n" + "sqxtn v5.4h, v10.4s \n" + "sqxtn2 v5.8h, v11.4s \n" + "sqxtn v8.8b, v4.8h \n" + "sqxtn2 v8.16b, v5.8h \n" + "str q8, [%[out]], #16 \n" + "bne 0b \n" + : [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop) + : [scale] "w" (vscale) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" + ); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "0: @ main loop\n" + "vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q5, q4, q4 @ set offset, 0.5\n" + "vand.i32 q6, q4, q4 @ set offset, 0.5\n" + "vand.i32 q7, q4, q4 @ set offset, 0.5\n" + "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" + "vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n" + "vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n" + "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" + "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" + "vbif.f32 q6, %q[vnoff], q10 @ get right offset\n" + "vbif.f32 q7, %q[vnoff], q11 @ get right offset\n" + "vmla.f32 q4, q0, %q[vscale] @ mul scale\n" + "vmla.f32 q5, q1, %q[vscale] @ mul scale\n" + "vmla.f32 q6, q2, %q[vscale] @ mul scale\n" + "vmla.f32 q7, q3, %q[vscale] @ mul scale\n" + "vcvt.s32.f32 q0, q4 @ cvt to int32\n" + "vcvt.s32.f32 q1, q5 @ cvt to int32\n" + "vcvt.s32.f32 q2, q6 @ cvt to int32\n" + "vcvt.s32.f32 q3, q7 @ cvt to int32\n" + "vqmovn.s32 d8, q0 @ cnt to int16\n" + "vqmovn.s32 d9, q1 @ cnt to int16\n" + "vqmovn.s32 d10, q2 @ cnt to int16\n" + "vqmovn.s32 d11, q3 @ cnt to int16\n" + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vqmovn.s16 d12, q4 @ cnt to int8\n" + "vqmovn.s16 d13, q5 @ cnt to int8\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "vst1.32 {d12-d13}, [%[dout]]! @ write to output\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ to main loop\n" + + :[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop) + :[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero) + :"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11" + ); +#endif + } + const float* din_r = din_c + 16 * cnt; + signed char* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = saturate_cast(roundf(inv_scale * din_r[i])); + } + } +} + +void fp32_to_int16(const float* din, int16_t* dout, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + + int cnt = inner_size / 8; + int remain = inner_size & 7; + long long loop_size = outer_size * axis_size; + +#pragma omp parallel for + for (int j = 0; j < loop_size; ++j) { + float inv_scale = 1.f / scale[j % axis_size]; + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vscale = vdupq_n_f32(inv_scale); + float32x4_t vpoff = vdupq_n_f32(0.5f); + float32x4_t vnoff = vdupq_n_f32(-0.5f); + const float* din_c = din + j * inner_size; + int16_t* dout_c = dout + j * inner_size; + if (cnt > 0) { + int cnt_loop = cnt; + const float* din_ptr = din_c; + int16_t* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" + "0: \n" /* main loop */ + "fmul v4.4s, v0.4s, %[scale].4s \n" + "fmul v5.4s, v1.4s, %[scale].4s \n" + "ldp q0, q1, [%[in]], #32 \n" + "subs %[cnt], %[cnt], #1 \n" + "FCVTAS v8.4s, v4.4s \n" + "FCVTAS v9.4s, v5.4s \n" + "sqxtn v4.4h, v8.4s \n" + "sqxtn2 v4.8h, v9.4s \n" + "str q4, [%[out]], #16 \n" + "bne 0b \n" + : [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop) + : [scale] "w" (vscale) + : "v0", "v1", "v4", "v5", "v8", "v9" + ); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "0: @ main loop\n" + "vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q5, q4, q4 @ set offset, 0.5\n" + "vand.i32 q6, q4, q4 @ set offset, 0.5\n" + "vand.i32 q7, q4, q4 @ set offset, 0.5\n" + "vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n" + "vbif.f32 q4, %q[vnoff], q8 @ get right offset\n" + "vbif.f32 q5, %q[vnoff], q9 @ get right offset\n" + "vmla.f32 q4, q0, %q[vscale] @ mul scale\n" + "vmla.f32 q5, q1, %q[vscale] @ mul scale\n" + "vcvt.s32.f32 q0, q4 @ cvt to int32\n" + "vcvt.s32.f32 q1, q5 @ cvt to int32\n" + "vqmovn.s32 d8, q0 @ cnt to int16\n" + "vqmovn.s32 d9, q1 @ cnt to int16\n" + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vst1.32 {d8-d9}, [%[dout]]! @ write to output\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 0b @ to main loop\n" + + :[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop) + :[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero) + :"q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9" + ); +#endif + } + const float* din_r = din_c + 8 * cnt; + int16_t* dout_r = dout_c + 8 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = saturate_cast(roundf(inv_scale * din_r[i])); + } + } +} + +void int8_to_fp32(const signed char* in, float* out, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + + int cnt = inner_size / 16; + int remain = inner_size & 15; + long long loop_size = axis_size * outer_size; +#pragma omp parallel for + for (long long n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const signed char* din_c = in + n * inner_size; + float* dout_c = out + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + if (cnt > 0) { + int loop = cnt; + const signed char* din_ptr = din_c; + float* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/ + "0: \n" /* main loop */ + "sshll v2.8h, v0.8b, #0 \n" /* trans to int16*/ + "sshll v3.8h, v1.8b, #0 \n" /* trans to int16*/ + + "sshll v4.4s, v2.4h, #0 \n" /* trans to int32*/ + "sshll2 v5.4s, v2.8h, #0 \n" /* trans to int32*/ + "sshll v6.4s, v3.4h, #0 \n" /* trans to int32*/ + "sshll2 v7.4s, v3.8h, #0 \n" /* trans to int32*/ + + "ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/ + + "scvtf v8.4s, v4.4s \n" /* trans to fp32*/ + "scvtf v9.4s, v5.4s \n" /* trans to fp32*/ + "scvtf v10.4s, v6.4s \n" /* trans to fp32*/ + "scvtf v11.4s, v7.4s \n" /* trans to fp32*/ + + "subs %[loop], %[loop], #1 \n" + + "fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/ + + "stp q4, q5, [%[out]], #32 \n" /* write to memory*/ + "stp q6, q7, [%[out]], #32 \n" /* write to memory*/ + + "bne 0b \n" + :[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) + :[scale] "w" (vscale) + :"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" + ); +#else + asm volatile( + "vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n" + "0: @ main loop\n" + "vmovl.s8 q2, d0 @ trans to int16\n" + "vmovl.s8 q3, d1 @ trans to int16\n" + "vmovl.s16 q4, d4 @ trans to int32\n" + "vmovl.s16 q5, d5 @ trans to int32\n" + "vmovl.s16 q6, d6 @ trans to int32\n" + "vmovl.s16 q7, d7 @ trans to int32\n" + "vcvt.f32.s32 q0, q4 @ trans to fp32\n" + "vcvt.f32.s32 q1, q5 @ trans to fp32\n" + "vcvt.f32.s32 q2, q6 @ trans to fp32\n" + "vcvt.f32.s32 q3, q7 @ trans to fp32\n" + "vmul.f32 q4, q0, %q[scale] @ mul with scale\n" + "vmul.f32 q5, q1, %q[scale] @ mul with scale\n" + "vmul.f32 q6, q2, %q[scale] @ mul with scale\n" + "vmul.f32 q7, q3, %q[scale] @ mul with scale\n" + + "vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n" + + "subs %[loop], #1 \n" + + "vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n" + "vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n" + + "bne 0b \n" + :[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) + :[scale] "w" (vscale) + :"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" + ); +#endif //__aarch64__ + } + const signed char* din_r = din_c + 16 * cnt; + float* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = in_scale * din_r[i]; + } + } +} + +void int16_to_fp32(const short* in, float* out, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + + int cnt = inner_size / 16; + int remain = inner_size & 15; + long long loop_size = axis_size * outer_size; +#pragma omp parallel for + for (long long n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const short* din_c = in + n * inner_size; + float* dout_c = out + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + if (cnt > 0) { + int loop = cnt; + const short* din_ptr = din_c; + float* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/ + "0: \n" /* main loop */ + "sshll v4.4s, v0.4h, #0 \n" /* trans to int32*/ + "sshll2 v5.4s, v0.8h, #0 \n" /* trans to int32*/ + "sshll v6.4s, v1.4h, #0 \n" /* trans to int32*/ + "sshll2 v7.4s, v1.8h, #0 \n" /* trans to int32*/ + + "ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/ + + "scvtf v8.4s, v4.4s \n" /* trans to fp32*/ + "scvtf v9.4s, v5.4s \n" /* trans to fp32*/ + "scvtf v10.4s, v6.4s \n" /* trans to fp32*/ + "scvtf v11.4s, v7.4s \n" /* trans to fp32*/ + + "subs %[loop], %[loop], #1 \n" + + "fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/ + "fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/ + + "stp q4, q5, [%[out]], #32 \n" /* write to memory*/ + "stp q6, q7, [%[out]], #32 \n" /* write to memory*/ + + "bne 0b \n" + :[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) + :[scale] "w" (vscale) + :"v0", "v1", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" + ); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[in]]! @ load 16 int16\n" + "0: @ main loop\n" + "vmovl.s16 q4, d0 @ trans to int32\n" + "vmovl.s16 q5, d1 @ trans to int32\n" + "vmovl.s16 q6, d2 @ trans to int32\n" + "vmovl.s16 q7, d3 @ trans to int32\n" + "vcvt.f32.s32 q0, q4 @ trans to fp32\n" + "vcvt.f32.s32 q1, q5 @ trans to fp32\n" + "vcvt.f32.s32 q2, q6 @ trans to fp32\n" + "vcvt.f32.s32 q3, q7 @ trans to fp32\n" + "vmul.f32 q4, q0, %q[scale] @ mul with scale\n" + "vmul.f32 q5, q1, %q[scale] @ mul with scale\n" + "vmul.f32 q6, q2, %q[scale] @ mul with scale\n" + "vmul.f32 q7, q3, %q[scale] @ mul with scale\n" + + "vld1.32 {d0-d3}, [%[in]]! @ load 16 int8\n" + + "subs %[loop], #1 \n" + + "vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n" + "vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n" + + "bne 0b \n" + :[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) + :[scale] "w" (vscale) + :"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7" + ); +#endif //__aarch64__ + } + const short* din_r = din_c + 16 * cnt; + float* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = in_scale * din_r[i]; + } + } +} + +void int32_to_fp32(const int* din, float* dout, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + int cnt = inner_size / 16; + int remain = inner_size & 15; + long long loop_size = axis_size * outer_size; +#pragma omp parallel for + for (long long n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const int* din_c = din + n * inner_size; + float* dout_c = dout + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + if (cnt > 0) { + int loop = cnt; + const int* din_ptr = din_c; + float* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "ldp q0, q1, [%[in]], #32 \n" + "ldp q2, q3, [%[in]], #32 \n" + "0: \n" + "scvtf v4.4s, v0.4s \n" + "scvtf v5.4s, v1.4s \n" + "scvtf v6.4s, v2.4s \n" + "scvtf v7.4s, v3.4s \n" + "ldp q0, q1, [%[in]], #32 \n" + "fmul v8.4s, v4.4s, %[scale].4s \n" + "fmul v9.4s, v5.4s, %[scale].4s \n" + "fmul v10.4s, v6.4s, %[scale].4s \n" + "fmul v11.4s, v7.4s, %[scale].4s \n" + "ldp q2, q3, [%[in]], #32 \n" + "stp q8, q9, [%[out]], #32 \n" + "stp q10, q11, [%[out]], #32 \n" + "subs %[loop], %[loop], #1 \n" + "bne 0b \n" + :[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) + :[scale] "w" (vscale) + :"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11" + ); +#else + asm volatile( + "vld1.s32 {d0-d3}, [%[in]]! \n" + "vld1.s32 {d4-d7}, [%[in]]! \n" + "0: \n" + "vcvt.f32.s32 q4, q0 \n" + "vcvt.f32.s32 q5, q1 \n" + "vcvt.f32.s32 q6, q2 \n" + "vcvt.f32.s32 q7, q3 \n" + "vld1.s32 {d0-d3}, [%[in]]! \n" + "vmul.f32 q8, q4, %q[scale] \n" + "vmul.f32 q9, q5, %q[scale] \n" + "vmul.f32 q10, q6, %q[scale] \n" + "vmul.f32 q11, q7, %q[scale] \n" + "vld1.s32 {d4-d7}, [%[in]]! \n" + "subs %[loop], #1 \n" + "vst1.f32 {d16-d19}, [%[out]]! \n" + "vst1.f32 {d20-d23}, [%[out]]! \n" + "bne 0b \n" + :[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) + :[scale] "w" (vscale) + :"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11" + ); +#endif //__aarch64__ + } + const int* din_r = din_c + 16 * cnt; + float* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = in_scale * din_r[i]; + } + } +} + +void int32_to_int8(const int* din, signed char* dout, const float* scale, \ + int axis_size, long long outer_size, long long inner_size) { + int cnt = inner_size / 16; + int remain = inner_size & 15; + long long loop_size = outer_size * axis_size; +#pragma omp parallel for + for (long long n = 0; n < loop_size; ++n) { + float in_scale = scale[n % axis_size]; + const int* din_c = din + n * inner_size; + signed char* dout_c = dout + n * inner_size; + float32x4_t vscale = vdupq_n_f32(in_scale); + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vpoff = vdupq_n_f32(0.5f); + float32x4_t vnoff = vdupq_n_f32(-0.5f); + if (cnt > 0) { + int loop = cnt; + const int* din_ptr = din_c; + signed char* dout_ptr = dout_c; +#ifdef __aarch64__ + asm volatile( + "0: \n" + "ld1 {v0.4s, v1.4s}, [%[in]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in]], #32 \n" + + "scvtf v4.4s, v0.4s \n" + "scvtf v5.4s, v1.4s \n" + "scvtf v6.4s, v2.4s \n" + "scvtf v7.4s, v3.4s \n" + + "fmul v0.4s, v4.4s, %[scale].4s \n" + "fmul v1.4s, v5.4s, %[scale].4s \n" + "fmul v2.4s, v6.4s, %[scale].4s \n" + "fmul v3.4s, v7.4s, %[scale].4s \n" + + "fcvtas v4.4s, v0.4s \n" + "fcvtas v5.4s, v1.4s \n" + "fcvtas v6.4s, v2.4s \n" + "fcvtas v7.4s, v3.4s \n" + + "sqxtn v0.4h, v4.4s \n" + "sqxtn2 v0.8h, v5.4s \n" + "sqxtn v1.4h, v6.4s \n" + "sqxtn2 v1.8h, v7.4s \n" + + "sqxtn v2.8b, v0.8h \n" + "sqxtn2 v2.16b, v1.8h \n" + + "st1 {v2.16b}, [%[out]], #16 \n" + "subs %[loop], %[loop], #1 \n" + "bne 0b \n" + :[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr) + :[scale] "w" (vscale) + :"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + ); +#else + asm volatile( + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "0: @ main loop\n" + "vcvt.f32.s32 q4, q0 @ cvt to float\n" + "vcvt.f32.s32 q5, q1 @ cvt to float\n" + "vcvt.f32.s32 q6, q2 @ cvt to float\n" + "vcvt.f32.s32 q7, q3 @ cvt to float\n" + "vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n" + "vand.i32 q1, q0, q0 @ set offset, 0.5\n" + "vand.i32 q2, q0, q0 @ set offset, 0.5\n" + "vand.i32 q3, q0, q0 @ set offset, 0.5\n" + "vcgt.f32 q8, q4, %q[vzero] @ get mask > 0, in0\n" + "vcgt.f32 q9, q5, %q[vzero] @ get mask > 0, in1\n" + "vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in2\n" + "vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in3\n" + "vbif.f32 q0, %q[vnoff], q8 @ get right offset\n" + "vbif.f32 q1, %q[vnoff], q9 @ get right offset\n" + "vbif.f32 q2, %q[vnoff], q10 @ get right offset\n" + "vbif.f32 q3, %q[vnoff], q11 @ get right offset\n" + "vmla.f32 q0, q4, %q[vscale] @ mul scale\n" + "vmla.f32 q1, q5, %q[vscale] @ mul scale\n" + "vmla.f32 q2, q6, %q[vscale] @ mul scale\n" + "vmla.f32 q3, q7, %q[vscale] @ mul scale\n" + "vcvt.s32.f32 q4, q0 @ cvt to int32\n" + "vcvt.s32.f32 q5, q1 @ cvt to int32\n" + "vcvt.s32.f32 q6, q2 @ cvt to int32\n" + "vcvt.s32.f32 q7, q3 @ cvt to int32\n" + "vqmovn.s32 d16, q4 @ cnt to int16\n" + "vqmovn.s32 d17, q5 @ cnt to int16\n" + "vqmovn.s32 d18, q6 @ cnt to int16\n" + "vqmovn.s32 d19, q7 @ cnt to int16\n" + "vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n" + "vqmovn.s16 d8, q8 @ cnt to int8\n" + "vqmovn.s16 d9, q9 @ cnt to int8\n" + "vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n" + "vst1.32 {d8-d9}, [%[dout]]! @ write to output\n" + "subs %[loop], #1 @ loop count -1\n" + "bne 0b @ to main loop\n" + :[loop] "+r" (loop), [din] "+r" (din_ptr), [dout] "+r" (dout_ptr) + :[vscale] "w" (vscale), [vzero] "w"(vzero), [vnoff] "w" (vnoff), [vpoff] "w" (vpoff) + :"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11" + ); +#endif //__aarch64__ + } + const int* din_r = din_c + 16 * cnt; + int8_t* dout_r = dout_c + 16 * cnt; + for (int i = 0; i < remain; ++i) { + dout_r[i] = saturate_cast(roundf(in_scale * din_r[i])); + } + } +} + +void int32_to_int32(const int* din, int* dout, const float* scale, \ + int axis_size, long long outer_size, long long inner_size) { + int size_all = outer_size * axis_size * inner_size; + memmove(dout, din, size_all*sizeof(int)); +} + +template <> +void int32_to_dtype(const int* din, float* dout, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + + return int32_to_fp32(din, dout, scale, axis_size, outer_size, inner_size); +} + +template <> +void int32_to_dtype(const int* din, signed char* dout, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + + return int32_to_int8(din, dout, scale, axis_size, outer_size, inner_size); +} + +template <> +void int32_to_dtype(const int* din, int* dout, const float* scale, + int axis_size, long long outer_size, long long inner_size) { + + return int32_to_int32(din, dout, scale, axis_size, outer_size, inner_size); +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h index 5948f6c4a854d9f678c316f351c017788c44c4a2..6b019abc19d4e0e0add32b23d3f39820b8b47588 100644 --- a/paddle/fluid/lite/core/memory.h +++ b/paddle/fluid/lite/core/memory.h @@ -65,6 +65,8 @@ class Buffer { TargetCopy(target_, data_, other.data_, nbytes); } + ~Buffer() { Free(); } + private: // memory it actually malloced. size_t space_{0}; diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_tester.cc b/paddle/fluid/lite/core/mir/pattern_matcher_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b082060fe21731000394f6941e0803af7da74d6 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_tester.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pattern_matcher.h" + +#include + +namespace paddle { +namespace lite { +namespace mir { + +void BuildGraph(SSAGraph* g) { + g->mutable_nodes().emplace_back(); + Node& o1 = g->mutable_nodes().back(); + o1.AsStmt().op_type = "op1"; + g->mutable_nodes().emplace_back(); + Node& o2 = g->mutable_nodes().back(); + o2.AsStmt().op_type = "op2"; + g->mutable_nodes().emplace_back(); + Node& o3 = g->mutable_nodes().back(); + o3.AsStmt().op_type = "op3"; + g->mutable_nodes().emplace_back(); + Node& o4 = g->mutable_nodes().back(); + o4.AsStmt().op_type = "op4"; + g->mutable_nodes().emplace_back(); + Node& o5 = g->mutable_nodes().back(); + o5.AsStmt().op_type = "op5"; + g->mutable_nodes().emplace_back(); + Node& v1 = g->mutable_nodes().back(); + v1.AsArg("var1"); + g->mutable_nodes().emplace_back(); + Node& v2 = g->mutable_nodes().back(); + v2.AsArg("var2"); + g->mutable_nodes().emplace_back(); + Node& v3 = g->mutable_nodes().back(); + v3.AsArg("var3"); + g->mutable_nodes().emplace_back(); + Node& v4 = g->mutable_nodes().back(); + v4.AsArg("var4"); + + // o1->v1->o2 + o1.outlinks.push_back(&v1); + o2.inlinks.push_back(&v1); + v1.inlinks.push_back(&o1); + v1.outlinks.push_back(&o2); + // o2->v2->o3 + // o2->v2->o4 + o2.outlinks.push_back(&v2); + o3.inlinks.push_back(&v2); + o4.inlinks.push_back(&v2); + v2.inlinks.push_back(&o2); + v2.outlinks.push_back(&o3); + v2.outlinks.push_back(&o4); + // o2->v3->o5 + o2.outlinks.push_back(&v3); + o5.inlinks.push_back(&v3); + v3.inlinks.push_back(&o2); + v3.outlinks.push_back(&o5); + // o3-v4->o5 + o3.outlinks.push_back(&v4); + o5.inlinks.push_back(&v4); + v4.inlinks.push_back(&o3); + v4.outlinks.push_back(&o5); +} + +TEST(PMPattern, NewNode) { + PMPattern x; + auto* n = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(n); + ASSERT_EQ(x.nodes_.size(), 1UL); +} + +TEST(PMPattern, AddEdge) { + PMPattern x; + auto* a = x.NewNode([](const Node* x) { return true; }); + auto* b = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(a); + ASSERT_TRUE(b); + x.AddEdge(a, b); + ASSERT_EQ(x.nodes_.size(), 2UL); + ASSERT_EQ(x.edges_.size(), 1UL); + ASSERT_EQ(x.edges_.front().first, a); + ASSERT_EQ(x.edges_.front().second, b); + + ASSERT_EQ(x.nodes().size(), 2UL); + ASSERT_EQ(x.edges().size(), 1UL); + ASSERT_EQ(x.edges().front().first, a); + ASSERT_EQ(x.edges().front().second, b); +} + +TEST(PatternMatcher, MarkPMNodesInGraph) { + PatternMatcher x; + // mark o2, o3, v2 + + // The pattern is a graph: + // o2(a node named o2) -> v2(a node named v2) + // v2 -> o3(a node named o3) + auto* o2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op2"; + }); + auto* o3 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op3"; + }); + auto* v2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsArg() && node->arg()->name == "var2"; + }); + + ASSERT_FALSE(o2->Tell(nullptr)); + ASSERT_FALSE(o3->Tell(nullptr)); + ASSERT_FALSE(v2->Tell(nullptr)); + + x.pattern_.AddEdge(o2, v2); + x.pattern_.AddEdge(v2, o3); + + ASSERT_EQ(x.pattern_.edges().size(), 2UL); + ASSERT_EQ(x.pattern_.edges()[0].first, o2); + ASSERT_EQ(x.pattern_.edges()[0].second, v2); + ASSERT_EQ(x.pattern_.edges()[1].first, v2); + ASSERT_EQ(x.pattern_.edges()[1].second, o3); + + SSAGraph graph; + BuildGraph(&graph); + + x.MarkPMNodesInGraph(&graph); + + ASSERT_EQ(x.pmnodes2nodes_.size(), 3UL); + + auto subgraphs = x.DetectPatterns(); + ASSERT_EQ(subgraphs.size(), 1UL); +} + +TEST(PatternMatcher, MultiSubgraph) { + SSAGraph graph; + BuildGraph(&graph); + + PatternMatcher x; + + // The pattern is a graph: + // op -> var + auto* any_op = x.mutable_pattern()->NewNode( + [](const Node* node) { + return node->IsStmt() && (node->stmt()->op_type == "op2" || + node->stmt()->op_type == "op3"); + }, + "OP0"); + auto* any_var = + x.mutable_pattern() + ->NewNode([](const Node* node) { return node->IsArg(); }, "VAR") + ->AsIntermediate(); + auto* any_op1 = x.mutable_pattern()->NewNode( + [](const Node* node) { return node->IsStmt(); }, "OP1"); + + x.mutable_pattern()->AddEdge(any_op, any_var); + x.mutable_pattern()->AddEdge(any_var, any_op1); + + int count = 0; + PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, + SSAGraph* g) { + LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> " + << s.at(any_var)->arg()->name << " -> " + << s.at(any_op1)->stmt()->op_type; + count++; + }; + + x(&graph, handle); + + // 1. Detect op3 -> var4 -> op5 + // 2. Detect op2 -> var2 -> op3 + // 3. Detect op2 -> var2 -> op4 + // 4. Detect op2 -> var3 -> op5 + // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 + ASSERT_GE(count, 1); + ASSERT_LE(count, 2); +} + +TEST(PatternMatcher, IntermediateCheck) { + SSAGraph graph; + BuildGraph(&graph); + + // o2->v2->o3 + // o2->v2->o4 + // check o2+o3 fuse, should fail because v2 also link to o4. + PatternMatcher matcher; + auto* op2 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op2"; + }, + "op2"); + auto* op3 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op3"; + }, + "op3"); + auto* v2 = matcher.mutable_pattern() + ->NewNode( + [](const Node* x) { + return x && x->IsArg() && x->arg()->name == "var2"; + }, + "var2") + ->AsIntermediate(); + v2->LinksFrom({op2}).LinksTo({op3}); + + int count = 0; + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + EXPECT_EQ(count, 0); + + count = 0; + v2->AsInput(); + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + ASSERT_EQ(count, 1); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 49332262deb6552c5c9079ff93d691f464ed7028..8f5bd651689112d8ba1da95a6474a00bdd120fef 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -91,9 +91,9 @@ class KernelRegistry final { void Register(const std::string &name, typename KernelRegistryForTarget::creator_t &&creator) { - // VLOG(3) << "register for " << TargetToStr(Target) << ":" - //<< PrecisionToStr(Precision) << "//" - //<< GetKernelOffset(); + VLOG(3) << "register for " << TargetToStr(Target) << ":" + << PrecisionToStr(Precision) << "//" + << GetKernelOffset(); using kernel_registor_t = KernelRegistryForTarget; auto &varient = registries_[GetKernelOffset()]; @@ -153,6 +153,12 @@ class KernelRegistor : public lite::Registor { public: KernelRegistor(const std::string &op_type, const std::string &alias) : Registor([=] { +<<<<<<< HEAD +======= + VLOG(3) << "Register kernel " << op_type << " for " + << TargetToStr(target) << " " << PrecisionToStr(precision) + << " " << DataLayoutToStr(layout) << " alias " << alias; +>>>>>>> gitlab/develop KernelRegistry::Global().Register( op_type, [=]() -> std::unique_ptr { std::unique_ptr x(new KernelType); diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index d6980ff8898374a54393d0b3c2b9af995504e42a..27677e23a27366d052001a6828f12d1cfcc5decb 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -21,6 +21,7 @@ * looks the same. */ +#include #include #include "paddle/fluid/lite/core/target_wrapper.h" diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index ff3cab02ee8b7e88783b8c6c18c496bf674c7cfd..1cf66b0d266b3edf0b0d271ceb5e375f01f652c3 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -9,12 +9,18 @@ cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) +lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) +lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) +lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) set(arm_kernels fc_compute_arm @@ -22,6 +28,11 @@ set(arm_kernels mul_compute_arm scale_compute_arm softmax_compute_arm - elementwise_add_compute_arm) + conv_compute_arm + elementwise_add_compute_arm + pool_compute_arm + split_compute_arm + ) set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") + diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.cc b/paddle/fluid/lite/kernels/arm/conv_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b464a5df0b0c33e76d2a31db183a515fea7a015 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/conv_compute.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/conv_compute.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ConvCompute::PrepareForRun() { + auto& param = this->Param(); + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + auto& ctx = this->ctx_->template As(); + + int win = x_dims[3]; // nchw + int hin = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int ow = o_dims[3]; + int oh = o_dims[2]; + int oc = o_dims[1]; + int kh = w_dims[2]; // oihw + int kw = w_dims[3]; + int pad = param.paddings[0]; + int stride = param.strides[0]; + + const auto* i_data = param.x->data(); + const auto* w_data = param.filter->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + bool kps_equal = (param.paddings[0] == param.paddings[1]) && + (param.strides[0] == param.strides[1]) && (kw == kh); + bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1); + bool flag_dw_3x3 = + (kw == 3 && (pad == 0 || pad == 1) && (stride == 1 || stride == 2)); + bool flag_dw_5x5 = + (kw == 5 && stride == 1) || (kw == 5 && stride == 2 && pad == 2); + bool flag_dw = flag_dw_3x3 || flag_dw_5x5; + + // select conv impl + if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) { + // dw conv impl + impl_ = new lite::arm::math::DepthwiseConv; + VLOG(3) << "invoking dw conv"; + } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && + no_dilation) { + if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) { + // winograd conv impl + impl_ = new lite::arm::math::WinogradConv; + VLOG(3) << "invoking winograd conv"; + } else { + // direct conv impl + impl_ = new lite::arm::math::DirectConv; + VLOG(3) << "invoking direct conv"; + } + } else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal && + no_dilation) { + // direct conv impl + impl_ = new lite::arm::math::DirectConv; + VLOG(3) << "invoking direct conv"; + } else { + impl_ = new lite::arm::math::GemmLikeConv; + VLOG(3) << "invoking gemm like conv"; + } + CHECK(this->impl_->create(param, &ctx)); +} + +void ConvCompute::Run() { + auto& param = this->Param(); + CHECK(impl_); + impl_->run(param); + // if (this->act_ != nullptr) { + // this->act_->run(outputs, outputs, param.activation_param); + // } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ConvCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ConvCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.h b/paddle/fluid/lite/kernels/arm/conv_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..21fabf8c3e8f7983a891265135c39b96aaf42e8d --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/conv_compute.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/operators/conv_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ConvCompute : public KernelLite { + public: + using param_t = operators::ConvParam; + + void PrepareForRun() override; + + void Run() override; + + ~ConvCompute() { + if (impl_ != nullptr) { + delete impl_; + } + } + + private: + lite::arm::math::ImplBase* impl_{ + nullptr}; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b95aa5ce4a3fd8bc1aa76c7ae3f66f13f60b4ea --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/conv_compute.h" +#include +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void conv_compute_ref(const operators::ConvParam& param) { + auto input = param.x; + auto filter = param.filter; + auto output = param.output; + DDim input_dims = param.x->dims(); + DDim filter_dims = param.filter->dims(); + DDim output_dims = param.output->dims(); + std::vector paddings = param.paddings; + std::vector strides = param.strides; + std::vector dilations = param.dilations; + int groups = param.groups; + + auto input_data = param.x->data(); + auto output_data = param.output->mutable_data(); + auto filter_data = param.filter->mutable_data(); + const float* bias_data = nullptr; + if (param.bias != nullptr) { + bias_data = param.bias->mutable_data(); + } + bool flag_bias = bias_data != nullptr; + bool flag_relu = false; // TODO(hong19860320) param.relu + + int num = input_dims[0]; + int chout = output_dims[1]; + int hout = output_dims[2]; + int wout = output_dims[3]; + + int chin = input_dims[1]; + int hin = input_dims[2]; + int win = input_dims[3]; + int out_c_group = chout / groups; + int in_c_group = chin / groups; + + int stride_h = strides[0]; + int stride_w = strides[1]; + int dilation_h = dilations[0]; + int dilation_w = dilations[1]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; + int kernel_h = filter_dims[2]; + int kernel_w = filter_dims[3]; + + for (int n = 0; n < num; ++n) { + for (int g = 0; g < groups; ++g) { + for (int oc = 0; oc < out_c_group; ++oc) { + for (int oh = 0; oh < hout; ++oh) { + for (int ow = 0; ow < wout; ++ow) { + int out_idx = n * groups * out_c_group * hout * wout + + g * out_c_group * hout * wout + oc * hout * wout + + oh * wout + ow; + output_data[out_idx] = + flag_bias ? static_cast(bias_data[g * out_c_group + oc]) + : 0.f; + for (int ic = 0; ic < in_c_group; ++ic) { + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int iw = ow * stride_w - padding_w + kw * (dilation_w); + int ih = oh * stride_h - padding_h + kh * (dilation_h); + if (iw < 0 || iw >= win) continue; + if (ih < 0 || ih >= hin) continue; + + int iidx = n * chin * hin * win + g * in_c_group * hin * win + + ic * hin * win + ih * win + iw; + int widx = + g * out_c_group * in_c_group * kernel_h * kernel_w + + oc * in_c_group * kernel_h * kernel_w + + ic * kernel_h * kernel_w + kh * kernel_w + kw; + + output_data[out_idx] += + (dtype)input_data[iidx] * (dtype)filter_data[widx]; + } + } + } + if (flag_relu) { + output_data[out_idx] = + output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f; + } + } + } + } + } + } +} + +TEST(conv_arm, retrive_op) { + auto conv = KernelRegistry::Global().Create( + "conv2d"); + ASSERT_FALSE(conv.empty()); + ASSERT_TRUE(conv.front()); +} + +TEST(conv_arm, init) { + ConvCompute conv; + ASSERT_EQ(conv.precision(), PRECISION(kFloat)); + ASSERT_EQ(conv.target(), TARGET(kARM)); +} + +TEST(conv_arm, compute) { + DeviceInfo::Init(); + for (auto n : {1, 2}) { + for (auto ic : {6, 32 /*, 128*/}) { + for (auto oc : {6, 32 /*, 128*/}) { + for (auto ih : {9, 18 /*, 56 , 112, 224, 512*/}) { + for (auto iw : {9, 18 /*, 56, 112, 224, 512*/}) { + for (auto flag_bias : {false, true}) { + for (auto flag_relu : {false, true}) { + for (auto depthwise : {false, true}) { + for (auto dilation : {1, 2}) { + for (auto stride : {1, 2}) { + for (auto padding : {0, 1, 2}) { + for (auto ks : {1, 3, 5}) { + int group = 1; + if (depthwise) { // depthwise convolution ? + group = oc = ic; + } + // get input, filter and output shape + std::vector input_shape = {n, ic, ih, iw}; + std::vector filter_shape = {oc, ic / group, + ks, ks}; + const int dks = dilation * (ks - 1) + 1; + int oh = (ih + 2 * padding - dks) / stride + 1; + int ow = (iw + 2 * padding - dks) / stride + 1; + std::vector output_shape({n, oc, oh, ow}); + // resize input, filter and output + Tensor input; + Tensor filter; + Tensor bias; + Tensor output; + Tensor output_ref; + input.Resize(input_shape); + filter.Resize(filter_shape); + output.Resize(output_shape); + output_ref.Resize(output_shape); + VLOG(3) << "input: " << input.dims(); + VLOG(3) << "filter: " << filter.dims() + << " padding:" << padding + << " stride:" << stride + << " dilation:" << dilation; + VLOG(3) << "output: " << output.dims(); + auto* input_data = input.mutable_data(); + auto* filter_data = filter.mutable_data(); + auto* output_data = output.mutable_data(); + for (int i = 0; i < input.dims().production(); i++) { + input_data[i] = static_cast(i % 128); + } + for (int i = 0; i < filter.dims().production(); i++) { + filter_data[i] = + i * 0.001f / + static_cast(filter.dims().production()); + } + // prepare kernel params and run + ConvCompute conv; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + conv.SetContext(std::move(ctx)); + operators::ConvParam param; + param.x = &input; + param.filter = &filter; + param.output = &output; + param.bias = nullptr; + if (flag_bias) { + bias.Resize({oc}); + auto* bias_data = bias.mutable_data(); + for (int i = 0; i < bias.dims().production(); i++) { + bias_data[i] = static_cast(i); + } + param.bias = &bias; + } + // TODO(hong19860320) param.relu = flag_relu; + param.paddings = std::vector({padding, padding}); + param.strides = std::vector({stride, stride}); + param.dilations = + std::vector({dilation, dilation}); + param.groups = group; + conv.SetParam(param); + conv.Launch(); + // invoking ref implementation and compare results + param.output = &output_ref; + conv_compute_ref(param); + auto* output_ref_data = + output_ref.mutable_data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], + 1e-3); + } + } + } + } + } + } + } + } + } + } + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/pool_compute.cc b/paddle/fluid/lite/kernels/arm/pool_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a7716fae6bfc3aa52dad7c8b8192191e986b6f3 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/pool_compute.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/pool_compute.h" +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void PoolCompute::Run() { + auto& param = Param(); + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + + const float* din = param.x->data(); + float* dout = param.output->mutable_data(); + + std::vector& ksize = param.ksize; + std::vector& strides = param.strides; + std::vector& paddings = param.paddings; + + std::string& pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + bool ceil_mode = param.ceil_mode; + bool use_quantizer = param.use_quantizer; + std::string& data_format = param.data_format; + + if (param.global_pooling) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(in_dims[i + 2]); + } + } + +#if 0 + for (int i = 0; i < in_dims.size(); ++i) { + LOG(INFO) << "in_dims[" << i << "]:" << in_dims[i]; + } + for (int i = 0; i < out_dims.size(); ++i) { + LOG(INFO) << "out_dims[" << i << "]:" << out_dims[i]; + } + for (int i = 0; i < ksize.size(); ++i) { + LOG(INFO) << "ksize[" << i << "]:" << ksize[i]; + } + for (int i = 0; i < strides.size(); ++i) { + LOG(INFO) << "strides[" << i << "]:" << strides[i]; + } + for (int i = 0; i < paddings.size(); ++i) { + LOG(INFO) << "paddings[" << i << "]:" << paddings[i]; + } + LOG(INFO) << "global_pooling:" << global_pooling; + LOG(INFO) << "exclusive:" << exclusive; + LOG(INFO) << "adaptive:" << adaptive; + LOG(INFO) << "ceil_mode:" << ceil_mode; + LOG(INFO) << "use_quantizer:" << use_quantizer; + LOG(INFO) << "data_format:" << data_format; + LOG(INFO) << "din:" << din; + LOG(INFO) << "dout:" << dout; +#endif + + // global + if (global_pooling == true) { + lite::arm::math::pooling_global( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1]) { + if (pooling_type == "max") { + lite::arm::math::pooling2x2s2_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling2x2s2_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 1 && + strides[0] == strides[1] && paddings[0] == 1) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s1p1_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s1p1_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1] && paddings[0] == 0) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s2p0_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s2p0_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 && + strides[0] == strides[1] && paddings[0] == 1) { + if (pooling_type == "max") { + lite::arm::math::pooling3x3s2p1_max( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } else if (pooling_type == "avg") { + lite::arm::math::pooling3x3s2p1_ave( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + } else { + lite::arm::math::pooling_basic( + din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3], + in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings, + global_pooling, exclusive, adaptive, ceil_mode, use_quantizer, + pooling_type); + } + return; +} + +TargetType PoolCompute::target() const { return TARGET(kARM); } + +PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); } + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pool, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::PoolCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/pool_compute.h b/paddle/fluid/lite/kernels/arm/pool_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..76dedbc3132405cd70d74e233619572f97dc07e0 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/pool_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/operators/pool_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class PoolCompute : public KernelLite { + public: + using param_t = operators::PoolParam; + + void Run() override; + + TargetType target() const override; + PrecisionType precision() const override; + + virtual ~PoolCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/pool_compute_test.cc b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..35873a9d2cc3fa922f48cc87e8e2c4191ac8ee60 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/pool_compute_test.cc @@ -0,0 +1,275 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/pool_compute.h" +#include +#include +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void pool_compute_ref(const operators::PoolParam& param) { + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + + const float* src_ptr = param.x->data(); + float* dst_ptr = param.output->mutable_data(); + + std::vector ksize = param.ksize; + std::vector strides = param.strides; + std::vector paddings = param.paddings; + + std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + bool exclusive = param.exclusive; + bool adaptive = param.adaptive; + bool ceil_mode = param.ceil_mode; + bool use_quantizer = param.use_quantizer; + std::string data_format = param.data_format; + + int in_n = in_dims[0]; + int in_c = in_dims[1]; + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int size_in_n = in_c * in_h * in_w; + int size_in_c = in_h * in_w; + + int out_h = out_dims[2]; + int out_w = out_dims[3]; + int size_out_n = in_c * out_h * out_w; + int size_out_c = out_h * out_w; + + int window_h = ksize[0]; + int window_w = ksize[1]; + int stride_h = strides[0]; + int stride_w = strides[1]; + int pad_h = paddings[0]; + int pad_w = paddings[1]; + + if (global_pooling == true) { + ksize[0] = in_h; + ksize[1] = in_w; + } + +#if 0 + for (int i = 0; i < ksize.size(); ++i) { + LOG(INFO) << "ksize[" << i << "]:" << ksize[i]; + } + for (int i = 0; i < strides.size(); ++i) { + LOG(INFO) << "strides[" << i << "]:" << strides[i]; + } + for (int i = 0; i < paddings.size(); ++i) { + LOG(INFO) << "paddings[" << i << "]:" << paddings[i]; + } + LOG(INFO) << "in nchw:" << in_n << ", " << in_c << ", " << in_h << ", " + << in_w; + LOG(INFO) << "size_in_n:" << size_in_n; + LOG(INFO) << "size_out_c:" << size_out_c; + LOG(INFO) << "out_h:" << out_h; + LOG(INFO) << "out_w:" << out_w; + LOG(INFO) << "size_out_n:" << size_out_n; + LOG(INFO) << "size_out_c:" << size_out_c; + LOG(INFO) << "window_h:" << window_h; + LOG(INFO) << "window_w:" << window_w; + LOG(INFO) << "stride_h:" << stride_h; + LOG(INFO) << "stride_w:" << stride_w; + LOG(INFO) << "pad_h:" << pad_h; + LOG(INFO) << "pad_w:" << pad_w; +#endif + + for (int ind_n = 0; ind_n < in_n; ++ind_n) { + for (int ind_c = 0; ind_c < in_c; ++ind_c) { + for (int ind_h = 0; ind_h < out_h; ++ind_h) { + int sh = ind_h * stride_h; + int eh = sh + window_h; + sh = (sh - pad_h) < 0 ? 0 : sh - pad_h; + eh = (eh - pad_h) > in_h ? in_h : eh - pad_h; + + for (int ind_w = 0; ind_w < out_w; ++ind_w) { + int sw = ind_w * stride_w; + int ew = sw + window_w; + sw = (sw - pad_w) < 0 ? 0 : sw - pad_w; + ew = (ew - pad_w) > in_w ? in_w : ew - pad_w; + + float result = static_cast(0); + + int dst_ind = + ind_n * size_out_n + ind_c * size_out_c + ind_h * out_w + ind_w; + + for (int kh = sh; kh < eh; ++kh) { + for (int kw = sw; kw < ew; ++kw) { + int src_ind = + ind_n * size_in_n + ind_c * size_in_c + kh * in_w + kw; + + if (kh == sh && kw == sw) { + result = src_ptr[src_ind]; + } else { + if (pooling_type == "max") { + result = + result >= src_ptr[src_ind] ? result : src_ptr[src_ind]; + } + if (pooling_type == "avg" && exclusive == false) { + // Pooling_average_include_padding + result += src_ptr[src_ind]; + } + if (pooling_type == "avg" && exclusive == true) { + // Pooling_average_include_padding + result += src_ptr[src_ind]; + } + } + } + } + if (pooling_type == "avg" && exclusive == false) { + // Pooling_average_include_padding + // result /= param.window_h * param.window_w; + // LOG(ERROR)<<"cpu"<= in_w + pad_w ? in_w + pad_w : sw + window_w; + bw -= sw; + } + if (eh == in_h) { + bh = sh + window_h >= in_h + pad_h ? in_h + pad_h : sh + window_h; + bh -= sh; + } + result /= bh * bw; + } + if (pooling_type == "avg" && exclusive == true) { + // Pooling_average_exclude_padding + result /= (ew - sw) * (eh - sh); + } + dst_ptr[dst_ind] = result; + } + } + } + } +} + +TEST(pool_arm, init) { + PoolCompute pool; + ASSERT_EQ(pool.precision(), PRECISION(kFloat)); + ASSERT_EQ(pool.target(), TARGET(kARM)); +} + +TEST(pool_arm, compute) { + PoolCompute pool; + operators::PoolParam param; + + lite::Tensor x; + lite::Tensor output; + lite::Tensor output_ref; + + for (auto pooling_type : {"avg", "max"}) { + for (auto global_pooling : {true}) { + for (auto stride : {2}) { + for (auto pad : {0}) { + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 11, 4, 1024}) { + for (auto h : {3, 1, 11, 4, 1}) { + for (auto w : {1, 3, 4, 12, 1}) { + VLOG(3) << "n:" << n << " c:" << c << " h:" << h << " w:" << w + << " stride:" << stride << " pad:" << pad + << " pooling_type:" << pooling_type + << " global_pooling:" << global_pooling; + + // init x, output + x.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, 1, 1}))); + output_ref.Resize(DDim(std::vector({n, c, 1, 1}))); + auto* x_data = x.mutable_data(); + for (int i = 0; i < x.dims().production(); ++i) { + x_data[i] = i; + } + + // fill param + param.x = &x; + param.output = &output; + param.pooling_type = pooling_type; + param.ksize = {h, w}; + param.global_pooling = global_pooling; + param.strides = {stride, stride}; + param.paddings = {pad, pad}; + param.exclusive = true; + param.adaptive = false; + param.ceil_mode = false; + param.use_quantizer = false; + + // compute + pool.SetParam(param); + pool.Run(); + +#if 0 + LOG(INFO) << "n:" << n << " c:" << c << " h:" << h << " w:" << w + << " end"; + std::cout << "n:" << n << " c:" << c << " h:" << h << " w:" << w + << " end" << std::endl; + for (int i = 0; i < param.ksize.size(); ++i) { + std::cout << " ksize[" << i << "]:" << param.ksize[i]; + } + std::cout << "\n"; + for (int i = 0; i < param.strides.size(); ++i) { + std::cout << " strides[" << i << "]:" << param.strides[i]; + } + std::cout << "\n"; + for (int i = 0; i < param.paddings.size(); ++i) { + std::cout << " paddings[" << i << "]:" << param.paddings[i]; + } + std::cout << "\n"; +#endif + + // compute ref + // output_ref.Resize(output.dims()); + param.output = &output_ref; + pool_compute_ref(param); + VLOG(3) << "pool_compute_ref(param) end"; + + // compare + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], + 1); // 1e-5); + } + + VLOG(3) << "compare pass"; + } + } + } + } + } // pad + } // stride + } // global_pooling + } // pooling_type +} + +TEST(pool, retrive_op) { + auto pool = + KernelRegistry::Global().Create("pool"); + ASSERT_FALSE(pool.empty()); + ASSERT_TRUE(pool.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/split_compute.cc b/paddle/fluid/lite/kernels/arm/split_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..9da69894592e146c9191eb9da38d8d481cf287a7 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/split_compute.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/split_compute.h" +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void SplitCompute::Run() { + auto& param = Param(); + const float* din = param.x->data(); + auto* dout = param.output; + auto in_dim = param.x->dims(); + std::vector in_strides(in_dim.size()); + in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; + for (int i = in_dim.size() - 2; i >= 0; --i) { + in_strides[i] = in_strides[i + 1] * in_dim[i]; + } + lite::arm::math::split(din, dout, param.axis, in_strides); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::SplitCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/split_compute.h b/paddle/fluid/lite/kernels/arm/split_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..22701ba0fd9a77149939933c2e9fcc0c9295e3a1 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/split_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class SplitCompute : public KernelLite { + public: + void Run() override; + + virtual ~SplitCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/split_compute_test.cc b/paddle/fluid/lite/kernels/arm/split_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..808a1e2cdb7724042ffcd1324cf0dc2c5e28f2fc --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/split_compute_test.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/split_compute.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void splite_resize_out(const lite::Tensor* din, + std::vector* dout, int axis, int num, + const std::vector& sections) { + for (auto out : *dout) delete out; + dout->clear(); + auto in_dims = din->dims(); + int outs_number; + if (num > 0) { + outs_number = num; + } else { + outs_number = sections.size(); + } + for (int i = 0; i < outs_number; i++) { + dout->push_back(new lite::Tensor); + } + + std::vector outs_dims; + outs_dims.reserve(outs_number); + + if (num > 0) { + int out_axis_dim = in_dims[axis] / num; + for (int i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = sections[i]; + outs_dims.push_back(dim); + } + } + + for (int j = 0; j < outs_dims.size(); ++j) { + (*dout)[j]->Resize(outs_dims[j]); + } +} + +template +void split_compute_ref(const operators::SplitParam& param) { + const dtype* din = param.x->mutable_data(); + auto& dout = param.output; + auto in_dim = param.x->dims(); + int axis = param.axis; + std::vector in_strides(in_dim.size()); + in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1]; + for (int i = in_dim.size() - 2; i >= 0; --i) { + in_strides[i] = in_strides[i + 1] * in_dim[i]; + } + + int input_offset = 0; + for (auto out : *dout) { + auto out_dim = out->dims(); + std::vector out_strides(out_dim.size()); + out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1]; + for (int i = out_dim.size() - 2; i >= 0; --i) { + out_strides[i] = out_strides[i + 1] * out_dim[i]; + } + + dtype* out_data = out->mutable_data(); + int before = out_strides[0] / out_strides[axis]; + int in_after = in_strides[axis]; + int out_after = out_strides[axis]; + + for (int i = 0; i < before; ++i) { + std::memcpy(out_data + i * out_after, din + input_offset + i * in_after, + sizeof(dtype) * out_after); + } + input_offset += out_strides[axis]; + } +} + +TEST(split_arm, init) { + SplitCompute split; + ASSERT_EQ(split.precision(), PRECISION(kFloat)); + ASSERT_EQ(split.target(), TARGET(kARM)); +} + +TEST(split_arm, compute) { + SplitCompute split; + operators::SplitParam param; + + lite::Tensor x; + std::vector output; + std::vector output_ref; + + for (auto n : {1, 3, 4}) { + for (auto c : {1, 3, 4}) { + for (auto h : {1, 3, 4}) { + for (auto w : {1, 3, 4}) { + for (auto axis : {0, 1, 2, 3}) { + for (auto num : {0, 1, 2, 3}) { + for (auto sections : + {std::vector{1, 1, 1}, std::vector{2, 2}, + std::vector{1, 2}}) { + auto x_dim = DDim(std::vector({n, c, h, w})); + x.Resize(x_dim); + if ((num != 0 && x_dim[axis] % num != 0) || + (num == 0 && x_dim[axis] % sections.size() != 0)) + continue; + auto* x_data = x.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + splite_resize_out(&x, &output, axis, num, sections); + splite_resize_out(&x, &output_ref, axis, num, sections); + param.x = &x; + param.axis = axis; + param.num = num; + param.sections = §ions; + param.output = &output; + split.SetParam(param); + split.Run(); + param.output = &output_ref; + split_compute_ref(param); + for (int i = 0; i < output.size(); i++) { + float* output_data = output[i]->mutable_data(); + float* output_ref_data = output_ref[i]->mutable_data(); + for (int j = 0; j < output[i]->dims().production(); j++) { + EXPECT_NEAR(output_data[j], output_ref_data[j], 1e-5); + } + } + } + } + } + } + } + } + } +} + +TEST(split, retrive_op) { + auto split = + KernelRegistry::Global().Create("split"); + ASSERT_FALSE(split.empty()); + ASSERT_TRUE(split.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(split, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/use_kernels.h b/paddle/fluid/lite/kernels/arm/use_kernels.h index d856950f3a177d08cdc950c259abf3d1a194ee25..1f93a81aa94f09f8330aa385840adec559d7161d 100644 --- a/paddle/fluid/lite/kernels/arm/use_kernels.h +++ b/paddle/fluid/lite/kernels/arm/use_kernels.h @@ -19,5 +19,6 @@ USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index ed26f5fdb1f8cec9780c686cd2b73a6699170120..9a90666420e94bdb585feeac689d9227fc6a2104 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -1,5 +1,7 @@ set(op_DEPS ${tensor_lite} op_lite op_params_lite) +cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) +cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS}) cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) @@ -17,10 +19,11 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) -cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) -cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) +cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) set(ops_lite + conv_op_lite + pool_op_lite fc_op_lite relu_op_lite mul_op_lite @@ -36,14 +39,16 @@ set(ops_lite activation_ops_lite dropout_op_lite concat_op_lite - conv_op_lite - pool_op_lite + split_op_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite X86_DEPS fc_compute_x86 - ARM_DEPS fc_compute_arm) + ARM_DEPS fc_compute_arm) +lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc + DEPS pool_op_lite memory_lite + ARM_DEPS pool_compute_arm) lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/conv_op.cc b/paddle/fluid/lite/operators/conv_op.cc index 63838efd6fe57150dd09ca8d2608ec81f056e3dc..948e2a0641c28ed03dc5dc6cb30d60c80cec129c 100644 --- a/paddle/fluid/lite/operators/conv_op.cc +++ b/paddle/fluid/lite/operators/conv_op.cc @@ -24,31 +24,49 @@ bool ConvOpLite::CheckShape() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.filter); - return true; -} + // bias is optional. -bool ConvOpLite::InferShape() const { - auto in_dims = param_.x->dims(); - auto filter_dims = param_.filter->dims(); - std::vector strides = param_.strides; - std::vector paddings = param_.paddings; - int groups = param_.groups; - std::vector dilations = param_.dilations; + const auto in_dims = param_.x->dims(); + const auto filter_dims = param_.filter->dims(); CHECK_OR_FALSE(in_dims.size() == 4 || in_dims.size() == 5); + CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size()); - CHECK_OR_FALSE(in_dims.size() - strides.size() == 2U); - CHECK_EQ_OR_FALSE(paddings.size(), strides.size()); - CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * groups); - CHECK_EQ_OR_FALSE(filter_dims[0] % groups, 0); + CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U); + CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size()); + + CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * param_.groups); + CHECK_EQ_OR_FALSE(filter_dims[0] % param_.groups, 0); + CHECK_EQ_OR_FALSE(filter_dims.size(), 4UL); + + return true; +} + +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + CHECK_GT_OR_FALSE(output_size, 0); + + return output_size; +} + +bool ConvOpLite::InferShape() const { + const auto in_dims = param_.x->dims(); + const auto filter_dims = param_.filter->dims(); std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], - dilations[i], paddings[i], - strides[i])); + for (size_t i = 0; i < param_.strides.size(); ++i) { + output_shape.push_back( + ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], param_.dilations[i], + param_.paddings[i], param_.strides[i])); } + + // Set output dims param_.output->Resize(lite::DDim(output_shape)); + + // share LoD + // param_.output->set_lod(param_.x->lod()); return true; } diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h index 3ab30eb787bd9574a10cc9198f4c08b744eb0c27..393b5dc2a8e5e9aa8d94784bc4f5a8d041414200 100644 --- a/paddle/fluid/lite/operators/conv_op.h +++ b/paddle/fluid/lite/operators/conv_op.h @@ -26,63 +26,53 @@ namespace paddle { namespace lite { namespace operators { -inline int ConvOutputSize(int input_size, int filter_size, int dilation, - int padding, int stride) { - const int dkernel = dilation * (filter_size - 1) + 1; - int output_size = (input_size + 2 * padding - dkernel) / stride + 1; - CHECK_OR_FALSE(output_size > 0); - - return output_size; -} - -inline bool IsExpand(const std::vector& filter_dim, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations) { - bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; - for (size_t j = 0; j < strides.size(); ++j) { - filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); - strides_1 = strides_1 && (strides[j] == 1); - padding_0 = padding_0 && (paddings[j] == 0); - dilation_1 = dilation_1 && (dilations[j] == 1); - } - return !(filter_1 && strides_1 && padding_0 && dilation_1); -} - class ConvOpLite : public OpLite { public: ConvOpLite() {} - explicit ConvOpLite(const std::string& type) : OpLite(type) {} + explicit ConvOpLite(const std::string &type) : OpLite(type) {} bool CheckShape() const override; bool InferShape() const override; - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { - auto X = op_desc.Input("Input").front(); - auto Filter = op_desc.Input("Filter").front(); - auto Bias = op_desc.Input("Bias").front(); - // auto ResidualData = op_desc.Input("ResidualData"); - auto Out = op_desc.Output("Output").front(); - - param_.x = scope->FindVar(X)->GetMutable(); - param_.filter = scope->FindVar(Filter)->GetMutable(); - param_.bias = scope->FindVar(Bias)->GetMutable(); - // param_.residualData = - // scope->FindVar(ResidualData)->GetMutable(); - param_.output = scope->FindVar(Out)->GetMutable(); - + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto input = op_desc.Input("Input").front(); + auto filter = op_desc.Input("Filter").front(); + auto out = op_desc.Output("Out").front(); + param_.x = scope->FindVar(input)->GetMutable(); + param_.filter = scope->FindVar(filter)->GetMutable(); + CHECK(scope->FindVar(out)); + param_.output = scope->FindVar(out)->GetMutable(); param_.strides = op_desc.GetAttr>("strides"); param_.paddings = op_desc.GetAttr>("paddings"); param_.groups = op_desc.GetAttr("groups"); param_.dilations = op_desc.GetAttr>("dilations"); - + // optional params + std::vector input_arg_names = op_desc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != + input_arg_names.end()) { + auto bias_var = scope->FindVar(op_desc.Input("Bias").front()); + if (bias_var != nullptr) { + param_.bias = + const_cast(&(bias_var->Get())); + } + } + if (std::find(input_arg_names.begin(), input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + auto residual_data_var = + scope->FindVar(op_desc.Input("ResidualData").front()); + if (residual_data_var != nullptr) { + param_.residualData = const_cast( + &(residual_data_var->Get())); + } + } return true; } + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "conv2d"; } private: diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 23b21cb276442d4e1da8b83557007a132c9de3fb..eee0d90dba2f3aad86a94983e0ac8fd67127b420 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -124,8 +124,8 @@ struct ConcatParam { struct ConvParam { lite::Tensor* x{}; lite::Tensor* filter{}; - lite::Tensor* bias{}; - lite::Tensor* residualData{}; + lite::Tensor* bias{nullptr}; + lite::Tensor* residualData{nullptr}; lite::Tensor* output{}; std::vector strides{1, 1}; std::vector paddings{0, 0}; @@ -174,6 +174,15 @@ struct DropoutParam { std::string dropout_implementation{"downgrade_in_infer"}; }; +// For Split op +struct SplitParam { + lite::Tensor* x{}; + std::vector* output{}; + int axis{-1}; + int num{0}; + std::vector* sections; +}; + /// ----------------------- element wise operators ---------------------- struct ElementwiseParam { const lite::Tensor* X{}; diff --git a/paddle/fluid/lite/operators/pool_op.cc b/paddle/fluid/lite/operators/pool_op.cc index 055f00f90a47766d5a76bcf01cae3f68e14d71e2..3faf2bf0fa4f3290921a6b40739d39a2f10b9c41 100644 --- a/paddle/fluid/lite/operators/pool_op.cc +++ b/paddle/fluid/lite/operators/pool_op.cc @@ -19,6 +19,27 @@ namespace paddle { namespace lite { namespace operators { +bool PoolOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + + const auto& x_dims = param_.x->dims(); + const auto& ksize = param_.ksize; + const auto& strides = param_.strides; + const auto& paddings = param_.paddings; + + // "Pooling intput should be 4-D or 5-D tensor." + CHECK_OR_FALSE(x_dims.size() == 4 || x_dims.size() == 5); + // Input size and pooling size should be consistent. + CHECK_OR_FALSE(x_dims.size() - ksize.size() == 2U); + // Strides size and pooling size should be the same. + CHECK_OR_FALSE(ksize.size() == strides.size()); + // Paddings size and pooling size should be the same. + CHECK_OR_FALSE(ksize.size() == paddings.size()); + + return true; +} + int PoolOutputSize(int input_size, int filter_size, int padding, int stride, bool ceil_mode) { int output_size; @@ -28,46 +49,35 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride, output_size = (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; } - CHECK_OR_FALSE(output_size > 0); return output_size; } -bool PoolOpLite::CheckShape() const { - CHECK_OR_FALSE(param_.x); - CHECK_OR_FALSE(param_.output); - return true; -} - bool PoolOpLite::InferShape() const { - const auto input_dims = param_.x->dims(); - CHECK_OR_FALSE(input_dims.size() == 4 || input_dims.size() == 5); - + const auto x_dims = param_.x->dims(); + std::vector& ksize = param_.ksize; if (param_.global_pooling) { - param_.ksize.resize(static_cast(input_dims.size()) - 2); - for (size_t i = 0; i < param_.ksize.size(); ++i) { + ksize.resize(static_cast(x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) { param_.paddings[i] = 0; - param_.ksize[i] = static_cast(input_dims[i + 2]); + ksize[i] = static_cast(x_dims[i + 2]); } } - CHECK_OR_FALSE(input_dims.size() - param_.ksize.size() == 2U); - CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.strides.size()); - CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.paddings.size()); - - std::vector output_shape({input_dims[0], input_dims[1]}); + std::vector output_shape({x_dims[0], x_dims[1]}); if (param_.adaptive) { output_shape.insert(output_shape.end(), param_.ksize.begin(), param_.ksize.end()); } else { for (size_t i = 0; i < param_.ksize.size(); ++i) { output_shape.push_back( - PoolOutputSize(input_dims[i + 2], param_.ksize[i], param_.paddings[i], + PoolOutputSize(x_dims[i + 2], param_.ksize[i], param_.paddings[i], param_.strides[i], param_.ceil_mode)); } } - // share LoD - // param_.output->set_lod(param_.input->lod()); param_.output->Resize(lite::DDim(output_shape)); + + // ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + // ctx->ShareLoD("X", "Out"); return true; } diff --git a/paddle/fluid/lite/operators/pool_op.h b/paddle/fluid/lite/operators/pool_op.h index 64c15ccf1db813c2a4d0465b86ed3c6d46091f73..2e9a02eec189599ba2fc23da8e7bcc9ebd0ea8b3 100644 --- a/paddle/fluid/lite/operators/pool_op.h +++ b/paddle/fluid/lite/operators/pool_op.h @@ -13,8 +13,10 @@ // limitations under the License. #pragma once + #include #include +#include "paddle/fluid/lite/core/compatible_tensor.h" #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/scope.h" @@ -35,24 +37,32 @@ class PoolOpLite : public OpLite { bool InferShape() const override; - void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { - auto input = op_desc.Input("X").front(); + auto x = op_desc.Input("X").front(); auto out = op_desc.Output("Out").front(); - param_.x = scope->FindVar(input)->GetMutable(); - param_.output = scope->FindVar(out)->GetMutable(); + CHECK(scope->FindVar(x)); + CHECK(scope->FindVar(out)); + param_.x = scope->FindVar(x)->GetMutable(); + param_.output = scope->FindVar(out)->GetMutable(); + param_.pooling_type = op_desc.GetAttr("pooling_type"); param_.ksize = op_desc.GetAttr>("ksize"); + param_.global_pooling = op_desc.GetAttr("global_pooling"); param_.strides = op_desc.GetAttr>("strides"); param_.paddings = op_desc.GetAttr>("paddings"); - param_.ceil_mode = op_desc.GetAttr("ceil_mode"); + + param_.exclusive = op_desc.GetAttr("exclusive"); param_.adaptive = op_desc.GetAttr("adaptive"); - param_.global_pooling = op_desc.GetAttr("global_pooling"); + param_.ceil_mode = op_desc.GetAttr("ceil_mode"); + param_.use_quantizer = op_desc.GetAttr("use_quantizer"); + // param_.data_format = op_desc.GetAttr("data_format"); return true; } + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "pool"; } private: diff --git a/paddle/fluid/lite/operators/pool_op_test.cc b/paddle/fluid/lite/operators/pool_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf46a2ecbd8a465fa5a52bc099389ff3838a5840 --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op_test.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/pool_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(pool_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x->Resize(DDim(std::vector({1, 3, 224, 224}))); + output->Resize(DDim(std::vector{1, 3, 112, 112})); + + // set data + for (int i = 0; i < 1 * 3 * 224 * 224; i++) { + x->mutable_data()[i] = i; + } + for (int i = 0; i < 1 * 3 * 112 * 112; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("pool"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + + std::string pooling_type("max"); + desc.SetAttr("pooling_type", pooling_type); + // desc.SetAttr("ksize", static_cast>({2, 2})); + std::vector ksize{2, 2}; + desc.SetAttr("ksize", ksize); + + bool global_pooling{false}; + desc.SetAttr("global_pooling", global_pooling); + + std::vector strides{1, 1}; + desc.SetAttr("strides", strides); + + std::vector paddings{0, 0}; + desc.SetAttr("paddings", paddings); + + bool exclusive{true}; + desc.SetAttr("exclusive", exclusive); + + bool adaptive{false}; + desc.SetAttr("adaptive", adaptive); + + bool ceil_mode{false}; + desc.SetAttr("ceil_mode", ceil_mode); + + bool use_quantizer{false}; + desc.SetAttr("use_quantizer", use_quantizer); + + PoolOpLite pool("pool"); + pool.SetValidPlaces({Place{TARGET(kARM), PRECISION(kFloat)}}); + pool.Attach(desc, &scope); + auto kernels = pool.CreateKernels({Place{TARGET(kARM), PRECISION(kFloat)}}); + LOG(INFO) << "kernels.size(): " << kernels.size(); + ASSERT_FALSE(kernels.empty()); +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def); +#endif diff --git a/paddle/fluid/lite/operators/split_op.cc b/paddle/fluid/lite/operators/split_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c788e9cf9546a8c058398d71fde7aa4295fe8fbc --- /dev/null +++ b/paddle/fluid/lite/operators/split_op.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/split_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SplitOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + CHECK_OR_FALSE(param_.axis >= -static_cast(x_rank) && + param_.axis < static_cast(x_rank)); + return true; +} + +bool SplitOp::InferShape() const { + const auto &outs = param_.output; + auto in_dims = param_.x.dims(); + int axis = param_.axis; + int num = param_.num; + const auto §ions = param_.sections; + + const int outs_number = outs.size(); + std::vector outs_dims; + outs_dims.reserve(outs_number); + + if (num > 0) { + int out_axis_dim = in_dims[axis] / num; + for (int i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = out_axis_dim; + outs_dims.push_back(dim); + } + } else if (sections.size() > 0) { + for (size_t i = 0; i < outs_number; ++i) { + auto dim = in_dims; + dim[axis] = sections[i]; + outs_dims.push_back(dim); + } + } + + for (int j = 0; j < outs_dims.size(); ++j) { + outs[j]->Resize(outs_dims[j]); + } + + return true; +} + +bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.axis = opdesc.GetAttr("axis"); + param_.num = opdesc.GetAttr("num"); + param_.sections = opdesc.GetAttr>("sections"); + param_.x = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + auto outs = op_desc.Output("Out"); + for (auto var : outs) { + param_.output.push_back(scope->FindVar(var)->GetMutable()); + } + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); diff --git a/paddle/fluid/lite/operators/split_op.h b/paddle/fluid/lite/operators/split_op.h new file mode 100644 index 0000000000000000000000000000000000000000..177c44171e6e67214f820f04e801be6c01df01cc --- /dev/null +++ b/paddle/fluid/lite/operators/split_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SoftmaxOp : public OpLite { + public: + SplitOp() {} + explicit SplitOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "split"; } + + private: + mutable SplitParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/utils/any.h b/paddle/fluid/lite/utils/any.h index 466deae3de92ad5992695a505108e1e31b68a826..2a8c68063f0b17beb72b597d236f71e1a5c2bb79 100644 --- a/paddle/fluid/lite/utils/any.h +++ b/paddle/fluid/lite/utils/any.h @@ -34,7 +34,6 @@ class Any { CHECK(type_ == typeid(T).hash_code()); } else { type_ = typeid(T).hash_code(); - data_ = new T; deleter_ = [&] { delete static_cast(data_); }; } data_ = new T; @@ -55,10 +54,16 @@ class Any { bool valid() const { return data_; } + // ~Any() { + // if (valid()) { + // deleter_(); + // } + // } + private: static size_t kInvalidType; size_t type_{kInvalidType}; - void* data_{}; + void* data_{nullptr}; std::function deleter_; };