提交 8642fb8a 编写于 作者: T tensor-tang

Merge remote-tracking branch 'gitlab/develop' into incubate/lite

...@@ -10,7 +10,10 @@ paddle/fluid/operators/distributed/send_recv.proto ...@@ -10,7 +10,10 @@ paddle/fluid/operators/distributed/send_recv.proto
*.vs *.vs
build/ build/
build_doc/ build_doc/
build.*
*.user *.user
*.sh
*.bkp
.vscode .vscode
.idea .idea
......
...@@ -43,7 +43,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) ...@@ -43,7 +43,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
if(NOT DEFINED TARGET_ARCH_ABI) if(NOT DEFINED TARGET_ARCH_ABI)
set(ARCH_ABI "arm64-v8a" CACHE STRING "Choose android platform") set(ARCH_ABI "arm64-v8a" CACHE STRING "Choose android platform")
endif() endif()
include(cross_compiling/host) include(cross_compiling/host)
include(cross_compiling/armlinux) include(cross_compiling/armlinux)
include(cross_compiling/android) include(cross_compiling/android)
......
...@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
......
...@@ -32,9 +32,9 @@ void Run(const char* model_dir) { ...@@ -32,9 +32,9 @@ void Run(const char* model_dir) {
valid_places); valid_places);
auto* input_tensor = predictor.GetInput(0); auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({100, 100}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({3, 224, 224})));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) { for (int i = 0; i < 3 * 224 * 224; i++) {
data[i] = i; data[i] = i;
} }
...@@ -65,6 +65,14 @@ USE_LITE_OP(feed); ...@@ -65,6 +65,14 @@ USE_LITE_OP(feed);
USE_LITE_OP(fetch); USE_LITE_OP(fetch);
USE_LITE_OP(io_copy); USE_LITE_OP(io_copy);
USE_LITE_OP(con2d);
// USE_LITE_OP(batch_norm);
USE_LITE_OP(relu);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(pool2d);
USE_LITE_OP(elementwise_add);
USE_LITE_OP(softmax);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
...@@ -72,7 +80,15 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); ...@@ -72,7 +80,15 @@ USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(con2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_con2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); // USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); // USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
#endif // LITE_WITH_ARM #endif // LITE_WITH_ARM
......
...@@ -72,8 +72,9 @@ class LightPredictor { ...@@ -72,8 +72,9 @@ class LightPredictor {
// Create the kernels of the target places, and filter out the specific // Create the kernels of the target places, and filter out the specific
// kernel with the target alias. // kernel with the target alias.
for (auto& op : program.ops()) { for (auto& op : program.ops_) {
auto kernel_type = op->op_info()->GetAttr<std::string>(kKernelTypeAttr); lite::pb::OpDesc desc(op->op_info()->desc());
auto kernel_type = desc.GetAttr(kKernelTypeAttr).get<std::string>();
std::string op_type, alias; std::string op_type, alias;
Place place; Place place;
KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place);
...@@ -88,8 +89,8 @@ class LightPredictor { ...@@ -88,8 +89,8 @@ class LightPredictor {
insts.emplace_back(op, std::move(*it)); insts.emplace_back(op, std::move(*it));
} }
program_.reset(new RuntimeProgram(std::move(insts))); program_.reset(new RuntimeProgram(std::move(insts)));
CHECK(program.exec_scope()); CHECK(program.exec_scope_);
program_->set_exec_scope(program.exec_scope()); program_->set_exec_scope(program.exec_scope_);
} }
private: private:
......
...@@ -6,4 +6,31 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) ...@@ -6,4 +6,31 @@ if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM))
return() return()
endif() endif()
cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc scale.cc elementwise.cc DEPS ${lite_kernel_deps} eigen3) # TODO(xxx): seperate them
cc_library(math_arm SRCS
funcs.cc
packed_sgemm.cc
softmax.cc
scale.cc
pooling.cc
elementwise.cc
sgemv.cc
type_trans.cpp
conv_impl.cc
conv_direct_3x3s1.cc
conv_direct_3x3s2.cc
conv_direct.cc
conv_depthwise_3x3_int7.cc
conv_depthwise_3x3_int8.cc
conv_depthwise_5x5s1_int8.cc
conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc
conv_depthwise_5x5s1.cc
conv_depthwise_5x5s2.cc
conv_depthwise.cc
conv_gemmlike.cc
conv_winograd_3x3.cc
conv_winograd.cc
split.cc
DEPS ${lite_kernel_deps} eigen3)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/pooling.h"
#include <algorithm>
#include <limits>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void pooling_basic(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
// no need to pad input tensor, border is zero pad inside this function
int kernel_h = ksize[0];
int kernel_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
int size_channel_in = win * hin;
int size_channel_out = wout * hout;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
if (global_pooling) {
if (pooling_type == "max") { // Pooling_max
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; ++c) {
const float* data_in_channel =
data_in_batch + c * size_channel_in; // in address
data_out_batch[c] = data_in_channel[0];
for (int i = 0; i < size_channel_in; ++i) {
data_out_batch[c] = data_out_batch[c] > data_in_channel[i]
? data_out_batch[c]
: data_in_channel[i];
}
}
}
} else if (pooling_type == "avg") {
// Pooling_average_include_padding
// Pooling_average_exclude_padding
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; ++c) {
const float* data_in_channel =
data_in_batch + c * size_channel_in; // in address
float sum = 0.f;
for (int i = 0; i < size_channel_in; ++i) {
sum += data_in_channel[i];
}
data_out_batch[c] = sum / size_channel_in;
}
}
} else {
LOG(FATAL) << "not support";
}
return;
}
if (pooling_type == "max") {
// Pooling_max
for (int n = 0; n < num; ++n) {
float* data_out_channel = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int q = 0; q < chout; q++) {
float* data_out_row = data_out_channel + q * size_channel_out;
const float* data_in_channel = data_in_batch + q * size_channel_in;
for (int i = 0; i < hout; i++) {
for (int j = 0; j < wout; j++) {
int hstart = i * stride_h - pad_h;
int wstart = j * stride_w - pad_w;
int hend = std::min(hstart + kernel_h, hin + pad_h);
int wend = std::min(wstart + kernel_w, win + pad_w);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, hin);
wend = std::min(wend, win);
data_out_row[j] = data_in_channel[hstart * win + wstart];
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
data_out_row[j] = data_out_row[j] > data_in_channel[h * win + w]
? data_out_row[j]
: data_in_channel[h * win + w];
}
}
}
data_out_row += wout;
}
}
}
} else if (pooling_type == "avg") {
if (exclusive == false) {
// Pooling_average_include_padding
for (int n = 0; n < num; ++n) {
int pool_size =
kernel_w *
kernel_h; // (hend - hstart) * (wend - wstart); // problem
float* data_out_channel = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int q = 0; q < chout; q++) {
float* data_out_row = data_out_channel + q * size_channel_out;
const float* data_in_channel = data_in_batch + q * size_channel_in;
for (int i = 0; i < hout; i++) {
for (int j = 0; j < wout; j++) {
int hstart = i * stride_h - pad_h;
int wstart = j * stride_w - pad_w;
int hend = std::min(hstart + kernel_h, hin + pad_h);
int wend = std::min(wstart + kernel_w, win + pad_w);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, hin);
wend = std::min(wend, win);
int bh = kernel_h;
int bw = kernel_w;
if (wend == win) {
bw = wstart + kernel_w >= win + pad_w ? win + pad_w
: wstart + kernel_w;
bw -= wstart;
}
if (hend == hin) {
bh = hstart + kernel_h >= hin + pad_h ? hin + pad_h
: hstart + kernel_h;
bh -= hstart;
}
pool_size = bh * bw;
data_out_row[j] = data_in_channel[hstart * win + wstart];
float sum = 0.f;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += data_in_channel[h * win + w];
}
}
data_out_row[j] = sum / pool_size;
}
data_out_row += wout;
}
}
}
} else { // exclusive == true, Pooling_average_exclude_padding
for (int n = 0; n < num; ++n) {
float* data_out_channel = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int q = 0; q < chout; q++) {
float* data_out_row = data_out_channel + q * size_channel_out;
const float* data_in_channel = data_in_batch + q * size_channel_in;
for (int i = 0; i < hout; i++) {
for (int j = 0; j < wout; j++) {
int hstart = i * stride_h - pad_h;
int wstart = j * stride_w - pad_w;
int hend = std::min(hstart + kernel_h, hin + pad_h);
int wend = std::min(wstart + kernel_w, win + pad_w);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, hin);
wend = std::min(wend, win);
data_out_row[j] = data_in_channel[hstart * win + wstart];
float sum = 0.f;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += data_in_channel[h * win + w];
}
}
int pool_size = (hend - hstart) * (wend - wstart);
data_out_row[j] = sum / pool_size;
}
data_out_row += wout;
}
}
}
}
} else {
LOG(FATAL) << "not support";
}
}
void pooling_global(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int size_channel_in = win * hin;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int cnt = size_channel_in / 8;
#if 0
LOG(INFO) << "size_channel_in:" << size_channel_in;
LOG(INFO) << "cnt:" << cnt;
LOG(INFO) << "num:" << num;
LOG(INFO) << "chout:" << chout;
LOG(INFO) << "hout:" << hout;
LOG(INFO) << "wout:" << wout;
LOG(INFO) << "chin:" << chin;
LOG(INFO) << "hin:" << hin;
LOG(INFO) << "win:" << win;
LOG(INFO) << "pooling_type " << pooling_type;
#endif
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout;
const float* data_in_batch = data_in + n * chin * size_channel_in;
if (pooling_type == "max") {
#pragma omp parallel for
for (int c = 0; c < chout; ++c) {
const float* data_in_channel = data_in_batch + c * size_channel_in;
int i = 0;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vmax = vdupq_n_f32(minval);
#ifdef __aarch64__
for (; i < cnt; i++) {
float32x4_t vdin1 = vld1q_f32(data_in_channel);
vmax = vmaxq_f32(vdin1, vmax);
float32x4_t vdin2 = vld1q_f32(data_in_channel + 4);
vmax = vmaxq_f32(vmax, vdin2);
data_in_channel += 8;
}
#else
int num = cnt;
if (num > 0) {
asm volatile(
"max_loop: @main loop\n"
"vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, "
"data_in_channel\n"
"vmax.f32 %q[vmax], %q[vmax], q0 @max vmax, "
"vmax, data_in_channel\n"
"vld1.f32 {d2-d3}, [%[data_in_channel]]! @ load 2nd 4 "
"data"
"vmax.f32 %q[vmax], %q[vmax], q1 @ compare 2nd "
"4 datas\n"
"subs %[num], #1 @subs num, 1\n"
"bne max_loop @bne num\n"
: [data_in_channel] "+r"(data_in_channel), [num] "+r"(num),
[vmax] "+w"(vmax)
:
: "cc", "memory", "q0", "q1");
}
#endif // __aarch64__
float32x2_t vmax_tmp =
vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
float tmp1 = vget_lane_f32(vmax_tmp, 0);
float tmp2 = vget_lane_f32(vmax_tmp, 1);
float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2;
for (i = cnt * 8; i < size_channel_in; ++i) {
/* code */
max_tmp = max_tmp > data_in_channel[0] ? max_tmp : data_in_channel[0];
data_in_channel++;
}
data_out_batch[c] = max_tmp;
}
} else {
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
const float* data_in_channel =
data_in_batch + c * size_channel_in; // in address
int i = 0;
float32x4_t vsum = vdupq_n_f32(0.0f);
#ifdef __aarch64__
for (; i < cnt; i++) { //
vsum = vaddq_f32(vld1q_f32(data_in_channel), vsum);
data_in_channel += 4;
}
#else
int num = cnt;
if (num > 0) {
asm volatile(
"add_loop: @main loop\n"
"vld1.f32 {d0-d1}, [%[data_in_channel]]! @load q1, "
"data_in_channel\n"
"vadd.f32 %q[vsum], %q[vsum], q0 @add vmax, "
"vmax, data_in_channel\n"
"subs %[num], #1 @subs num, 1\n"
"bne add_loop @bne num\n"
: [data_in_channel] "+r"(data_in_channel), [num] "+r"(num),
[vsum] "+w"(vsum)
:
: "cc", "memory", "q0");
}
#endif // __aarch64__
float32x2_t vsum_tmp =
vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum));
float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1);
for (i = cnt * 4; i < size_channel_in; i++) {
sum += data_in_channel[0];
data_in_channel++;
}
data_out_batch[c] = sum / size_channel_in;
}
}
}
}
void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int w_even = (win >> 1) << 1;
// int w_remains = w_in - w_even; // should be 0 or 1
int h_even = (hin >> 1) << 1;
// int h_remains = h_in - h_even; // should be 0 or 1
int w_unroll_size = (w_even >> 3) << 3;
// int w_unroll_remian = w_even - w_unroll_size;
int w_in_2 = win << 1;
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
int h = 0;
for (; h < h_even; h += 2) {
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
float32x4_t dr10 = vld1q_f32(&r1[w]);
float32x4_t dr11 = vld1q_f32(&r1[w + 4]);
float32x4_t dmax1 = vmaxq_f32(dr00, dr10);
float32x4_t dmax2 = vmaxq_f32(dr01, dr11);
#ifdef __aarch64__
float32x4_t dmax = vpmaxq_f32(dmax1, dmax2);
#else
float32x2_t dmaxl =
vpmax_f32(vget_low_f32(dmax1), vget_high_f32(dmax1));
float32x2_t dmaxh =
vpmax_f32(vget_low_f32(dmax2), vget_high_f32(dmax2));
float32x4_t dmax = vcombine_f32(dmaxl, dmaxh);
#endif
vst1q_f32(&data_out_channel[w >> 1], dmax);
}
#else
w = w_unroll_size;
int num = w_unroll_size >> 3;
const float* dr0 = r0;
const float* dr1 = r1;
float* dr_out = data_out_channel;
if (num > 0) {
asm volatile(
"s2_max_loop: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n"
"vld1.f32 {d4-d7}, [%[dr1]]! @load q1, dr1\n"
"vmax.f32 q0, q0, q2 @max q0, q0, "
"q2\n"
"vmax.f32 q1, q1, q3 @max q1, q1, "
"q2\n"
"vpmax.f32 d4, d0, d1 @max d4, d0, "
"d1\n"
"vpmax.f32 d5, d2, d3 @max d5, d2, "
"d3\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, "
"dr_out\n"
"subs %[num], #1 @subs num, 1\n"
"bne s2_max_loop @bne num\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[num] "+r"(num)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
}
#endif // __aarch64__
for (; w < w_even; w += 2) {
data_out_channel[w >> 1] =
std::max(std::max(r0[w], r0[w + 1]), std::max(r1[w], r1[w + 1]));
}
for (; w < win; ++w) { // run 0 or 1 time
data_out_channel[w >> 1] = std::max(r0[w], r1[w]);
}
r0 += w_in_2; // << 1;
r1 += w_in_2; // << 1;
data_out_channel += wout;
}
// process remain row (odd, last row)
for (; h < hin; h++) { // run 0 or 1 time
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
#ifdef __aarch64__
float32x4_t dmax = vpmaxq_f32(dr00, dr01);
#else
float32x2_t dmaxl =
vpmax_f32(vget_low_f32(dr00), vget_high_f32(dr00));
float32x2_t dmaxh =
vpmax_f32(vget_low_f32(dr01), vget_high_f32(dr01));
float32x4_t dmax = vcombine_f32(dmaxl, dmaxh);
#endif
float32x4_t dmax_cmp_zero = vmaxq_f32(dmax, vzero);
vst1q_f32(&data_out_channel[w >> 1], dmax_cmp_zero);
}
#else
w = w_unroll_size;
int num = w_unroll_size >> 3;
const float* dr0 = r0;
float* dr_out = data_out_channel;
if (num > 0) {
asm volatile(
"s2_max_loop1: @main "
"loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0, dr0\n"
"vpmax.f32 d4, d0, d1 @max d4, d0, "
"d1\n"
"vpmax.f32 d5, d2, d3 @max d5, d2, "
"d3\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2, "
"dr_out\n"
"subs %[num], #1 @subs num, 1\n"
"bne s2_max_loop1 @bne num\n"
: [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [num] "+r"(num)
:
: "cc", "memory", "q0", "q1", "q2");
}
#endif // __aarch64__
for (; w < w_even; w += 2) {
data_out_channel[w >> 1] = std::max(std::max(r0[w], r0[w + 1]), 0.f);
}
for (; w < win; ++w) { // run 0 or 1 time
data_out_channel[w >> 1] = std::max(r0[w], 0.f);
}
}
}
}
}
void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int w_even = (win >> 1) << 1;
// int w_remains = w_in - w_even; // should be 0 or 1
int h_even = (hin >> 1) << 1;
// int h_remains = h_in - h_even; // should be 0 or 1
int w_unroll_size = (w_even >> 3) << 3;
// int w_unroll_remian = w_even - w_unroll_size;
int w_in_2 = win << 1;
float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
int h = 0;
for (; h < h_even; h += 2) {
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
float32x4_t dr10 = vld1q_f32(&r1[w]);
float32x4_t dr11 = vld1q_f32(&r1[w + 4]);
float32x4_t dsum1 = vaddq_f32(dr00, dr10);
float32x4_t dsum2 = vaddq_f32(dr01, dr11);
#ifdef __aarch64__
float32x4_t dsum = vpaddq_f32(dsum1, dsum2);
#else
float32x2_t dsuml =
vpadd_f32(vget_low_f32(dsum1), vget_high_f32(dsum1));
float32x2_t dsumh =
vpadd_f32(vget_low_f32(dsum2), vget_high_f32(dsum2));
float32x4_t dsum = vcombine_f32(dsuml, dsumh);
#endif
float32x4_t res = vmulq_f32(dsum, vcoef);
vst1q_f32(&data_out_channel[w >> 1], res);
}
#else
w = w_unroll_size;
int num = w_unroll_size >> 3;
const float* dr0 = r0;
const float* dr1 = r1;
float* dr_out = data_out_channel;
if (num > 0) {
asm volatile(
"1: @ main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, "
"dr0\n"
"vld1.f32 {d4-d7}, [%[dr1]]! @ load q1, "
"dr1\n"
"vadd.f32 q0, q0, q2 @ add q0, q0, "
"q2\n"
"vadd.f32 q1, q1, q3 @ add q1, q1, "
"q2\n"
"vpadd.f32 d4, d0, d1 @ add d4, d0, "
"d1\n"
"vpadd.f32 d5, d2, d3 @ add d5, d2, "
"d3\n"
"vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, "
"vcoef\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, "
"dr_out\n"
"subs %[num], #1 @ subs num, 1\n"
"bne 1b @ bne num\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[vcoef] "+w"(vcoef), [num] "+r"(num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(num), "w"(vcoef)
: "cc", "memory", "q0", "q1", "q2", "q3");
}
#endif // __aarch64__
for (; w < w_even; w += 2) {
data_out_channel[w >> 1] =
(r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) / 4.f;
}
for (; w < win; ++w) { // run 0 or 1 time
data_out_channel[w >> 1] = (r0[w] + r1[w]) / 4.f;
}
r0 += w_in_2; // << 1;
r1 += w_in_2; // << 1;
data_out_channel += wout;
}
// process remain row (odd, last row)
for (; h < hin; h++) { // run 0 or 1 time
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
#ifdef __aarch64__
float32x4_t dsum = vpaddq_f32(dr00, dr01);
#else
float32x2_t dsuml =
vpadd_f32(vget_low_f32(dr00), vget_high_f32(dr00));
float32x2_t dsumh =
vpadd_f32(vget_low_f32(dr01), vget_high_f32(dr01));
float32x4_t dsum = vcombine_f32(dsuml, dsumh);
#endif
float32x4_t res = vmulq_f32(dsum, vcoef);
vst1q_f32(&data_out_channel[w >> 1], res);
}
#else
w = w_unroll_size;
int num = w_unroll_size >> 3;
const float* dr0 = r0;
float* dr_out = data_out_channel;
if (num > 0) {
asm volatile(
"1: @ main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @ load q0, "
"dr0\n"
"vpadd.f32 d4, d0, d1 @ add d4, d0, "
"d1\n"
"vpadd.f32 d5, d2, d3 @ add d5, d2, "
"d3\n"
"vmul.f32 q2, q2, %q[vcoef] @ mul q2, q2, "
"vcoef\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @ vst1 q2, "
"dr_out\n"
"subs %[num], #1 @ subs num, 1\n"
"bne 1b @ bne num\n"
: [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [vcoef] "+w"(vcoef),
[num] "+r"(num)
: "r"(dr0), "r"(dr_out), "r"(num), "w"(vcoef)
: "cc", "memory", "q0", "q1", "q2");
}
#endif // __aarch64__
for (; w < w_even; w += 2) {
data_out_channel[w >> 1] = (r0[w] + r0[w + 1]) / 4.f;
}
for (; w < win; ++w) { // run 0 or 1 time
data_out_channel[w >> 1] = r0[w] / 4.f;
}
}
}
}
}
void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
// no need to pad input tensor, pad_size is not used, default border is zero
// padded
int ch_in = chin;
int h_in = hin;
int w_in = win;
int ch_out = chout;
int h_out = hout;
int w_out = wout;
int size_channel_out = w_out * h_out;
int size_channel_in = win * hin;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int w_even = (w_in >> 1) << 1;
// int w_remains = w_in - w_even; // should be 0 or 1
int h_even = (h_in >> 1) << 1;
// int h_remains = h_in - h_even; // should be 0 or 1
// int w_unroll_size = (w_even >> 3) << 3;
// int w_unroll_remian = w_even - w_unroll_size;
int w_in_2 = w_in << 1;
int w_unroll_size = (w_in - 2) >> 2;
int w_unroll_remian = w_in - 2 - w_unroll_size * 4;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vzero = vdupq_n_f32(minval); // zero pad
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * ch_out * size_channel_out;
const float* data_in_batch = data_in + n * ch_in * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < ch_out; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + w_in;
const float* r2 = r1 + w_in;
int cnt_num = w_unroll_size; // w_in / 4
float* dr_out = data_out_channel;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 0;
int cnt = 1;
// left
data_out_channel[0] =
std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1]));
// first row with zero pad
#ifdef __aarch64__
for (; w <= w_in - 6; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_34_56 =
vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56);
float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0));
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3);
vst1q_f32(&data_out_channel[cnt], vmax);
cnt += 4;
}
#else
dr_out = dr_out + 1;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n"
"vmax.f32 q5, q0, q2 @max "
"r0_1234,r1_1234\n"
"vmax.f32 d12, d2, d6 @max "
"r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vpmax.f32 d2, d10, d11 @pmax d4, "
"max_1234, max_1234\n"
"vpmax.f32 d3, d0, d1 @pmax d4, "
"max_2345, max_2345\n"
"vpmax.f32 d6, d4, d5 @pmax d6, "
"max_3456, max_3456\n"
"vmax.f32 d8, d2, d3 @max d2, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d9, d3, d6 @max d2, "
"vmax_23_45, vmax_34_56\n"
"sub %[dr0], #8 @sub w, 8\n"
"sub %[dr1], #8 @sub w, 8\n"
// swap
"vmov.f32 s0, s17 @mov \n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s0 @mov \n"
"subs %[cnt_num], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remian
w = w_unroll_size * 4;
for (int j = 0; j < w_unroll_remian; j++) {
float tmp_max = std::max(r0[j + w], r1[j + w]);
tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1]));
tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2]));
data_out_channel[j + w + 1] = tmp_max;
}
// right
float tmp = std::max(r0[w_in - 2], r1[w_in - 2]);
tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1]));
data_out_channel[w_out - 1] = tmp;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
data_out_channel += w_out;
int h = 0;
for (; h < h_in - 2; h += 1) {
// deal with left pad
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
float maxr2 = std::max(r2[0], r2[1]);
data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2);
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w <= w_in - 6; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_34_56 =
vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56);
float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0));
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3);
vst1q_f32(&data_out_channel[cnt], vmax);
cnt += 4;
}
#else
dr_out = data_out_channel + 1;
dr0 = r0;
dr1 = r1;
dr2 = r2;
cnt_num = w_unroll_size;
if (cnt_num > 0) {
asm volatile(
"1: @main "
"loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n"
"vmax.f32 q7, q0, q2 @max "
"r0_1234,r1_1234\n"
"vmax.f32 d16, d2, d6 @max "
"r0_5678,r1_5678\n"
"vmax.f32 q3, q7, q4 @max "
"r0_1234,r1_1234\n"
"vmax.f32 d12, d16, d10 @max "
"r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345\n"
"vext.f32 q2, q3, q6, #2 @vext max_3456\n"
"vpmax.f32 d2, d6, d7 @pmax d4, "
"max_1234, max_1234\n"
"vpmax.f32 d3, d0, d1 @pmax d4, "
"max_2345, max_2345\n"
"vpmax.f32 d6, d4, d5 @pmax d6, "
"max_3456, max_3456\n"
"vmax.f32 d8, d2, d3 @max d2, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d9, d3, d6 @max d2, "
"vmax_23_45, vmax_34_56\n"
"sub %[dr0], #8 @sub w, 8\n"
"sub %[dr1], #8 @sub w, 8\n"
"sub %[dr2], #8 @sub w, 8\n"
// swap
"vmov.f32 s0, s17 @mov \n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s0 @mov \n"
"subs %[cnt_num], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"bne 1b @ bne "
"s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8");
}
#endif
// remian
w = w_unroll_size * 4;
for (int j = 0; j < w_unroll_remian; j++) {
float tmp_max = std::max(r0[j + w], r1[j + w]);
tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1]));
tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2]));
tmp_max = std::max(tmp_max, std::max(r2[j + w], r2[j + w + 1]));
tmp_max = std::max(tmp_max, r2[j + w + 2]);
data_out_channel[j + w + 1] = tmp_max;
}
// right
tmp = std::max(r0[w_in - 2], r1[w_in - 2]);
tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1]));
tmp = std::max(tmp, std::max(r2[w_in - 2], r2[w_in - 1]));
data_out_channel[w_out - 1] = tmp;
r0 = r1;
r1 = r2;
r2 = r1 + w_in;
data_out_channel += w_out;
}
// the last two line
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
data_out_channel[0] = std::max(maxr0, maxr1);
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w <= w_in - 6; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_34_56 =
vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56);
float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0));
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3);
vst1q_f32(&data_out_channel[cnt], vmax);
cnt += 4;
}
#else
dr_out = data_out_channel + 1;
dr0 = r0;
dr1 = r1;
cnt_num = w_unroll_size;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n"
"vmax.f32 q5, q0, q2 @max "
"r0_1234,r1_1234\n"
"vmax.f32 d12, d2, d6 @max "
"r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vpmax.f32 d2, d10, d11 @pmax d4, "
"max_1234, max_1234\n"
"vpmax.f32 d3, d0, d1 @pmax d4, "
"max_2345, max_2345\n"
"vpmax.f32 d6, d4, d5 @pmax d6, "
"max_3456, max_3456\n"
"vmax.f32 d8, d2, d3 @max d2, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d9, d3, d6 @max d2, "
"vmax_23_45, vmax_34_56\n"
"sub %[dr0], #8 @sub w, 8\n"
"sub %[dr1], #8 @sub w, 8\n"
// swap
"vmov.f32 s0, s17 @mov \n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s0 @mov \n"
"subs %[cnt_num], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remian
w = w_unroll_size * 4;
for (int j = 0; j < w_unroll_remian; j++) {
float tmp_max = std::max(r0[j + w], r1[j + w]);
tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1]));
tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2]));
data_out_channel[j + w + 1] = tmp_max;
}
tmp = std::max(r0[w_in - 2], r1[w_in - 2]);
tmp = std::max(tmp, std::max(r0[w_in - 1], r1[w_in - 1]));
data_out_channel[w_out - 1] = tmp;
}
}
}
void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int w_in = win;
int h_in = hin;
int ch_in = chin;
int w_out = wout;
int h_out = hout;
int ch_out = chout;
int size_channel_out = w_out * h_out;
int size_channel_in = w_in * h_in;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int w_even = (w_in >> 1) << 1;
int h_even = (h_in >> 1) << 1;
int w_in_2 = w_in << 1;
int w_unroll_size = (w_in - 2) >> 2;
int w_unroll_remian = w_in - 2 - w_unroll_size * 4;
float32x4_t vzero = vdupq_n_f32(0.f); // zero pad
float32x4_t vcoef = vdupq_n_f32(1.f / 9.f); // zero pad
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * ch_out * size_channel_out;
const float* data_in_batch = data_in + n * ch_in * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < ch_out; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + w_in;
const float* r2 = r1 + w_in;
int cnt_num = w_unroll_size; // w_in / 4
float* dr_out = data_out_channel;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 0;
int cnt = 1;
// left
data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f;
// first row with zero pad
#ifdef __aarch64__
for (; w <= w_in - 6; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345);
vsum = vaddq_f32(vsum, vsum_3456);
vsum = vmulq_f32(vsum, vcoef);
vst1q_f32(&data_out_channel[cnt], vsum);
cnt += 4;
}
#else
dr_out = dr_out + 1;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n"
"vadd.f32 q5, q0, q2 @max "
"r0_1234,r1_1234\n"
"vadd.f32 d12, d2, d6 @max "
"r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vadd.f32 q1, q5, q0 @add 1234 + 2345\n"
"vadd.f32 q1, q1, q2 @add + 3456\n"
"vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n"
"sub %[dr0], #8 @sub w, 8\n"
"sub %[dr1], #8 @sub w, 8\n"
"subs %[cnt_num], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remian
w = w_unroll_size * 4;
for (int j = 0; j < w_unroll_remian; j++) {
float tmp_sum = r0[j + w] + r1[j + w];
tmp_sum += (r0[j + w + 1] + r1[j + w + 1]);
tmp_sum += (r0[j + w + 2] + r1[j + w + 2]);
data_out_channel[j + w + 1] = tmp_sum / 9.f;
}
// right
float tmp = r0[w_in - 2] + r1[w_in - 2];
tmp += (r0[w_in - 1] + r1[w_in - 1]);
data_out_channel[w_out - 1] = tmp / 9.f;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
data_out_channel += w_out;
int h = 0;
for (; h < h_in - 2; h += 1) {
// deal with left pad
float maxr0 = r0[0] + r0[1];
float maxr1 = r1[0] + r1[1];
float maxr2 = r2[0] + r2[1];
data_out_channel[0] = (maxr0 + maxr1 + maxr2) / 9.f;
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w <= w_in - 6; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
vsum_1234 = vaddq_f32(vsum_1234, vr2_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
vsum_5678 = vaddq_f32(vsum_5678, vr2_5678);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345);
vsum = vaddq_f32(vsum, vsum_3456);
vsum = vmulq_f32(vsum, vcoef);
vst1q_f32(&data_out_channel[cnt], vsum);
cnt += 4;
}
#else
dr_out = data_out_channel + 1;
dr0 = r0;
dr1 = r1;
dr2 = r2;
cnt_num = w_unroll_size;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n"
"vadd.f32 q7, q0, q2 @max "
"r0_1234,r1_1234\n"
"vadd.f32 d16, d2, d6 @max "
"r0_5678,r1_5678\n"
"vadd.f32 q3, q7, q4 @max "
"r0_1234,r1_1234\n"
"vadd.f32 d12, d16, d10 @max "
"r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345\n"
"vext.f32 q2, q3, q6, #2 @vext max_3456\n"
"vadd.f32 q1, q3, q0 @add 1234 + "
"2345\n"
"vadd.f32 q1, q1, q2 @add + 3456\n"
"vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n"
"sub %[dr0], #8 @sub w, 8\n"
"sub %[dr1], #8 @sub w, 8\n"
"sub %[dr2], #8 @sub w, 8\n"
"subs %[cnt_num], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"bne 1b @bne "
"s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[vcoef] "+w"(vcoef)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8");
}
#endif
// remian
w = w_unroll_size * 4;
for (int j = 0; j < w_unroll_remian; j++) {
float tmp_sum = r0[j + w] + r1[j + w];
tmp_sum += (r0[j + w + 1] + r1[j + w + 1]);
tmp_sum += (r0[j + w + 2] + r1[j + w + 2]);
tmp_sum += (r2[j + w + 1] + r2[j + w + 2]);
tmp_sum += r2[j + w];
data_out_channel[j + w + 1] = tmp_sum / 9.f;
}
// right
tmp = r0[w_in - 2] + r1[w_in - 2];
tmp += (r0[w_in - 1] + r1[w_in - 1]);
tmp += (r2[w_in - 2] + r2[w_in - 1]);
data_out_channel[w_out - 1] = tmp / 9.f;
r0 = r1;
r1 = r2;
r2 = r1 + w_in;
data_out_channel += w_out;
}
// the last two line
float maxr0 = (r0[0] + r0[1]);
float maxr1 = (r1[0] + r1[1]);
data_out_channel[0] = (maxr0 + maxr1) / 9.f;
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w <= w_in - 6; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345);
vsum = vaddq_f32(vsum, vsum_3456);
vsum = vmulq_f32(vsum, vcoef);
vst1q_f32(&data_out_channel[cnt], vsum);
cnt += 4;
}
#else
dr_out = data_out_channel + 1;
dr0 = r0;
dr1 = r1;
cnt_num = w_unroll_size;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7, dr1\n"
"vadd.f32 q5, q0, q2 @max "
"r0_1234,r1_1234\n"
"vadd.f32 d12, d2, d6 @max "
"r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vadd.f32 q1, q5, q0 @add 1234 + 2345\n"
"vadd.f32 q1, q1, q2 @add + 3456\n"
"vmul.f32 q4, q1, %q[vcoef] @mul * 1/9.f \n"
"sub %[dr0], #8 @sub w, 8\n"
"sub %[dr1], #8 @sub w, 8\n"
"subs %[cnt_num], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [vcoef] "+w"(vcoef)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remian
w = w_unroll_size * 4;
for (int j = 0; j < w_unroll_remian; j++) {
float tmp_sum = r0[j + w] + r1[j + w];
tmp_sum += (r0[j + w + 1] + r1[j + w + 1]);
tmp_sum += (r0[j + w + 2] + r1[j + w + 2]);
data_out_channel[j + w + 1] = tmp_sum / 9.f;
}
// right
tmp = r0[w_in - 2] + r1[w_in - 2];
tmp += (r0[w_in - 1] + r1[w_in - 1]);
data_out_channel[w_out - 1] = tmp / 9.f;
}
}
}
void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int kernel_h = ksize[0];
int kernel_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
int pad_top = pad_h;
int pad_left = pad_w;
int w_needed = wout * 2 + 1;
int h_needed = hout * 2 + 1;
int pad_right = w_needed - win - pad_left;
int pad_bottom = h_needed - hin - pad_top;
int w_even = (win >> 1) << 1;
int h_even = (hin >> 1) << 1;
int w_in_2 = win << 1;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vzero = vdupq_n_f32(minval); // zero pad
int cnt_col = (win - 1) / 8;
// remain
int remain = ((win - 1) % 8) / 2;
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
float* dr_out = data_out_channel;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 1;
int cnt = 1;
int cnt_num = cnt_col;
int cnt_num1 = remain;
data_out_channel[0] =
std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1]));
// first row with zero pad
#ifdef __aarch64__
for (; w < win - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&data_out_channel[cnt], vmax_123_345);
vst1_f32(&data_out_channel[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
vmax2 = vpmax_f32(vmax2, vmax2);
data_out_channel[cnt] = vget_lane_f32(vmax2, 0);
cnt++;
}
#else
dr0 = dr0 + 1;
dr1 = dr1 + 1;
dr_out = dr_out + 1;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, 0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vmax.f32 q6, q0, q3 @max "
"r0_1234,r1_1234\n"
"vmax.f32 q7, q1, q4 @max "
"r0_5678,r1_5678\n"
"vmax.f32 q8, q2, q5 @max "
"r0_9101112,r1_9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q7, q8, #1 @vext max_6789\n"
"vpmax.f32 d4, d12, d13 @pmax d4, "
"vmax_1234, vmax_1234\n"
"vpmax.f32 d6, d14, d15 @pmax d6, "
"vmax_5678, vmax_5678\n"
"vpmax.f32 d5, d0, d1 @pmax d5, "
"vmax_2345, vmax_2345\n"
"vpmax.f32 d7, d2, d3 @pmax d7, "
"vmax_6789, vmax_6789\n"
"vmax.f32 d8, d4, d5 @max d2, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d9, d6, d7 @max d2, "
"vmax_56_78, vmax_67_89\n"
"sub %[dr0], #16 @add w, 8\n"
"sub %[dr1], #16 @add w, 8\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"bne 1b @bne s3_max_loop\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp cnt_num, "
"0\n"
"ble 4f @ble exit\n"
"2: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vmov.f32 s7,s6 @movs7, s6\n"
"vmax.f32 q0, q0, q1 @max q0, q0, q1\n"
"vpmax.f32 d0, d0, d1 @pmax d0, d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0, d0, d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs "
"cnt_num, #1\n"
"bne 2b @bne "
"s3_max_loop_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9");
}
// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1);
#endif
// int w = w_even - 1;
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp = std::max(tmp, std::max(r0[i], r1[i]));
}
data_out_channel[w_even >> 1] = tmp;
// cnt ++;
}
r0 = r1;
r1 = r0 + win;
r2 = r1 + win;
data_out_channel += wout;
int h = 2;
for (; h < h_even; h += 2) {
// deal with left pad
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
float maxr2 = std::max(r2[0], r2[1]);
data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2);
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < win - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&data_out_channel[cnt], vmax_123_345);
vst1_f32(&data_out_channel[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
data_out_channel[cnt] = vget_lane_f32(vmax, 0);
cnt++;
}
#else
dr_out = data_out_channel + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
dr2 = (r2 + 1);
cnt_num = cnt_col;
cnt_num1 = remain;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vmax.f32 q9, q0, q3 @max q0,q0,q2\n"
"vmax.f32 q10, q1, q4 @max q1,q1,q3\n"
"vmax.f32 q11, q2, q5 @max q1,q1,q3\n"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 "
"1234\n"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 "
"5678\n"
"vmax.f32 q1, q11, q8 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345\n"
"vext.f32 q2, q3, q1, #1 @vext 6789\n"
"vpmax.f32 d10, d0, d1 @pmax d10, "
"vmax_1234, vmax_1234\n"
"vpmax.f32 d12, d6, d7 @pmax d12, "
"vmax_5678, vmax_5678\n"
"vpmax.f32 d11, d8, d9 @pmax d11, "
"vmax_2345, vmax_2345\n"
"vpmax.f32 d13, d4, d5 @pmax d13, "
"vmax_6789, vmax_6789\n"
"vmax.f32 d0, d10, d11 @pmax d0, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d1, d12, d13 @pmax d1, "
"vmax_56_78, vmax_67_89\n"
"sub %[dr0], #16 @add w, 8\n"
"sub %[dr1], #16 @add w, 8\n"
"sub %[dr2], #16 @add w, 8\n"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"bne 1b @bne "
"s3_max_loop_mid\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, "
"dr1\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vmov.f32 s7,s6 @movs7, s6\n"
"vmov.f32 s11,s10 @movs11, s10\n"
"vmax.f32 q0, q0, q1 @max q0, q0, "
"q1\n"
"vmax.f32 q0, q0, q2 @max q0, q0, "
"q2\n"
"vpmax.f32 d0, d0, d1 @pmax d0, "
"d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0, d0, "
"d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"sub %[dr2], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs cnt_num, "
"#1\n"
"bne 2b @bne "
"s3_max_loop_mid_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num1] "+r"(cnt_num1)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, std::max(r0[i], r1[i]));
tmp = std::max(tmp, r2[i]);
}
data_out_channel[w_even >> 1] = tmp;
// cnt ++;
}
r0 = r2;
r1 = r0 + win;
r2 = r1 + win;
data_out_channel += wout;
}
if (pad_bottom) {
// deal with bottom pad
// first row with zero pad
int hstart = (h >> 1) * stride_h - pad_h;
int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin);
if (hstart == hend - 1) { // only one lline
data_out_channel[0] = std::max(r0[0], r0[1]);
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < win - 8; w += 8) {
float32x4_t vmax_1234 = vld1q_f32(&r0[w]);
float32x4_t vmax_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vmax_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&data_out_channel[cnt], vmax_123_345);
vst1_f32(&data_out_channel[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
float32x2_t vmax = vpmax_f32(vget_low_f32(vr0), vget_high_f32(vr0));
vmax = vpmax_f32(vmax, vmax);
data_out_channel[cnt] = vget_lane_f32(vmax, 0);
cnt++;
}
#else
dr_out = data_out_channel + 1;
dr0 = (r0 + 1);
cnt_num = cnt_col;
cnt_num1 = remain;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3, "
"dr0\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, "
"dr0\n"
"vext.f32 q4, q0, q1, #1 @vext q4, q0, "
"q1, 1 2345\n"
"vext.f32 q5, q1, q2, #1 @vext q5, q0, "
"q1, 1 6789\n"
"vpmax.f32 d12, d0, d1 @pmax d12, "
"vmax_1234, vmax_1234\n"
"vpmax.f32 d14, d2, d3 @pmax d14, "
"vmax_5678, vmax_5678\n"
"vpmax.f32 d13, d8, d9 @pmax d13, "
"vmax_2345, vmax_2345\n"
"vpmax.f32 d15, d10, d11 @pmax d15, "
"vmax_6789, vmax_6789\n"
"vmax.f32 d0, d12, d13 @max d0, "
"vmax_12_34,vmax_23_45\n"
"vmax.f32 d1, d14, d15 @pmax d2, "
"vmax_56_78, vmax_67_89\n"
"sub %[dr0], #16 @add w, 6\n"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"bne 1b @bne "
"s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vpmax.f32 d0, d0, d1 @pmax d0, "
"d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0, d0, "
"d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"sub %[dr0], #8 @add w, 2\n"
"subs %[cnt_num1], #1 @subs "
"cnt_num, #1\n"
"bne 2b @bne "
"s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, r0[i]);
}
data_out_channel[w_even >> 1] = tmp;
}
} else { // two lines
data_out_channel[0] =
std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1]));
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < win - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&data_out_channel[cnt], vmax_123_345);
vst1_f32(&data_out_channel[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
vmax2 = vpmax_f32(vmax2, vmax2);
data_out_channel[cnt] = vget_lane_f32(vmax2, 0);
cnt++;
}
#else
dr_out = data_out_channel + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
cnt_num = cnt_col;
cnt_num1 = remain;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, "
"dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 "
"1234\n"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 "
"5678\n"
"vmax.f32 q8, q2, q5 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7,
// s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0, "
"2345\n"
"vext.f32 q1, q7, q8, #1 @vext q1, "
"6789\n"
"vpmax.f32 d4, d12, d13 @pmax d4, "
"vmax_1234, vmax_1234\n"
"vpmax.f32 d6, d14, d15 @pmax d6, "
"vmax_5678, vmax_5678\n"
"vpmax.f32 d5, d0, d1 @pmax d5, "
"vmax_2345, vmax_2345\n"
"vpmax.f32 d7, d2, d3 @pmax d7, "
"vmax_6789, vmax_6789\n"
"vmax.f32 d8, d4, d5 @max d2, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d9, d6, d7 @max d2, "
"vmax_56_78, vmax_67_89\n"
"sub %[dr0], #16 @add w, 8\n"
"sub %[dr1], #16 @add w, 8\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"bne 1b @bne "
"s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vmov.f32 s7,s6 @movs7, s6\n"
"vmax.f32 q0, q0, q1 @max q0, q0, "
"q1\n"
"vpmax.f32 d0, d0, d1 @pmax d0, "
"d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0, d0, "
"d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs "
"cnt_num, #1\n"
"bne 2b @bne "
"s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp = std::max(tmp, std::max(r0[i], r1[i]));
}
data_out_channel[w_even >> 1] = tmp;
}
}
}
}
}
}
void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int kernel_h = ksize[0];
int kernel_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
int pad_top = pad_h;
int pad_left = pad_w;
int w_needed = wout * 2 + 1;
int h_needed = hout * 2 + 1;
int pad_right = w_needed - win - pad_left;
int pad_bottom = h_needed - hin - pad_top;
int w_even = (win >> 1) << 1;
int h_even = (hin >> 1) << 1;
int w_in_2 = win << 1;
int w_unroll_size = (win - 1) / 8;
// remain
int w_unroll_remian = ((win - 1) % 8) / 2;
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
int cnt_num = w_unroll_size;
int cnt_num1 = w_unroll_remian;
float* dr_out = data_out_channel;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 1;
int cnt = 1;
float32x4_t vcoef = vdupq_n_f32(1.f / 9.f);
float32x4_t vzero = vdupq_n_f32(0.f);
data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f;
// first row with zero pad
#ifdef __aarch64__
for (; w < win - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&data_out_channel[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
vsum2 = vpadd_f32(vsum2, vsum2);
float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef));
data_out_channel[cnt] = vget_lane_f32(vrst, 0);
cnt++;
}
#else
dr0 = dr0 + 1;
dr1 = dr1 + 1;
dr_out = dr_out + 1;
// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1);
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, 0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vadd.f32 q6, q0, q3 @max "
"r0_1234,r1_1234\n"
"vadd.f32 q7, q1, q4 @max "
"r0_5678,r1_5678\n"
"vadd.f32 q8, q2, q5 @max "
"r0_9101112,r1_9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, 2345 \n"
"vadd.f32 q5, q7, q1 @add 5678, 4567 \n"
"vadd.f32 q4, q4, q2 @add 3456, sum1 \n"
"vadd.f32 q5, q5, q3 @add 6789, sum2 \n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s21 @mov \n"
"vmov.f32 s19, s23 @mov \n"
"vmul.f32 q4, q4, %q[vcoef] @mul \n"
"sub %[dr0], #16 @add w, 8\n"
"sub %[dr1], #16 @add w, 8\n"
"subs %[cnt_num], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n"
"bne 1b @bne s3_max_loop\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp cnt_num, "
"0\n"
"ble 4f @ble exit\n"
"2: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0, q0, q1\n"
"vpadd.f32 d0, d0, d1 @padd d0, d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0, d0, d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"bne 2b @bne s3_max_loop_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1),
[vcoef] "+w"(vcoef), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9");
}
// printf("cnt_num: %d, cnt_num1: %d \n",cnt_num, cnt_num1);
#endif
// int w = w_even - 1;
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = 0.f; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp += (r0[i] + r1[i]);
}
data_out_channel[w_even >> 1] = tmp / 9.f;
// cnt ++;
}
r0 = r1;
r1 = r0 + win;
r2 = r1 + win;
data_out_channel += wout;
int h = 2;
for (; h < h_even; h += 2) {
// deal with left pad
float sum0 = r0[0] + r0[1];
float sum1 = r1[0] + r1[1];
float sum2 = r2[0] + r2[1];
data_out_channel[0] = (sum0 + sum1 + sum2) / 9.f;
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < win - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
vsum_1234 = vaddq_f32(vsum_1234, vr2_1234);
vsum_5678 = vaddq_f32(vsum_5678, vr2_5678);
vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&data_out_channel[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
vr2 = vsetq_lane_f32(0.f, vr2, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
vsum1 = vaddq_f32(vsum1, vr2);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
float32x2_t vsum = vpadd_f32(vsum2, vsum2);
data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f;
cnt++;
}
#else
dr_out = data_out_channel + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
dr2 = (r2 + 1);
cnt_num = w_unroll_size;
cnt_num1 = w_unroll_remian;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vadd.f32 q9, q0, q3 @max q0,q0,q2\n"
"vadd.f32 q10, q1, q4 @max q1,q1,q3\n"
"vadd.f32 q11, q2, q5 @max q1,q1,q3\n"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 "
"1234\n"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 "
"5678\n"
"vadd.f32 q8, q11, q8 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, 2345 "
"\n"
"vadd.f32 q5, q7, q1 @add 5678, 4567 "
"\n"
"vadd.f32 q4, q4, q2 @add 3456, sum1 "
"\n"
"vadd.f32 q5, q5, q3 @add 6789, sum2 "
"\n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s21 @mov \n"
"vmov.f32 s19, s23 @mov \n"
"vmul.f32 q4, q4, %q[vcoef] @mul \n"
"sub %[dr0], #16 @add w, 8\n"
"sub %[dr1], #16 @add w, 8\n"
"sub %[dr2], #16 @add w, 8\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"bne 1b @bne s3_max_loop_mid\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, "
"dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n"
"vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0, q0, "
"q1\n"
"vadd.f32 q0, q0, q2 @add q0, q0, "
"q1\n"
"vpadd.f32 d0, d0, d1 @padd d0, "
"d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0, d0, "
"d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"sub %[dr2], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"bne 2b @bne s3_max_loop_mid_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef),
[vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = 0.f;
for (int i = wstart; i < wend; i++) {
tmp += (r0[i] + r1[i] + r2[i]);
}
data_out_channel[w_even >> 1] = tmp / 9.f;
// cnt ++;
}
r0 = r2;
r1 = r0 + win;
r2 = r1 + win;
data_out_channel += wout;
}
if (pad_bottom) {
// deal with bottom pad
// first row with zero pad
int hstart = (h >> 1) * stride_h - pad_h;
int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin);
if (hstart == hend - 1) { // only one lline
data_out_channel[0] = (r0[0] + r0[1]) / 9.f;
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < win - 8; w += 8) {
float32x4_t vsum_1234 = vld1q_f32(&r0[w]);
float32x4_t vsum_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vsum_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2),
vsum_123_345, 1);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1),
vsum_123_345, 2);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3),
vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&data_out_channel[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
float32x2_t vsum = vpadd_f32(vget_low_f32(vr0), vget_high_f32(vr0));
vsum = vpadd_f32(vsum, vsum);
data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f;
cnt++;
}
#else
dr_out = data_out_channel + 1;
dr0 = (r0 + 1);
cnt_num = w_unroll_size;
cnt_num1 = w_unroll_remian;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d12-d15}, [%[dr0]]! @load "
"d0-d3, dr0\n"
"vld1.f32 {d16-d17}, [%[dr0]]! @load "
"d0-d3, dr0\n"
"vext.f32 q0, q6, q7, #1 @vext "
"max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext "
"max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext "
"max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext "
"max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, "
"2345 \n"
"vadd.f32 q5, q7, q1 @add 5678, "
"4567 \n"
"vadd.f32 q4, q4, q2 @add 3456, "
"sum1 \n"
"vadd.f32 q5, q5, q3 @add 6789, "
"sum2 \n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s21 @mov \n"
"vmov.f32 s19, s23 @mov \n"
"vmul.f32 q4, q4, %q[vcoef] @mul \n"
"sub %[dr0], #16 @add w, 6\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vext.f32 q0, %q[vzero], q0, #3 @ ext "
"v0_0123\n"
"vpadd.f32 d0, d0, d1 @padd d0, "
"d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0, d0, "
"d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 2\n"
"subs %[cnt_num1], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1),
[vcoef] "+w"(vcoef), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = 0.f;
for (int i = wstart; i < wend; i++) {
tmp += r0[i];
}
data_out_channel[w_even >> 1] = tmp / 9.f;
}
} else { // two lines
data_out_channel[0] = (r0[0] + r0[1] + r1[0] + r1[1]) / 9.f;
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < win - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2),
vsum_123_345, 1);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1),
vsum_123_345, 2);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3),
vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&data_out_channel[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
vsum2 = vpadd_f32(vsum2, vsum2);
float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef));
data_out_channel[cnt] = vget_lane_f32(vrst, 0);
cnt++;
}
#else
dr_out = data_out_channel + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
cnt_num = w_unroll_size;
cnt_num1 = w_unroll_remian;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3, "
"dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 "
"1234\n"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 "
"5678\n"
"vmax.f32 q8, q2, q5 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7,
// s6\n"
"vext.f32 q0, q6, q7, #1 @vext "
"max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext "
"max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext "
"max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext "
"max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, "
"2345 \n"
"vadd.f32 q5, q7, q1 @add 5678, "
"4567 \n"
"vadd.f32 q4, q4, q2 @add 3456, "
"sum1 \n"
"vadd.f32 q5, q5, q3 @add 6789, "
"sum2 \n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s21 @mov \n"
"vmov.f32 s19, s23 @mov \n"
"vmul.f32 q4, q4, %q[vcoef] @mul \n"
"sub %[dr0], #16 @add w, 8\n"
"sub %[dr1], #16 @add w, 8\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ ext "
"v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ ext "
"v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0, q0, "
"q1\n"
"vpadd.f32 d0, d0, d1 @padd d0, "
"d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0, d0, "
"d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1),
[vcoef] "+w"(vcoef), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, win + pad_w), win);
float tmp = 0.f;
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp += (r0[i] + r1[i]);
}
data_out_channel[w_even >> 1] = tmp / 9.f;
}
}
}
}
}
}
void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int w_in = win;
int h_in = hin;
int ch_in = chin;
int w_out = wout;
int h_out = hout;
int ch_out = chout;
int kernel_h = ksize[0];
int kernel_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
int size_channel_out = w_out * h_out;
int size_channel_in = w_in * h_in;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int pad_top = pad_h;
int pad_left = pad_w;
int w_needed = w_out * 2 + 1;
int h_needed = h_out * 2 + 1;
int pad_right = w_needed - w_in - pad_left;
int pad_bottom = h_needed - h_in - pad_top;
int w_even = ((w_in - 1) >> 1) << 1;
// int w_remains = w_in - w_even; // should be 0 or 1
int h_even = ((h_in - 1) >> 1) << 1;
// int h_remains = h_in - h_even; // should be 0 or 1
int w_unroll_size = w_in >> 3;
int w_unroll_remian = (w_in - w_unroll_size * 8 - 1) / 2;
int w_in_2 = w_in << 1;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vzero = vdupq_n_f32(minval); // zero pad
// printf("minval: %.2f\n", minval);
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * ch_out * size_channel_out;
const float* data_in_batch = data_in + n * ch_in * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < ch_out; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + w_in;
const float* r2 = r1 + w_in;
int cnt_num = w_unroll_size;
// w = w_in - 8;
int cnt_num1 = w_unroll_remian;
float* dr_out = data_out_channel;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 0;
int cnt = 0;
// data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
// first row with zero pad
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
// data_out_channel += w_out;
int h = 0;
for (; h < h_even; h += 2) {
// deal with left pad
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
float maxr2 = std::max(r2[0], r2[1]);
// data_out_channel[0] = std::max(std::max(maxr0, maxr1), maxr2);
#ifdef __aarch64__
w = 0;
cnt = 0;
for (; w < w_in - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&data_out_channel[cnt], vmax_123_345);
vst1_f32(&data_out_channel[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
data_out_channel[cnt] = vget_lane_f32(vmax, 0);
cnt++;
}
#else
dr_out = data_out_channel; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
dr2 = r2; // (r2 + 1);
cnt_num = w_unroll_size;
cnt_num1 = w_unroll_remian;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n"
"vmax.f32 q9, q0, q3 @max q0,q0,q2\n"
"vmax.f32 q10, q1, q4 @max q1,q1,q3\n"
"vmax.f32 d22, d4, d10 @max q1,q1,q3\n"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 "
"1234\n"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 "
"5678\n"
"vmax.f32 d2, d22, d16 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345\n"
"vext.f32 q2, q3, q1, #1 @vext 6789\n"
"vpmax.f32 d10, d0, d1 @pmax d10, "
"vmax_1234, vmax_1234\n"
"vpmax.f32 d12, d6, d7 @pmax d12, "
"vmax_5678, vmax_5678\n"
"vpmax.f32 d11, d8, d9 @pmax d11, "
"vmax_2345, vmax_2345\n"
"vpmax.f32 d13, d4, d5 @pmax d13, "
"vmax_6789, vmax_6789\n"
"vmax.f32 d0, d10, d11 @pmax d0, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d1, d12, d13 @pmax d1, "
"vmax_56_78, vmax_67_89\n"
"sub %[dr0], #8 @add w, 8\n"
"sub %[dr1], #8 @add w, 8\n"
"sub %[dr2], #8 @add w, 8\n"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"bne 1b @bne s3_max_loop_mid\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, "
"dr1\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vmov.f32 s7,s6 @movs7, s6\n"
"vmov.f32 s11,s10 @movs11, s10\n"
"vmax.f32 q0, q0, q1 @max q0, q0, "
"q1\n"
"vmax.f32 q0, q0, q2 @max q0, q0, "
"q2\n"
"vpmax.f32 d0, d0, d1 @pmax d0, "
"d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0, d0, "
"d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"sub %[dr2], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs cnt_num, "
"#1\n"
"bne 2b @bne s3_max_loop_mid_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num1] "+r"(cnt_num1)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, std::max(r0[i], r1[i]));
tmp = std::max(tmp, r2[i]);
}
data_out_channel[w_even >> 1] = tmp;
// cnt ++;
}
r0 = r2;
r1 = r0 + w_in;
r2 = r1 + w_in;
data_out_channel += w_out;
}
if (pad_bottom) {
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in);
// data_out_channel[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
#ifdef __aarch64__
w = 0;
cnt = 0;
for (; w < w_in - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&data_out_channel[cnt], vmax_123_345);
vst1_f32(&data_out_channel[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
vmax2 = vpmax_f32(vmax2, vmax2);
data_out_channel[cnt] = vget_lane_f32(vmax2, 0);
cnt++;
}
#else
dr_out = data_out_channel; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
cnt_num = w_unroll_size;
cnt_num1 = w_unroll_remian;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 "
"1234\n"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 "
"5678\n"
"vmax.f32 d16, d4, d10 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0, 2345\n"
"vext.f32 q1, q7, q8, #1 @vext q1, 6789\n"
"vpmax.f32 d4, d12, d13 @pmax d4, "
"vmax_1234, vmax_1234\n"
"vpmax.f32 d6, d14, d15 @pmax d6, "
"vmax_5678, vmax_5678\n"
"vpmax.f32 d5, d0, d1 @pmax d5, "
"vmax_2345, vmax_2345\n"
"vpmax.f32 d7, d2, d3 @pmax d7, "
"vmax_6789, vmax_6789\n"
"vmax.f32 d8, d4, d5 @max d2, "
"vmax_12_34, vmax_23_45\n"
"vmax.f32 d9, d6, d7 @max d2, "
"vmax_56_78, vmax_67_89\n"
"sub %[dr0], #8 @add w, 8\n"
"sub %[dr1], #8 @add w, 8\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vmov.f32 s7,s6 @movs7, s6\n"
"vmax.f32 q0, q0, q1 @max q0, q0, "
"q1\n"
"vpmax.f32 d0, d0, d1 @pmax d0, "
"d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0, d0, "
"d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs "
"cnt_num, #1\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp = std::max(tmp, std::max(r0[i], r1[i]));
}
data_out_channel[w_even >> 1] = tmp;
}
}
}
}
}
void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
int w_in = win;
int h_in = hin;
int ch_in = chin;
int w_out = wout;
int h_out = hout;
int ch_out = chout;
int kernel_h = ksize[0];
int kernel_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
int size_channel_out = w_out * h_out;
int size_channel_in = w_in * h_in;
float* data_out = static_cast<float*>(dout);
const float* data_in = static_cast<const float*>(din);
int pad_top = pad_h;
int pad_left = pad_w;
int w_needed = w_out * 2 + 1;
int h_needed = h_out * 2 + 1;
int pad_right = w_needed - w_in - pad_left;
int pad_bottom = h_needed - h_in - pad_top;
int w_even = ((w_in - 1) >> 1) << 1;
int h_even = ((h_in - 1) >> 1) << 1;
int w_in_2 = w_in << 1;
int w_unroll_size = w_in >> 3;
int w_unroll_remian = (w_even - w_unroll_size * 8 - 1) / 2;
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * ch_out * size_channel_out;
const float* data_in_batch = data_in + n * ch_in * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < ch_out; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + w_in;
const float* r2 = r1 + w_in;
int cnt_num = w_unroll_size;
// w = w_in - 8;
int cnt_num1 = w_unroll_remian;
float* dr_out = data_out_channel;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
float32x4_t vcoef = vdupq_n_f32(1.f / 9.f);
float32x4_t vzero = vdupq_n_f32(0.f);
int h = 0;
for (; h < h_even; h += 2) {
// LOG(INFO) << "h: " << h<<", dr0:" << r0 <<", dr1: "<<r1 << ",dr2: "<<r2;
// deal with left pad
// float sum0 = r0[0] + r0[1];
// float sum1 = r1[0] + r1[1];
// float sum2 = r2[0] + r2[1];
// data_out_channel[0] = (sum0 + sum1 + sum2) / 9.f;
#if 1 // def __aarch64__
int w = 0;
int cnt = 0;
for (; w < w_in - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
vsum_1234 = vaddq_f32(vsum_1234, vr2_1234);
vsum_5678 = vaddq_f32(vsum_5678, vr2_5678);
vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&data_out_channel[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
vr2 = vsetq_lane_f32(0.f, vr2, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
vsum1 = vaddq_f32(vsum1, vr2);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
float32x2_t vsum = vpadd_f32(vsum2, vsum2);
data_out_channel[cnt] = vget_lane_f32(vsum, 0) / 9.f;
cnt++;
}
#else
dr_out = data_out_channel; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
dr2 = r2; // (r2 + 1);
cnt_num = w_unroll_size;
cnt_num1 = w_unroll_remian;
// LOG(INFO) << "cnt_num: " << cnt_num <<"cnt_num1: "<< cnt_num1;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble loop3_ave_p0 @ble "
"exit\n"
"s3_ave_loop_mid_p0: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr1\n"
"vadd.f32 q9, q0, q3 @max q0,q0,q2\n"
"vadd.f32 q10, q1, q4 @max q1,q1,q3\n"
"vadd.f32 d22, d4, d10 @max q1,q1,q3\n"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 "
"1234\n"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 "
"5678\n"
"vadd.f32 d16, d22, d16 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, 2345 "
"\n"
"vadd.f32 q5, q7, q1 @add 5678, 4567 "
"\n"
"vadd.f32 q4, q4, q2 @add 3456, sum1 "
"\n"
"vadd.f32 q5, q5, q3 @add 6789, sum2 "
"\n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s21 @mov \n"
"vmov.f32 s19, s23 @mov \n"
"vmul.f32 q4, q4, %q[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 8\n"
"sub %[dr1], #8 @add w, 8\n"
"sub %[dr2], #8 @add w, 8\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"bne s3_ave_loop_mid_p0 @bne "
"s3_max_loop_mid\n"
"loop3_ave_p0: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble exit1_ave_p0 @ble "
"exit1\n"
"s3_ave_loop_mid_1_p0: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3, "
"dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n"
"vext.f32 q2, %q[vzero], q2, #3 @ ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0, q0, "
"q1\n"
"vadd.f32 q0, q0, q2 @add q0, q0, "
"q1\n"
"vpadd.f32 d0, d0, d1 @padd d0, "
"d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0, d0, "
"d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"sub %[dr2], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs cnt_num, "
"#1\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"bne s3_ave_loop_mid_1_p0 @bne "
"s3_max_loop_mid_1\n"
"exit1_ave_p0: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num1] "+r"(cnt_num1), [vcoef] "+w"(vcoef),
[vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num1)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in);
float tmp = 0.f;
int pool_size = 3 * (wend - wstart);
for (int i = wstart; i < wend; i++) {
tmp += (r0[i] + r1[i] + r2[i]);
}
data_out_channel[w_even >> 1] = tmp / pool_size;
// cnt ++;
}
r0 = r2;
r1 = r0 + w_in;
r2 = r1 + w_in;
data_out_channel += w_out;
}
if (pad_bottom) {
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, h_in + pad_h),h_in);
// data_out_channel[0] =(r0[0] + r0[1] + r1[0] + r1[1]) / 9.f;
#if 1 // def __aarch64__
int w = 0;
int cnt = 0;
vcoef = vdupq_n_f32(1.f / 6.f);
for (; w < w_in - 8; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&data_out_channel[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
vsum2 = vpadd_f32(vsum2, vsum2);
float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef));
data_out_channel[cnt] = vget_lane_f32(vrst, 0);
cnt++;
}
#else
dr_out = data_out_channel; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
cnt_num = w_unroll_size;
cnt_num1 = w_unroll_remian;
// LOG(INFO) << "dr0:" << dr0 <<", dr1: "<<dr1 << ",dr2: "<<dr2;
if (cnt_num > 0 || cnt_num1 > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, "
"0\n"
"ble 2f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, "
"dr1\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3, dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n"
"vadd.f32 q6, q0, q3 @max q0,q0,q2 "
"1234\n"
"vadd.f32 q7, q1, q4 @max q1,q1,q3 "
"5678\n"
"vadd.f32 d16, d4, d10 @max q1,q1,q3 "
"9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, 2345 "
"\n"
"vadd.f32 q5, q7, q1 @add 5678, 4567 "
"\n"
"vadd.f32 q4, q4, q2 @add 3456, sum1 "
"\n"
"vadd.f32 q5, q5, q3 @add 6789, sum2 "
"\n"
"vmov.f32 s17, s18 @mov \n"
"vmov.f32 s18, s21 @mov \n"
"vmov.f32 s19, s23 @mov \n"
"vmul.f32 q4, q4, %q[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 8\n"
"sub %[dr1], #8 @add w, 8\n"
"subs %[cnt_num], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, "
"dr_out\n"
"bne 1b @bne s3_max_loop_bot\n"
"2: @loop \n"
"cmp %[cnt_num1], #0 @cmp "
"cnt_num, 0\n"
"ble 3f @ble exit\n"
"4: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1, "
"dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3, "
"dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0, q0, "
"q1\n"
"vpadd.f32 d0, d0, d1 @padd d0, "
"d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0, d0, "
"d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul \n"
"sub %[dr0], #8 @add w, 6\n"
"sub %[dr1], #8 @add w, 6\n"
"subs %[cnt_num1], #1 @subs "
"cnt_num, #1\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0], "
"dr_out\n"
"bne 4b @bne s3_max_loop_bot_1\n"
"3: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num1] "+r"(cnt_num1),
[vcoef] "+w"(vcoef), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num1)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9");
}
#endif
if (pad_right) {
// deal with right pad
int wstart = (w_even >> 1) * stride_w - pad_w;
int wend = std::min(std::min(wstart + kernel_w, w_in + pad_w), w_in);
float tmp = 0.f;
int pool_size = 2 * (wend - wstart);
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp += (r0[i] + r1[i]);
}
data_out_channel[w_even >> 1] = tmp / pool_size;
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
// !pooling fp32 Op
void pooling_basic(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling_global(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling2x2s2_max(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling2x2s2_ave(const void* din, void* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s1p1_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s1p1_ave(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p1_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p0_max(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p1_ave(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling3x3s2p0_ave(const void* din, void* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/split.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void split_cpy<float>(const float* din, float* dout, int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr;
dout_ptr++;
din_ptr++;
}
}
}
template <>
void split<float>(const float* din, std::vector<lite::Tensor*>* dout,
const int axis, const std::vector<int>& in_strides) {
int input_offset = 0;
for (auto out : *dout) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
for (int i = out_dim.size() - 2; i >= 0; --i) {
out_strides[i] = out_strides[i + 1] * out_dim[i];
}
float* out_data = out->mutable_data<float>();
int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis];
int out_after = out_strides[axis];
for (int i = 0; i < before; ++i) {
split_cpy(din + input_offset + i * in_after, out_data + i * out_after,
out_after);
}
input_offset += out_strides[axis];
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void split_cpy(const T* din, T* dout, int num);
template <typename T>
void split(const T* din, std::vector<lite::Tensor*>* dout, const int axis,
const std::vector<int>& in_strides);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/saturate.h"
#include <arm_neon.h>
#include <string.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename dtype>
void int32_to_dtype(const int* din, dtype* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size);
void fp32_to_int8(const float* din, signed char* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for
for (int j = 0; j < loop_size; ++j) {
float inv_scale = 1.f / scale[j % axis_size];
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vscale = vdupq_n_f32(inv_scale);
float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f);
const float* din_c = din + j * inner_size;
signed char* dout_c = dout + j * inner_size;
if (cnt > 0) {
int cnt_loop = cnt;
const float* din_ptr = din_c;
signed char* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n"
"ldp q2, q3, [%[in]], #32 \n"
"0: \n" /* main loop */
"fmul v4.4s, v0.4s, %[scale].4s \n"
"fmul v5.4s, v1.4s, %[scale].4s \n"
"fmul v6.4s, v2.4s, %[scale].4s \n"
"fmul v7.4s, v3.4s, %[scale].4s \n"
"ldp q0, q1, [%[in]], #32 \n"
"subs %[cnt], %[cnt], #1 \n"
"FCVTAS v8.4s, v4.4s \n"
"FCVTAS v9.4s, v5.4s \n"
"FCVTAS v10.4s, v6.4s \n"
"FCVTAS v11.4s, v7.4s \n"
"ldp q2, q3, [%[in]], #32 \n"
"sqxtn v4.4h, v8.4s \n"
"sqxtn2 v4.8h, v9.4s \n"
"sqxtn v5.4h, v10.4s \n"
"sqxtn2 v5.8h, v11.4s \n"
"sqxtn v8.8b, v4.8h \n"
"sqxtn2 v8.16b, v5.8h \n"
"str q8, [%[out]], #16 \n"
"bne 0b \n"
: [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop)
: [scale] "w" (vscale)
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"0: @ main loop\n"
"vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q5, q4, q4 @ set offset, 0.5\n"
"vand.i32 q6, q4, q4 @ set offset, 0.5\n"
"vand.i32 q7, q4, q4 @ set offset, 0.5\n"
"vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n"
"vcgt.f32 q10, q2, %q[vzero] @ get mask > 0, in2\n"
"vcgt.f32 q11, q3, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q4, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q5, %q[vnoff], q9 @ get right offset\n"
"vbif.f32 q6, %q[vnoff], q10 @ get right offset\n"
"vbif.f32 q7, %q[vnoff], q11 @ get right offset\n"
"vmla.f32 q4, q0, %q[vscale] @ mul scale\n"
"vmla.f32 q5, q1, %q[vscale] @ mul scale\n"
"vmla.f32 q6, q2, %q[vscale] @ mul scale\n"
"vmla.f32 q7, q3, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q0, q4 @ cvt to int32\n"
"vcvt.s32.f32 q1, q5 @ cvt to int32\n"
"vcvt.s32.f32 q2, q6 @ cvt to int32\n"
"vcvt.s32.f32 q3, q7 @ cvt to int32\n"
"vqmovn.s32 d8, q0 @ cnt to int16\n"
"vqmovn.s32 d9, q1 @ cnt to int16\n"
"vqmovn.s32 d10, q2 @ cnt to int16\n"
"vqmovn.s32 d11, q3 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vqmovn.s16 d12, q4 @ cnt to int8\n"
"vqmovn.s16 d13, q5 @ cnt to int8\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"vst1.32 {d12-d13}, [%[dout]]! @ write to output\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ to main loop\n"
:[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop)
:[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"
);
#endif
}
const float* din_r = din_c + 16 * cnt;
signed char* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int8_t>(roundf(inv_scale * din_r[i]));
}
}
}
void fp32_to_int16(const float* din, int16_t* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 8;
int remain = inner_size & 7;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for
for (int j = 0; j < loop_size; ++j) {
float inv_scale = 1.f / scale[j % axis_size];
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vscale = vdupq_n_f32(inv_scale);
float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f);
const float* din_c = din + j * inner_size;
int16_t* dout_c = dout + j * inner_size;
if (cnt > 0) {
int cnt_loop = cnt;
const float* din_ptr = din_c;
int16_t* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n"
"0: \n" /* main loop */
"fmul v4.4s, v0.4s, %[scale].4s \n"
"fmul v5.4s, v1.4s, %[scale].4s \n"
"ldp q0, q1, [%[in]], #32 \n"
"subs %[cnt], %[cnt], #1 \n"
"FCVTAS v8.4s, v4.4s \n"
"FCVTAS v9.4s, v5.4s \n"
"sqxtn v4.4h, v8.4s \n"
"sqxtn2 v4.8h, v9.4s \n"
"str q4, [%[out]], #16 \n"
"bne 0b \n"
: [in] "+r" (din_ptr), [out] "+r" (dout_ptr), [cnt] "+r" (cnt_loop)
: [scale] "w" (vscale)
: "v0", "v1", "v4", "v5", "v8", "v9"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"0: @ main loop\n"
"vand.i32 q4, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q5, q4, q4 @ set offset, 0.5\n"
"vand.i32 q6, q4, q4 @ set offset, 0.5\n"
"vand.i32 q7, q4, q4 @ set offset, 0.5\n"
"vcgt.f32 q8, q0, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q1, %q[vzero] @ get mask > 0, in1\n"
"vbif.f32 q4, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q5, %q[vnoff], q9 @ get right offset\n"
"vmla.f32 q4, q0, %q[vscale] @ mul scale\n"
"vmla.f32 q5, q1, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q0, q4 @ cvt to int32\n"
"vcvt.s32.f32 q1, q5 @ cvt to int32\n"
"vqmovn.s32 d8, q0 @ cnt to int16\n"
"vqmovn.s32 d9, q1 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vst1.32 {d8-d9}, [%[dout]]! @ write to output\n"
"subs %[cnt], #1 @ loop count -1\n"
"bne 0b @ to main loop\n"
:[dout]"+r"(dout_ptr), [din]"+r"(din_ptr), [cnt]"+r"(cnt_loop)
:[vscale]"w"(vscale), [vpoff]"w"(vpoff), [vnoff]"w"(vnoff), [vzero]"w"(vzero)
:"q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9"
);
#endif
}
const float* din_r = din_c + 8 * cnt;
int16_t* dout_r = dout_c + 8 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int16_t>(roundf(inv_scale * din_r[i]));
}
}
}
void int8_to_fp32(const signed char* in, float* out, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const signed char* din_c = in + n * inner_size;
float* dout_c = out + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) {
int loop = cnt;
const signed char* din_ptr = din_c;
float* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/
"0: \n" /* main loop */
"sshll v2.8h, v0.8b, #0 \n" /* trans to int16*/
"sshll v3.8h, v1.8b, #0 \n" /* trans to int16*/
"sshll v4.4s, v2.4h, #0 \n" /* trans to int32*/
"sshll2 v5.4s, v2.8h, #0 \n" /* trans to int32*/
"sshll v6.4s, v3.4h, #0 \n" /* trans to int32*/
"sshll2 v7.4s, v3.8h, #0 \n" /* trans to int32*/
"ldp d0, d1, [%[in]], #16 \n" /* load 16 int8*/
"scvtf v8.4s, v4.4s \n" /* trans to fp32*/
"scvtf v9.4s, v5.4s \n" /* trans to fp32*/
"scvtf v10.4s, v6.4s \n" /* trans to fp32*/
"scvtf v11.4s, v7.4s \n" /* trans to fp32*/
"subs %[loop], %[loop], #1 \n"
"fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/
"fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/
"fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/
"fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/
"stp q4, q5, [%[out]], #32 \n" /* write to memory*/
"stp q6, q7, [%[out]], #32 \n" /* write to memory*/
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n"
"0: @ main loop\n"
"vmovl.s8 q2, d0 @ trans to int16\n"
"vmovl.s8 q3, d1 @ trans to int16\n"
"vmovl.s16 q4, d4 @ trans to int32\n"
"vmovl.s16 q5, d5 @ trans to int32\n"
"vmovl.s16 q6, d6 @ trans to int32\n"
"vmovl.s16 q7, d7 @ trans to int32\n"
"vcvt.f32.s32 q0, q4 @ trans to fp32\n"
"vcvt.f32.s32 q1, q5 @ trans to fp32\n"
"vcvt.f32.s32 q2, q6 @ trans to fp32\n"
"vcvt.f32.s32 q3, q7 @ trans to fp32\n"
"vmul.f32 q4, q0, %q[scale] @ mul with scale\n"
"vmul.f32 q5, q1, %q[scale] @ mul with scale\n"
"vmul.f32 q6, q2, %q[scale] @ mul with scale\n"
"vmul.f32 q7, q3, %q[scale] @ mul with scale\n"
"vld1.32 {d0-d1}, [%[in]]! @ load 16 int8\n"
"subs %[loop], #1 \n"
"vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n"
"vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"
);
#endif //__aarch64__
}
const signed char* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
}
}
void int16_to_fp32(const short* in, float* out, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const short* din_c = in + n * inner_size;
float* dout_c = out + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) {
int loop = cnt;
const short* din_ptr = din_c;
float* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/
"0: \n" /* main loop */
"sshll v4.4s, v0.4h, #0 \n" /* trans to int32*/
"sshll2 v5.4s, v0.8h, #0 \n" /* trans to int32*/
"sshll v6.4s, v1.4h, #0 \n" /* trans to int32*/
"sshll2 v7.4s, v1.8h, #0 \n" /* trans to int32*/
"ldp q0, q1, [%[in]], #32 \n" /* load 16 int16*/
"scvtf v8.4s, v4.4s \n" /* trans to fp32*/
"scvtf v9.4s, v5.4s \n" /* trans to fp32*/
"scvtf v10.4s, v6.4s \n" /* trans to fp32*/
"scvtf v11.4s, v7.4s \n" /* trans to fp32*/
"subs %[loop], %[loop], #1 \n"
"fmul v4.4s, v8.4s, %[scale].4s \n" /* mul with scale*/
"fmul v5.4s, v9.4s, %[scale].4s \n" /* mul with scale*/
"fmul v6.4s, v10.4s, %[scale].4s \n" /* mul with scale*/
"fmul v7.4s, v11.4s, %[scale].4s \n" /* mul with scale*/
"stp q4, q5, [%[out]], #32 \n" /* write to memory*/
"stp q6, q7, [%[out]], #32 \n" /* write to memory*/
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[in]]! @ load 16 int16\n"
"0: @ main loop\n"
"vmovl.s16 q4, d0 @ trans to int32\n"
"vmovl.s16 q5, d1 @ trans to int32\n"
"vmovl.s16 q6, d2 @ trans to int32\n"
"vmovl.s16 q7, d3 @ trans to int32\n"
"vcvt.f32.s32 q0, q4 @ trans to fp32\n"
"vcvt.f32.s32 q1, q5 @ trans to fp32\n"
"vcvt.f32.s32 q2, q6 @ trans to fp32\n"
"vcvt.f32.s32 q3, q7 @ trans to fp32\n"
"vmul.f32 q4, q0, %q[scale] @ mul with scale\n"
"vmul.f32 q5, q1, %q[scale] @ mul with scale\n"
"vmul.f32 q6, q2, %q[scale] @ mul with scale\n"
"vmul.f32 q7, q3, %q[scale] @ mul with scale\n"
"vld1.32 {d0-d3}, [%[in]]! @ load 16 int8\n"
"subs %[loop], #1 \n"
"vst1.f32 {d8-d11}, [%[out]]! @ write to memory\n"
"vst1.f32 {d12-d15}, [%[out]]! @ write to memory\n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"
);
#endif //__aarch64__
}
const short* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
}
}
void int32_to_fp32(const int* din, float* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = axis_size * outer_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const int* din_c = din + n * inner_size;
float* dout_c = dout + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
if (cnt > 0) {
int loop = cnt;
const int* din_ptr = din_c;
float* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"ldp q0, q1, [%[in]], #32 \n"
"ldp q2, q3, [%[in]], #32 \n"
"0: \n"
"scvtf v4.4s, v0.4s \n"
"scvtf v5.4s, v1.4s \n"
"scvtf v6.4s, v2.4s \n"
"scvtf v7.4s, v3.4s \n"
"ldp q0, q1, [%[in]], #32 \n"
"fmul v8.4s, v4.4s, %[scale].4s \n"
"fmul v9.4s, v5.4s, %[scale].4s \n"
"fmul v10.4s, v6.4s, %[scale].4s \n"
"fmul v11.4s, v7.4s, %[scale].4s \n"
"ldp q2, q3, [%[in]], #32 \n"
"stp q8, q9, [%[out]], #32 \n"
"stp q10, q11, [%[out]], #32 \n"
"subs %[loop], %[loop], #1 \n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11"
);
#else
asm volatile(
"vld1.s32 {d0-d3}, [%[in]]! \n"
"vld1.s32 {d4-d7}, [%[in]]! \n"
"0: \n"
"vcvt.f32.s32 q4, q0 \n"
"vcvt.f32.s32 q5, q1 \n"
"vcvt.f32.s32 q6, q2 \n"
"vcvt.f32.s32 q7, q3 \n"
"vld1.s32 {d0-d3}, [%[in]]! \n"
"vmul.f32 q8, q4, %q[scale] \n"
"vmul.f32 q9, q5, %q[scale] \n"
"vmul.f32 q10, q6, %q[scale] \n"
"vmul.f32 q11, q7, %q[scale] \n"
"vld1.s32 {d4-d7}, [%[in]]! \n"
"subs %[loop], #1 \n"
"vst1.f32 {d16-d19}, [%[out]]! \n"
"vst1.f32 {d20-d23}, [%[out]]! \n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"
);
#endif //__aarch64__
}
const int* din_r = din_c + 16 * cnt;
float* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = in_scale * din_r[i];
}
}
}
void int32_to_int8(const int* din, signed char* dout, const float* scale, \
int axis_size, long long outer_size, long long inner_size) {
int cnt = inner_size / 16;
int remain = inner_size & 15;
long long loop_size = outer_size * axis_size;
#pragma omp parallel for
for (long long n = 0; n < loop_size; ++n) {
float in_scale = scale[n % axis_size];
const int* din_c = din + n * inner_size;
signed char* dout_c = dout + n * inner_size;
float32x4_t vscale = vdupq_n_f32(in_scale);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vpoff = vdupq_n_f32(0.5f);
float32x4_t vnoff = vdupq_n_f32(-0.5f);
if (cnt > 0) {
int loop = cnt;
const int* din_ptr = din_c;
signed char* dout_ptr = dout_c;
#ifdef __aarch64__
asm volatile(
"0: \n"
"ld1 {v0.4s, v1.4s}, [%[in]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in]], #32 \n"
"scvtf v4.4s, v0.4s \n"
"scvtf v5.4s, v1.4s \n"
"scvtf v6.4s, v2.4s \n"
"scvtf v7.4s, v3.4s \n"
"fmul v0.4s, v4.4s, %[scale].4s \n"
"fmul v1.4s, v5.4s, %[scale].4s \n"
"fmul v2.4s, v6.4s, %[scale].4s \n"
"fmul v3.4s, v7.4s, %[scale].4s \n"
"fcvtas v4.4s, v0.4s \n"
"fcvtas v5.4s, v1.4s \n"
"fcvtas v6.4s, v2.4s \n"
"fcvtas v7.4s, v3.4s \n"
"sqxtn v0.4h, v4.4s \n"
"sqxtn2 v0.8h, v5.4s \n"
"sqxtn v1.4h, v6.4s \n"
"sqxtn2 v1.8h, v7.4s \n"
"sqxtn v2.8b, v0.8h \n"
"sqxtn2 v2.16b, v1.8h \n"
"st1 {v2.16b}, [%[out]], #16 \n"
"subs %[loop], %[loop], #1 \n"
"bne 0b \n"
:[loop] "+r" (loop), [in] "+r" (din_ptr), [out] "+r" (dout_ptr)
:[scale] "w" (vscale)
:"v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
);
#else
asm volatile(
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"0: @ main loop\n"
"vcvt.f32.s32 q4, q0 @ cvt to float\n"
"vcvt.f32.s32 q5, q1 @ cvt to float\n"
"vcvt.f32.s32 q6, q2 @ cvt to float\n"
"vcvt.f32.s32 q7, q3 @ cvt to float\n"
"vand.i32 q0, %q[vpoff], %q[vpoff] @ set offset, 0.5\n"
"vand.i32 q1, q0, q0 @ set offset, 0.5\n"
"vand.i32 q2, q0, q0 @ set offset, 0.5\n"
"vand.i32 q3, q0, q0 @ set offset, 0.5\n"
"vcgt.f32 q8, q4, %q[vzero] @ get mask > 0, in0\n"
"vcgt.f32 q9, q5, %q[vzero] @ get mask > 0, in1\n"
"vcgt.f32 q10, q6, %q[vzero] @ get mask > 0, in2\n"
"vcgt.f32 q11, q7, %q[vzero] @ get mask > 0, in3\n"
"vbif.f32 q0, %q[vnoff], q8 @ get right offset\n"
"vbif.f32 q1, %q[vnoff], q9 @ get right offset\n"
"vbif.f32 q2, %q[vnoff], q10 @ get right offset\n"
"vbif.f32 q3, %q[vnoff], q11 @ get right offset\n"
"vmla.f32 q0, q4, %q[vscale] @ mul scale\n"
"vmla.f32 q1, q5, %q[vscale] @ mul scale\n"
"vmla.f32 q2, q6, %q[vscale] @ mul scale\n"
"vmla.f32 q3, q7, %q[vscale] @ mul scale\n"
"vcvt.s32.f32 q4, q0 @ cvt to int32\n"
"vcvt.s32.f32 q5, q1 @ cvt to int32\n"
"vcvt.s32.f32 q6, q2 @ cvt to int32\n"
"vcvt.s32.f32 q7, q3 @ cvt to int32\n"
"vqmovn.s32 d16, q4 @ cnt to int16\n"
"vqmovn.s32 d17, q5 @ cnt to int16\n"
"vqmovn.s32 d18, q6 @ cnt to int16\n"
"vqmovn.s32 d19, q7 @ cnt to int16\n"
"vld1.32 {d0-d3}, [%[din]]! @ load in0~in7\n"
"vqmovn.s16 d8, q8 @ cnt to int8\n"
"vqmovn.s16 d9, q9 @ cnt to int8\n"
"vld1.32 {d4-d7}, [%[din]]! @ load in8~in16\n"
"vst1.32 {d8-d9}, [%[dout]]! @ write to output\n"
"subs %[loop], #1 @ loop count -1\n"
"bne 0b @ to main loop\n"
:[loop] "+r" (loop), [din] "+r" (din_ptr), [dout] "+r" (dout_ptr)
:[vscale] "w" (vscale), [vzero] "w"(vzero), [vnoff] "w" (vnoff), [vpoff] "w" (vpoff)
:"q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"
);
#endif //__aarch64__
}
const int* din_r = din_c + 16 * cnt;
int8_t* dout_r = dout_c + 16 * cnt;
for (int i = 0; i < remain; ++i) {
dout_r[i] = saturate_cast<int8_t>(roundf(in_scale * din_r[i]));
}
}
}
void int32_to_int32(const int* din, int* dout, const float* scale, \
int axis_size, long long outer_size, long long inner_size) {
int size_all = outer_size * axis_size * inner_size;
memmove(dout, din, size_all*sizeof(int));
}
template <>
void int32_to_dtype(const int* din, float* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
return int32_to_fp32(din, dout, scale, axis_size, outer_size, inner_size);
}
template <>
void int32_to_dtype(const int* din, signed char* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
return int32_to_int8(din, dout, scale, axis_size, outer_size, inner_size);
}
template <>
void int32_to_dtype(const int* din, int* dout, const float* scale,
int axis_size, long long outer_size, long long inner_size) {
return int32_to_int32(din, dout, scale, axis_size, outer_size, inner_size);
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -65,6 +65,8 @@ class Buffer { ...@@ -65,6 +65,8 @@ class Buffer {
TargetCopy(target_, data_, other.data_, nbytes); TargetCopy(target_, data_, other.data_, nbytes);
} }
~Buffer() { Free(); }
private: private:
// memory it actually malloced. // memory it actually malloced.
size_t space_{0}; size_t space_{0};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pattern_matcher.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
namespace mir {
void BuildGraph(SSAGraph* g) {
g->mutable_nodes().emplace_back();
Node& o1 = g->mutable_nodes().back();
o1.AsStmt().op_type = "op1";
g->mutable_nodes().emplace_back();
Node& o2 = g->mutable_nodes().back();
o2.AsStmt().op_type = "op2";
g->mutable_nodes().emplace_back();
Node& o3 = g->mutable_nodes().back();
o3.AsStmt().op_type = "op3";
g->mutable_nodes().emplace_back();
Node& o4 = g->mutable_nodes().back();
o4.AsStmt().op_type = "op4";
g->mutable_nodes().emplace_back();
Node& o5 = g->mutable_nodes().back();
o5.AsStmt().op_type = "op5";
g->mutable_nodes().emplace_back();
Node& v1 = g->mutable_nodes().back();
v1.AsArg("var1");
g->mutable_nodes().emplace_back();
Node& v2 = g->mutable_nodes().back();
v2.AsArg("var2");
g->mutable_nodes().emplace_back();
Node& v3 = g->mutable_nodes().back();
v3.AsArg("var3");
g->mutable_nodes().emplace_back();
Node& v4 = g->mutable_nodes().back();
v4.AsArg("var4");
// o1->v1->o2
o1.outlinks.push_back(&v1);
o2.inlinks.push_back(&v1);
v1.inlinks.push_back(&o1);
v1.outlinks.push_back(&o2);
// o2->v2->o3
// o2->v2->o4
o2.outlinks.push_back(&v2);
o3.inlinks.push_back(&v2);
o4.inlinks.push_back(&v2);
v2.inlinks.push_back(&o2);
v2.outlinks.push_back(&o3);
v2.outlinks.push_back(&o4);
// o2->v3->o5
o2.outlinks.push_back(&v3);
o5.inlinks.push_back(&v3);
v3.inlinks.push_back(&o2);
v3.outlinks.push_back(&o5);
// o3-v4->o5
o3.outlinks.push_back(&v4);
o5.inlinks.push_back(&v4);
v4.inlinks.push_back(&o3);
v4.outlinks.push_back(&o5);
}
TEST(PMPattern, NewNode) {
PMPattern x;
auto* n = x.NewNode([](const Node* x) { return true; });
ASSERT_TRUE(n);
ASSERT_EQ(x.nodes_.size(), 1UL);
}
TEST(PMPattern, AddEdge) {
PMPattern x;
auto* a = x.NewNode([](const Node* x) { return true; });
auto* b = x.NewNode([](const Node* x) { return true; });
ASSERT_TRUE(a);
ASSERT_TRUE(b);
x.AddEdge(a, b);
ASSERT_EQ(x.nodes_.size(), 2UL);
ASSERT_EQ(x.edges_.size(), 1UL);
ASSERT_EQ(x.edges_.front().first, a);
ASSERT_EQ(x.edges_.front().second, b);
ASSERT_EQ(x.nodes().size(), 2UL);
ASSERT_EQ(x.edges().size(), 1UL);
ASSERT_EQ(x.edges().front().first, a);
ASSERT_EQ(x.edges().front().second, b);
}
TEST(PatternMatcher, MarkPMNodesInGraph) {
PatternMatcher x;
// mark o2, o3, v2
// The pattern is a graph:
// o2(a node named o2) -> v2(a node named v2)
// v2 -> o3(a node named o3)
auto* o2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op2";
});
auto* o3 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op3";
});
auto* v2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsArg() && node->arg()->name == "var2";
});
ASSERT_FALSE(o2->Tell(nullptr));
ASSERT_FALSE(o3->Tell(nullptr));
ASSERT_FALSE(v2->Tell(nullptr));
x.pattern_.AddEdge(o2, v2);
x.pattern_.AddEdge(v2, o3);
ASSERT_EQ(x.pattern_.edges().size(), 2UL);
ASSERT_EQ(x.pattern_.edges()[0].first, o2);
ASSERT_EQ(x.pattern_.edges()[0].second, v2);
ASSERT_EQ(x.pattern_.edges()[1].first, v2);
ASSERT_EQ(x.pattern_.edges()[1].second, o3);
SSAGraph graph;
BuildGraph(&graph);
x.MarkPMNodesInGraph(&graph);
ASSERT_EQ(x.pmnodes2nodes_.size(), 3UL);
auto subgraphs = x.DetectPatterns();
ASSERT_EQ(subgraphs.size(), 1UL);
}
TEST(PatternMatcher, MultiSubgraph) {
SSAGraph graph;
BuildGraph(&graph);
PatternMatcher x;
// The pattern is a graph:
// op -> var
auto* any_op = x.mutable_pattern()->NewNode(
[](const Node* node) {
return node->IsStmt() && (node->stmt()->op_type == "op2" ||
node->stmt()->op_type == "op3");
},
"OP0");
auto* any_var =
x.mutable_pattern()
->NewNode([](const Node* node) { return node->IsArg(); }, "VAR")
->AsIntermediate();
auto* any_op1 = x.mutable_pattern()->NewNode(
[](const Node* node) { return node->IsStmt(); }, "OP1");
x.mutable_pattern()->AddEdge(any_op, any_var);
x.mutable_pattern()->AddEdge(any_var, any_op1);
int count = 0;
PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s,
SSAGraph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> "
<< s.at(any_var)->arg()->name << " -> "
<< s.at(any_op1)->stmt()->op_type;
count++;
};
x(&graph, handle);
// 1. Detect op3 -> var4 -> op5
// 2. Detect op2 -> var2 -> op3
// 3. Detect op2 -> var2 -> op4
// 4. Detect op2 -> var3 -> op5
// But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2
ASSERT_GE(count, 1);
ASSERT_LE(count, 2);
}
TEST(PatternMatcher, IntermediateCheck) {
SSAGraph graph;
BuildGraph(&graph);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
PatternMatcher matcher;
auto* op2 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op2";
},
"op2");
auto* op3 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op3";
},
"op3");
auto* v2 = matcher.mutable_pattern()
->NewNode(
[](const Node* x) {
return x && x->IsArg() && x->arg()->name == "var2";
},
"var2")
->AsIntermediate();
v2->LinksFrom({op2}).LinksTo({op3});
int count = 0;
matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) {
++count;
});
EXPECT_EQ(count, 0);
count = 0;
v2->AsInput();
matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) {
++count;
});
ASSERT_EQ(count, 1);
}
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -91,9 +91,9 @@ class KernelRegistry final { ...@@ -91,9 +91,9 @@ class KernelRegistry final {
void Register(const std::string &name, void Register(const std::string &name,
typename KernelRegistryForTarget<Target, Precision, typename KernelRegistryForTarget<Target, Precision,
Layout>::creator_t &&creator) { Layout>::creator_t &&creator) {
// VLOG(3) << "register for " << TargetToStr(Target) << ":" VLOG(3) << "register for " << TargetToStr(Target) << ":"
//<< PrecisionToStr(Precision) << "//" << PrecisionToStr(Precision) << "//"
//<< GetKernelOffset<Target, Precision, Layout>(); << GetKernelOffset<Target, Precision, Layout>();
using kernel_registor_t = using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>; KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()]; auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()];
...@@ -153,6 +153,12 @@ class KernelRegistor : public lite::Registor<KernelType> { ...@@ -153,6 +153,12 @@ class KernelRegistor : public lite::Registor<KernelType> {
public: public:
KernelRegistor(const std::string &op_type, const std::string &alias) KernelRegistor(const std::string &op_type, const std::string &alias)
: Registor<KernelType>([=] { : Registor<KernelType>([=] {
<<<<<<< HEAD
=======
VLOG(3) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision)
<< " " << DataLayoutToStr(layout) << " alias " << alias;
>>>>>>> gitlab/develop
KernelRegistry::Global().Register<target, precision, layout>( KernelRegistry::Global().Register<target, precision, layout>(
op_type, [=]() -> std::unique_ptr<KernelType> { op_type, [=]() -> std::unique_ptr<KernelType> {
std::unique_ptr<KernelType> x(new KernelType); std::unique_ptr<KernelType> x(new KernelType);
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* looks the same. * looks the same.
*/ */
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
......
...@@ -9,12 +9,18 @@ cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) ...@@ -9,12 +9,18 @@ cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3)
cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm)
lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm)
lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm)
lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm)
lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm)
set(arm_kernels set(arm_kernels
fc_compute_arm fc_compute_arm
...@@ -22,6 +28,11 @@ set(arm_kernels ...@@ -22,6 +28,11 @@ set(arm_kernels
mul_compute_arm mul_compute_arm
scale_compute_arm scale_compute_arm
softmax_compute_arm softmax_compute_arm
elementwise_add_compute_arm) conv_compute_arm
elementwise_add_compute_arm
pool_compute_arm
split_compute_arm
)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ConvCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
auto& ctx = this->ctx_->template As<ARMContext>();
int win = x_dims[3]; // nchw
int hin = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int kh = w_dims[2]; // oihw
int kw = w_dims[3];
int pad = param.paddings[0];
int stride = param.strides[0];
const auto* i_data = param.x->data<float>();
const auto* w_data = param.filter->data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
auto* o_data = param.output->mutable_data<float>();
bool kps_equal = (param.paddings[0] == param.paddings[1]) &&
(param.strides[0] == param.strides[1]) && (kw == kh);
bool no_dilation = (param.dilations[0] == 1) && (param.dilations[1] == 1);
bool flag_dw_3x3 =
(kw == 3 && (pad == 0 || pad == 1) && (stride == 1 || stride == 2));
bool flag_dw_5x5 =
(kw == 5 && stride == 1) || (kw == 5 && stride == 2 && pad == 2);
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
// select conv impl
if (param.groups == ic && ic == oc && kps_equal && no_dilation && flag_dw) {
// dw conv impl
impl_ = new lite::arm::math::DepthwiseConv<PRECISION(kFloat)>;
VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
if (ic >= 32 && oc >= 32 && oh > 16 && ow > 16) {
// winograd conv impl
impl_ = new lite::arm::math::WinogradConv<PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
} else {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv";
}
} else if (param.groups == 1 && kw == 3 && stride == 2 && kps_equal &&
no_dilation) {
// direct conv impl
impl_ = new lite::arm::math::DirectConv<PRECISION(kFloat)>;
VLOG(3) << "invoking direct conv";
} else {
impl_ = new lite::arm::math::GemmLikeConv<PRECISION(kFloat)>;
VLOG(3) << "invoking gemm like conv";
}
CHECK(this->impl_->create(param, &ctx));
}
void ConvCompute::Run() {
auto& param = this->Param<param_t>();
CHECK(impl_);
impl_->run(param);
// if (this->act_ != nullptr) {
// this->act_->run(outputs, outputs, param.activation_param);
// }
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::ConvCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/conv_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
~ConvCompute() {
if (impl_ != nullptr) {
delete impl_;
}
}
private:
lite::arm::math::ImplBase<TARGET(kARM), PRECISION(kFloat), param_t>* impl_{
nullptr};
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/conv_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename dtype>
void conv_compute_ref(const operators::ConvParam& param) {
auto input = param.x;
auto filter = param.filter;
auto output = param.output;
DDim input_dims = param.x->dims();
DDim filter_dims = param.filter->dims();
DDim output_dims = param.output->dims();
std::vector<int> paddings = param.paddings;
std::vector<int> strides = param.strides;
std::vector<int> dilations = param.dilations;
int groups = param.groups;
auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>();
auto filter_data = param.filter->mutable_data<float>();
const float* bias_data = nullptr;
if (param.bias != nullptr) {
bias_data = param.bias->mutable_data<float>();
}
bool flag_bias = bias_data != nullptr;
bool flag_relu = false; // TODO(hong19860320) param.relu
int num = input_dims[0];
int chout = output_dims[1];
int hout = output_dims[2];
int wout = output_dims[3];
int chin = input_dims[1];
int hin = input_dims[2];
int win = input_dims[3];
int out_c_group = chout / groups;
int in_c_group = chin / groups;
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
int padding_h = paddings[0];
int padding_w = paddings[1];
int kernel_h = filter_dims[2];
int kernel_w = filter_dims[3];
for (int n = 0; n < num; ++n) {
for (int g = 0; g < groups; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < hout; ++oh) {
for (int ow = 0; ow < wout; ++ow) {
int out_idx = n * groups * out_c_group * hout * wout +
g * out_c_group * hout * wout + oc * hout * wout +
oh * wout + ow;
output_data[out_idx] =
flag_bias ? static_cast<float>(bias_data[g * out_c_group + oc])
: 0.f;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - padding_w + kw * (dilation_w);
int ih = oh * stride_h - padding_h + kh * (dilation_h);
if (iw < 0 || iw >= win) continue;
if (ih < 0 || ih >= hin) continue;
int iidx = n * chin * hin * win + g * in_c_group * hin * win +
ic * hin * win + ih * win + iw;
int widx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
output_data[out_idx] +=
(dtype)input_data[iidx] * (dtype)filter_data[widx];
}
}
}
if (flag_relu) {
output_data[out_idx] =
output_data[out_idx] > 0.f ? output_data[out_idx] : 0.f;
}
}
}
}
}
}
}
TEST(conv_arm, retrive_op) {
auto conv = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"conv2d");
ASSERT_FALSE(conv.empty());
ASSERT_TRUE(conv.front());
}
TEST(conv_arm, init) {
ConvCompute conv;
ASSERT_EQ(conv.precision(), PRECISION(kFloat));
ASSERT_EQ(conv.target(), TARGET(kARM));
}
TEST(conv_arm, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto ic : {6, 32 /*, 128*/}) {
for (auto oc : {6, 32 /*, 128*/}) {
for (auto ih : {9, 18 /*, 56 , 112, 224, 512*/}) {
for (auto iw : {9, 18 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1, 2}) {
for (auto ks : {1, 3, 5}) {
int group = 1;
if (depthwise) { // depthwise convolution ?
group = oc = ic;
}
// get input, filter and output shape
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {oc, ic / group,
ks, ks};
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
std::vector<int64_t> output_shape({n, oc, oh, ow});
// resize input, filter and output
Tensor input;
Tensor filter;
Tensor bias;
Tensor output;
Tensor output_ref;
input.Resize(input_shape);
filter.Resize(filter_shape);
output.Resize(output_shape);
output_ref.Resize(output_shape);
VLOG(3) << "input: " << input.dims();
VLOG(3) << "filter: " << filter.dims()
<< " padding:" << padding
<< " stride:" << stride
<< " dilation:" << dilation;
VLOG(3) << "output: " << output.dims();
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] =
i * 0.001f /
static_cast<float>(filter.dims().production());
}
// prepare kernel params and run
ConvCompute conv;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
conv.SetContext(std::move(ctx));
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
if (flag_bias) {
bias.Resize({oc});
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
param.bias = &bias;
}
// TODO(hong19860320) param.relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
std::vector<int>({dilation, dilation});
param.groups = group;
conv.SetParam(param);
conv.Launch();
// invoking ref implementation and compare results
param.output = &output_ref;
conv_compute_ref<float>(param);
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i],
1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/pool_compute.h"
#include <string>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void PoolCompute::Run() {
auto& param = Param<operators::PoolParam>();
auto& in_dims = param.x->dims();
auto& out_dims = param.output->dims();
const float* din = param.x->data<float>();
float* dout = param.output->mutable_data<float>();
std::vector<int>& ksize = param.ksize;
std::vector<int>& strides = param.strides;
std::vector<int>& paddings = param.paddings;
std::string& pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive;
bool adaptive = param.adaptive;
bool ceil_mode = param.ceil_mode;
bool use_quantizer = param.use_quantizer;
std::string& data_format = param.data_format;
if (param.global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_dims[i + 2]);
}
}
#if 0
for (int i = 0; i < in_dims.size(); ++i) {
LOG(INFO) << "in_dims[" << i << "]:" << in_dims[i];
}
for (int i = 0; i < out_dims.size(); ++i) {
LOG(INFO) << "out_dims[" << i << "]:" << out_dims[i];
}
for (int i = 0; i < ksize.size(); ++i) {
LOG(INFO) << "ksize[" << i << "]:" << ksize[i];
}
for (int i = 0; i < strides.size(); ++i) {
LOG(INFO) << "strides[" << i << "]:" << strides[i];
}
for (int i = 0; i < paddings.size(); ++i) {
LOG(INFO) << "paddings[" << i << "]:" << paddings[i];
}
LOG(INFO) << "global_pooling:" << global_pooling;
LOG(INFO) << "exclusive:" << exclusive;
LOG(INFO) << "adaptive:" << adaptive;
LOG(INFO) << "ceil_mode:" << ceil_mode;
LOG(INFO) << "use_quantizer:" << use_quantizer;
LOG(INFO) << "data_format:" << data_format;
LOG(INFO) << "din:" << din;
LOG(INFO) << "dout:" << dout;
#endif
// global
if (global_pooling == true) {
lite::arm::math::pooling_global(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1]) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 1 &&
strides[0] == strides[1] && paddings[0] == 1) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p1_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == 0) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p0_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == 1) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p1_ave(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
} else {
lite::arm::math::pooling_basic(
din, dout, out_dims[0], out_dims[1], out_dims[2], out_dims[3],
in_dims[1], in_dims[2], in_dims[3], ksize, strides, paddings,
global_pooling, exclusive, adaptive, ceil_mode, use_quantizer,
pooling_type);
}
return;
}
TargetType PoolCompute::target() const { return TARGET(kARM); }
PrecisionType PoolCompute::precision() const { return PRECISION(kFloat); }
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(pool, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::PoolCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/operators/pool_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class PoolCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::PoolParam;
void Run() override;
TargetType target() const override;
PrecisionType precision() const override;
virtual ~PoolCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/pool_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <string>
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void pool_compute_ref(const operators::PoolParam& param) {
auto& in_dims = param.x->dims();
auto& out_dims = param.output->dims();
const float* src_ptr = param.x->data<const float>();
float* dst_ptr = param.output->mutable_data<float>();
std::vector<int> ksize = param.ksize;
std::vector<int> strides = param.strides;
std::vector<int> paddings = param.paddings;
std::string pooling_type = param.pooling_type;
bool global_pooling = param.global_pooling;
bool exclusive = param.exclusive;
bool adaptive = param.adaptive;
bool ceil_mode = param.ceil_mode;
bool use_quantizer = param.use_quantizer;
std::string data_format = param.data_format;
int in_n = in_dims[0];
int in_c = in_dims[1];
int in_h = in_dims[2];
int in_w = in_dims[3];
int size_in_n = in_c * in_h * in_w;
int size_in_c = in_h * in_w;
int out_h = out_dims[2];
int out_w = out_dims[3];
int size_out_n = in_c * out_h * out_w;
int size_out_c = out_h * out_w;
int window_h = ksize[0];
int window_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
if (global_pooling == true) {
ksize[0] = in_h;
ksize[1] = in_w;
}
#if 0
for (int i = 0; i < ksize.size(); ++i) {
LOG(INFO) << "ksize[" << i << "]:" << ksize[i];
}
for (int i = 0; i < strides.size(); ++i) {
LOG(INFO) << "strides[" << i << "]:" << strides[i];
}
for (int i = 0; i < paddings.size(); ++i) {
LOG(INFO) << "paddings[" << i << "]:" << paddings[i];
}
LOG(INFO) << "in nchw:" << in_n << ", " << in_c << ", " << in_h << ", "
<< in_w;
LOG(INFO) << "size_in_n:" << size_in_n;
LOG(INFO) << "size_out_c:" << size_out_c;
LOG(INFO) << "out_h:" << out_h;
LOG(INFO) << "out_w:" << out_w;
LOG(INFO) << "size_out_n:" << size_out_n;
LOG(INFO) << "size_out_c:" << size_out_c;
LOG(INFO) << "window_h:" << window_h;
LOG(INFO) << "window_w:" << window_w;
LOG(INFO) << "stride_h:" << stride_h;
LOG(INFO) << "stride_w:" << stride_w;
LOG(INFO) << "pad_h:" << pad_h;
LOG(INFO) << "pad_w:" << pad_w;
#endif
for (int ind_n = 0; ind_n < in_n; ++ind_n) {
for (int ind_c = 0; ind_c < in_c; ++ind_c) {
for (int ind_h = 0; ind_h < out_h; ++ind_h) {
int sh = ind_h * stride_h;
int eh = sh + window_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > in_h ? in_h : eh - pad_h;
for (int ind_w = 0; ind_w < out_w; ++ind_w) {
int sw = ind_w * stride_w;
int ew = sw + window_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > in_w ? in_w : ew - pad_w;
float result = static_cast<float>(0);
int dst_ind =
ind_n * size_out_n + ind_c * size_out_c + ind_h * out_w + ind_w;
for (int kh = sh; kh < eh; ++kh) {
for (int kw = sw; kw < ew; ++kw) {
int src_ind =
ind_n * size_in_n + ind_c * size_in_c + kh * in_w + kw;
if (kh == sh && kw == sw) {
result = src_ptr[src_ind];
} else {
if (pooling_type == "max") {
result =
result >= src_ptr[src_ind] ? result : src_ptr[src_ind];
}
if (pooling_type == "avg" && exclusive == false) {
// Pooling_average_include_padding
result += src_ptr[src_ind];
}
if (pooling_type == "avg" && exclusive == true) {
// Pooling_average_include_padding
result += src_ptr[src_ind];
}
}
}
}
if (pooling_type == "avg" && exclusive == false) {
// Pooling_average_include_padding
// result /= param.window_h * param.window_w;
// LOG(ERROR)<<"cpu"<<param.window_h * param.window_w;
int bh = window_h;
int bw = window_w;
if (ew == in_w) {
bw = sw + window_w >= in_w + pad_w ? in_w + pad_w : sw + window_w;
bw -= sw;
}
if (eh == in_h) {
bh = sh + window_h >= in_h + pad_h ? in_h + pad_h : sh + window_h;
bh -= sh;
}
result /= bh * bw;
}
if (pooling_type == "avg" && exclusive == true) {
// Pooling_average_exclude_padding
result /= (ew - sw) * (eh - sh);
}
dst_ptr[dst_ind] = result;
}
}
}
}
}
TEST(pool_arm, init) {
PoolCompute pool;
ASSERT_EQ(pool.precision(), PRECISION(kFloat));
ASSERT_EQ(pool.target(), TARGET(kARM));
}
TEST(pool_arm, compute) {
PoolCompute pool;
operators::PoolParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
for (auto pooling_type : {"avg", "max"}) {
for (auto global_pooling : {true}) {
for (auto stride : {2}) {
for (auto pad : {0}) {
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4, 1024}) {
for (auto h : {3, 1, 11, 4, 1}) {
for (auto w : {1, 3, 4, 12, 1}) {
VLOG(3) << "n:" << n << " c:" << c << " h:" << h << " w:" << w
<< " stride:" << stride << " pad:" << pad
<< " pooling_type:" << pooling_type
<< " global_pooling:" << global_pooling;
// init x, output
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, 1, 1})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, 1, 1})));
auto* x_data = x.mutable_data<float>();
for (int i = 0; i < x.dims().production(); ++i) {
x_data[i] = i;
}
// fill param
param.x = &x;
param.output = &output;
param.pooling_type = pooling_type;
param.ksize = {h, w};
param.global_pooling = global_pooling;
param.strides = {stride, stride};
param.paddings = {pad, pad};
param.exclusive = true;
param.adaptive = false;
param.ceil_mode = false;
param.use_quantizer = false;
// compute
pool.SetParam(param);
pool.Run();
#if 0
LOG(INFO) << "n:" << n << " c:" << c << " h:" << h << " w:" << w
<< " end";
std::cout << "n:" << n << " c:" << c << " h:" << h << " w:" << w
<< " end" << std::endl;
for (int i = 0; i < param.ksize.size(); ++i) {
std::cout << " ksize[" << i << "]:" << param.ksize[i];
}
std::cout << "\n";
for (int i = 0; i < param.strides.size(); ++i) {
std::cout << " strides[" << i << "]:" << param.strides[i];
}
std::cout << "\n";
for (int i = 0; i < param.paddings.size(); ++i) {
std::cout << " paddings[" << i << "]:" << param.paddings[i];
}
std::cout << "\n";
#endif
// compute ref
// output_ref.Resize(output.dims());
param.output = &output_ref;
pool_compute_ref(param);
VLOG(3) << "pool_compute_ref(param) end";
// compare
auto* output_data = output.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i],
1); // 1e-5);
}
VLOG(3) << "compare pass";
}
}
}
}
} // pad
} // stride
} // global_pooling
} // pooling_type
}
TEST(pool, retrive_op) {
auto pool =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("pool");
ASSERT_FALSE(pool.empty());
ASSERT_TRUE(pool.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/split_compute.h"
#include <vector>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void SplitCompute::Run() {
auto& param = Param<operators::SplitParam>();
const float* din = param.x->data<float>();
auto* dout = param.output;
auto in_dim = param.x->dims();
std::vector<int> in_strides(in_dim.size());
in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1];
for (int i = in_dim.size() - 2; i >= 0; --i) {
in_strides[i] = in_strides[i + 1] * in_dim[i];
}
lite::arm::math::split(din, dout, param.axis, in_strides);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(split, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::SplitCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SplitCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~SplitCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/kernels/arm/split_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void splite_resize_out(const lite::Tensor* din,
std::vector<lite::Tensor*>* dout, int axis, int num,
const std::vector<int>& sections) {
for (auto out : *dout) delete out;
dout->clear();
auto in_dims = din->dims();
int outs_number;
if (num > 0) {
outs_number = num;
} else {
outs_number = sections.size();
}
for (int i = 0; i < outs_number; i++) {
dout->push_back(new lite::Tensor);
}
std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
int out_axis_dim = in_dims[axis] / num;
for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = out_axis_dim;
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = sections[i];
outs_dims.push_back(dim);
}
}
for (int j = 0; j < outs_dims.size(); ++j) {
(*dout)[j]->Resize(outs_dims[j]);
}
}
template <typename dtype>
void split_compute_ref(const operators::SplitParam& param) {
const dtype* din = param.x->mutable_data<const dtype>();
auto& dout = param.output;
auto in_dim = param.x->dims();
int axis = param.axis;
std::vector<int> in_strides(in_dim.size());
in_strides[in_dim.size() - 1] = in_dim[in_dim.size() - 1];
for (int i = in_dim.size() - 2; i >= 0; --i) {
in_strides[i] = in_strides[i + 1] * in_dim[i];
}
int input_offset = 0;
for (auto out : *dout) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
for (int i = out_dim.size() - 2; i >= 0; --i) {
out_strides[i] = out_strides[i + 1] * out_dim[i];
}
dtype* out_data = out->mutable_data<dtype>();
int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis];
int out_after = out_strides[axis];
for (int i = 0; i < before; ++i) {
std::memcpy(out_data + i * out_after, din + input_offset + i * in_after,
sizeof(dtype) * out_after);
}
input_offset += out_strides[axis];
}
}
TEST(split_arm, init) {
SplitCompute split;
ASSERT_EQ(split.precision(), PRECISION(kFloat));
ASSERT_EQ(split.target(), TARGET(kARM));
}
TEST(split_arm, compute) {
SplitCompute split;
operators::SplitParam param;
lite::Tensor x;
std::vector<lite::Tensor*> output;
std::vector<lite::Tensor*> output_ref;
for (auto n : {1, 3, 4}) {
for (auto c : {1, 3, 4}) {
for (auto h : {1, 3, 4}) {
for (auto w : {1, 3, 4}) {
for (auto axis : {0, 1, 2, 3}) {
for (auto num : {0, 1, 2, 3}) {
for (auto sections :
{std::vector<int>{1, 1, 1}, std::vector<int>{2, 2},
std::vector<int>{1, 2}}) {
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
x.Resize(x_dim);
if ((num != 0 && x_dim[axis] % num != 0) ||
(num == 0 && x_dim[axis] % sections.size() != 0))
continue;
auto* x_data = x.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i;
}
splite_resize_out(&x, &output, axis, num, sections);
splite_resize_out(&x, &output_ref, axis, num, sections);
param.x = &x;
param.axis = axis;
param.num = num;
param.sections = &sections;
param.output = &output;
split.SetParam(param);
split.Run();
param.output = &output_ref;
split_compute_ref<float>(param);
for (int i = 0; i < output.size(); i++) {
float* output_data = output[i]->mutable_data<float>();
float* output_ref_data = output_ref[i]->mutable_data<float>();
for (int j = 0; j < output[i]->dims().production(); j++) {
EXPECT_NEAR(output_data[j], output_ref_data[j], 1e-5);
}
}
}
}
}
}
}
}
}
}
TEST(split, retrive_op) {
auto split =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>("split");
ASSERT_FALSE(split.empty());
ASSERT_TRUE(split.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(split, kARM, kFloat, kNCHW, def);
...@@ -19,5 +19,6 @@ USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); ...@@ -19,5 +19,6 @@ USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); USE_LITE_KERNEL(feed, kARM, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def);
set(op_DEPS ${tensor_lite} op_lite op_params_lite) set(op_DEPS ${tensor_lite} op_lite op_params_lite)
cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS})
cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS})
cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS}) cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS})
cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS})
cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS})
...@@ -17,10 +19,11 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) ...@@ -17,10 +19,11 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS})
set(ops_lite set(ops_lite
conv_op_lite
pool_op_lite
fc_op_lite fc_op_lite
relu_op_lite relu_op_lite
mul_op_lite mul_op_lite
...@@ -36,14 +39,16 @@ set(ops_lite ...@@ -36,14 +39,16 @@ set(ops_lite
activation_ops_lite activation_ops_lite
dropout_op_lite dropout_op_lite
concat_op_lite concat_op_lite
conv_op_lite split_op_lite
pool_op_lite
PARENT_SCOPE) PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite DEPS fc_op_lite memory_lite
X86_DEPS fc_compute_x86 X86_DEPS fc_compute_x86
ARM_DEPS fc_compute_arm) ARM_DEPS fc_compute_arm)
lite_cc_test(test_pool_op_lite SRCS pool_op_test.cc
DEPS pool_op_lite memory_lite
ARM_DEPS pool_compute_arm)
lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
......
...@@ -24,31 +24,49 @@ bool ConvOpLite::CheckShape() const { ...@@ -24,31 +24,49 @@ bool ConvOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.filter); CHECK_OR_FALSE(param_.filter);
return true; // bias is optional.
}
bool ConvOpLite::InferShape() const { const auto in_dims = param_.x->dims();
auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims();
auto filter_dims = param_.filter->dims();
std::vector<int> strides = param_.strides;
std::vector<int> paddings = param_.paddings;
int groups = param_.groups;
std::vector<int> dilations = param_.dilations;
CHECK_OR_FALSE(in_dims.size() == 4 || in_dims.size() == 5); CHECK_OR_FALSE(in_dims.size() == 4 || in_dims.size() == 5);
CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size()); CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size());
CHECK_OR_FALSE(in_dims.size() - strides.size() == 2U); CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U);
CHECK_EQ_OR_FALSE(paddings.size(), strides.size()); CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size());
CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * groups);
CHECK_EQ_OR_FALSE(filter_dims[0] % groups, 0); CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * param_.groups);
CHECK_EQ_OR_FALSE(filter_dims[0] % param_.groups, 0);
CHECK_EQ_OR_FALSE(filter_dims.size(), 4UL);
return true;
}
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
CHECK_GT_OR_FALSE(output_size, 0);
return output_size;
}
bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims();
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < param_.strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(
dilations[i], paddings[i], ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], param_.dilations[i],
strides[i])); param_.paddings[i], param_.strides[i]));
} }
// Set output dims
param_.output->Resize(lite::DDim(output_shape)); param_.output->Resize(lite::DDim(output_shape));
// share LoD
// param_.output->set_lod(param_.x->lod());
return true; return true;
} }
......
...@@ -26,63 +26,53 @@ namespace paddle { ...@@ -26,63 +26,53 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
CHECK_OR_FALSE(output_size > 0);
return output_size;
}
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
class ConvOpLite : public OpLite { class ConvOpLite : public OpLite {
public: public:
ConvOpLite() {} ConvOpLite() {}
explicit ConvOpLite(const std::string& type) : OpLite(type) {} explicit ConvOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShape() const override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto X = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front(); auto filter = op_desc.Input("Filter").front();
auto Bias = op_desc.Input("Bias").front(); auto out = op_desc.Output("Out").front();
// auto ResidualData = op_desc.Input("ResidualData"); param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
auto Out = op_desc.Output("Output").front(); param_.filter = scope->FindVar(filter)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out));
param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.bias = scope->FindVar(Bias)->GetMutable<lite::Tensor>();
// param_.residualData =
// scope->FindVar(ResidualData)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings"); param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups"); param_.groups = op_desc.GetAttr<int>("groups");
param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations"); param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations");
// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front());
if (bias_var != nullptr) {
param_.bias =
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>()));
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
auto residual_data_var =
scope->FindVar(op_desc.Input("ResidualData").front());
if (residual_data_var != nullptr) {
param_.residualData = const_cast<lite::Tensor *>(
&(residual_data_var->Get<lite::Tensor>()));
}
}
return true; return true;
} }
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "conv2d"; } std::string DebugString() const override { return "conv2d"; }
private: private:
......
...@@ -124,8 +124,8 @@ struct ConcatParam { ...@@ -124,8 +124,8 @@ struct ConcatParam {
struct ConvParam { struct ConvParam {
lite::Tensor* x{}; lite::Tensor* x{};
lite::Tensor* filter{}; lite::Tensor* filter{};
lite::Tensor* bias{}; lite::Tensor* bias{nullptr};
lite::Tensor* residualData{}; lite::Tensor* residualData{nullptr};
lite::Tensor* output{}; lite::Tensor* output{};
std::vector<int> strides{1, 1}; std::vector<int> strides{1, 1};
std::vector<int> paddings{0, 0}; std::vector<int> paddings{0, 0};
...@@ -174,6 +174,15 @@ struct DropoutParam { ...@@ -174,6 +174,15 @@ struct DropoutParam {
std::string dropout_implementation{"downgrade_in_infer"}; std::string dropout_implementation{"downgrade_in_infer"};
}; };
// For Split op
struct SplitParam {
lite::Tensor* x{};
std::vector<lite::Tensor*>* output{};
int axis{-1};
int num{0};
std::vector<int>* sections;
};
/// ----------------------- element wise operators ---------------------- /// ----------------------- element wise operators ----------------------
struct ElementwiseParam { struct ElementwiseParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
...@@ -19,6 +19,27 @@ namespace paddle { ...@@ -19,6 +19,27 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
bool PoolOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
const auto& x_dims = param_.x->dims();
const auto& ksize = param_.ksize;
const auto& strides = param_.strides;
const auto& paddings = param_.paddings;
// "Pooling intput should be 4-D or 5-D tensor."
CHECK_OR_FALSE(x_dims.size() == 4 || x_dims.size() == 5);
// Input size and pooling size should be consistent.
CHECK_OR_FALSE(x_dims.size() - ksize.size() == 2U);
// Strides size and pooling size should be the same.
CHECK_OR_FALSE(ksize.size() == strides.size());
// Paddings size and pooling size should be the same.
CHECK_OR_FALSE(ksize.size() == paddings.size());
return true;
}
int PoolOutputSize(int input_size, int filter_size, int padding, int stride, int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
bool ceil_mode) { bool ceil_mode) {
int output_size; int output_size;
...@@ -28,46 +49,35 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride, ...@@ -28,46 +49,35 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
output_size = output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1; (input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
} }
CHECK_OR_FALSE(output_size > 0);
return output_size; return output_size;
} }
bool PoolOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
return true;
}
bool PoolOpLite::InferShape() const { bool PoolOpLite::InferShape() const {
const auto input_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
CHECK_OR_FALSE(input_dims.size() == 4 || input_dims.size() == 5); std::vector<int>& ksize = param_.ksize;
if (param_.global_pooling) { if (param_.global_pooling) {
param_.ksize.resize(static_cast<size_t>(input_dims.size()) - 2); ksize.resize(static_cast<size_t>(x_dims.size()) - 2);
for (size_t i = 0; i < param_.ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
param_.paddings[i] = 0; param_.paddings[i] = 0;
param_.ksize[i] = static_cast<int>(input_dims[i + 2]); ksize[i] = static_cast<int>(x_dims[i + 2]);
} }
} }
CHECK_OR_FALSE(input_dims.size() - param_.ksize.size() == 2U); std::vector<int64_t> output_shape({x_dims[0], x_dims[1]});
CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.strides.size());
CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.paddings.size());
std::vector<int64_t> output_shape({input_dims[0], input_dims[1]});
if (param_.adaptive) { if (param_.adaptive) {
output_shape.insert(output_shape.end(), param_.ksize.begin(), output_shape.insert(output_shape.end(), param_.ksize.begin(),
param_.ksize.end()); param_.ksize.end());
} else { } else {
for (size_t i = 0; i < param_.ksize.size(); ++i) { for (size_t i = 0; i < param_.ksize.size(); ++i) {
output_shape.push_back( output_shape.push_back(
PoolOutputSize(input_dims[i + 2], param_.ksize[i], param_.paddings[i], PoolOutputSize(x_dims[i + 2], param_.ksize[i], param_.paddings[i],
param_.strides[i], param_.ceil_mode)); param_.strides[i], param_.ceil_mode));
} }
} }
// share LoD
// param_.output->set_lod(param_.input->lod());
param_.output->Resize(lite::DDim(output_shape)); param_.output->Resize(lite::DDim(output_shape));
// ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
// ctx->ShareLoD("X", "Out");
return true; return true;
} }
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/scope.h"
...@@ -35,24 +37,32 @@ class PoolOpLite : public OpLite { ...@@ -35,24 +37,32 @@ class PoolOpLite : public OpLite {
bool InferShape() const override; bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<Tensor>(); CHECK(scope->FindVar(x));
param_.output = scope->FindVar(out)->GetMutable<Tensor>(); CHECK(scope->FindVar(out));
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.pooling_type = op_desc.GetAttr<std::string>("pooling_type"); param_.pooling_type = op_desc.GetAttr<std::string>("pooling_type");
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize"); param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings"); param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.ceil_mode = op_desc.GetAttr<bool>("ceil_mode");
param_.exclusive = op_desc.GetAttr<bool>("exclusive");
param_.adaptive = op_desc.GetAttr<bool>("adaptive"); param_.adaptive = op_desc.GetAttr<bool>("adaptive");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling"); param_.ceil_mode = op_desc.GetAttr<bool>("ceil_mode");
param_.use_quantizer = op_desc.GetAttr<bool>("use_quantizer");
// param_.data_format = op_desc.GetAttr<bool>("data_format");
return true; return true;
} }
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pool"; } std::string DebugString() const override { return "pool"; }
private: private:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/pool_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(pool_op_lite, test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize(DDim(std::vector<int64_t>({1, 3, 224, 224})));
output->Resize(DDim(std::vector<int64_t>{1, 3, 112, 112}));
// set data
for (int i = 0; i < 1 * 3 * 224 * 224; i++) {
x->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 1 * 3 * 112 * 112; i++) {
output->mutable_data<float>()[i] = 0.;
}
// prepare op desc
cpp::OpDesc desc;
desc.SetType("pool");
desc.SetInput("X", {"x"});
desc.SetOutput("Out", {"output"});
std::string pooling_type("max");
desc.SetAttr("pooling_type", pooling_type);
// desc.SetAttr("ksize", static_cast<std::vector<int>>({2, 2}));
std::vector<int> ksize{2, 2};
desc.SetAttr("ksize", ksize);
bool global_pooling{false};
desc.SetAttr("global_pooling", global_pooling);
std::vector<int> strides{1, 1};
desc.SetAttr("strides", strides);
std::vector<int> paddings{0, 0};
desc.SetAttr("paddings", paddings);
bool exclusive{true};
desc.SetAttr("exclusive", exclusive);
bool adaptive{false};
desc.SetAttr("adaptive", adaptive);
bool ceil_mode{false};
desc.SetAttr("ceil_mode", ceil_mode);
bool use_quantizer{false};
desc.SetAttr("use_quantizer", use_quantizer);
PoolOpLite pool("pool");
pool.SetValidPlaces({Place{TARGET(kARM), PRECISION(kFloat)}});
pool.Attach(desc, &scope);
auto kernels = pool.CreateKernels({Place{TARGET(kARM), PRECISION(kFloat)}});
LOG(INFO) << "kernels.size(): " << kernels.size();
ASSERT_FALSE(kernels.empty());
}
} // namespace operators
} // namespace lite
} // namespace paddle
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL(pool, kARM, kFloat, kNCHW, def);
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/split_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SplitOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
CHECK_OR_FALSE(param_.axis >= -static_cast<int>(x_rank) &&
param_.axis < static_cast<int>(x_rank));
return true;
}
bool SplitOp::InferShape() const {
const auto &outs = param_.output;
auto in_dims = param_.x.dims();
int axis = param_.axis;
int num = param_.num;
const auto &sections = param_.sections;
const int outs_number = outs.size();
std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
int out_axis_dim = in_dims[axis] / num;
for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = out_axis_dim;
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = sections[i];
outs_dims.push_back(dim);
}
}
for (int j = 0; j < outs_dims.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
return true;
}
bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.axis = opdesc.GetAttr<int>("axis");
param_.num = opdesc.GetAttr<int>("num");
param_.sections = opdesc.GetAttr<std::vector<int>>("sections");
param_.x = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
auto outs = op_desc.Output("Out");
for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SoftmaxOp : public OpLite {
public:
SplitOp() {}
explicit SplitOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "split"; }
private:
mutable SplitParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -34,7 +34,6 @@ class Any { ...@@ -34,7 +34,6 @@ class Any {
CHECK(type_ == typeid(T).hash_code()); CHECK(type_ == typeid(T).hash_code());
} else { } else {
type_ = typeid(T).hash_code(); type_ = typeid(T).hash_code();
data_ = new T;
deleter_ = [&] { delete static_cast<T*>(data_); }; deleter_ = [&] { delete static_cast<T*>(data_); };
} }
data_ = new T; data_ = new T;
...@@ -55,10 +54,16 @@ class Any { ...@@ -55,10 +54,16 @@ class Any {
bool valid() const { return data_; } bool valid() const { return data_; }
// ~Any() {
// if (valid()) {
// deleter_();
// }
// }
private: private:
static size_t kInvalidType; static size_t kInvalidType;
size_t type_{kInvalidType}; size_t type_{kInvalidType};
void* data_{}; void* data_{nullptr};
std::function<void()> deleter_; std::function<void()> deleter_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册