未验证 提交 31150e99 编写于 作者: H HappyAngel 提交者: GitHub

[arm]add 2x2s2p1 pooling (#3705)

* fix pooling bug and speed

* add 2x2s2p1 pooling. test=develop

* fix conflict, test=develop
上级 49a60805
此差异已折叠。
......@@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom,
int pad_right);
void pooling2x2s2_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling2x2s2p0_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p0_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_max(const float* din,
float* dout,
......
......@@ -50,13 +50,14 @@ class PoolingPE : public PE {
PoolingArgs args = {0};
args.mode = param_.type;
auto paddings = *param_.paddings;
args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height));
args.image.address = input->data<float16>();
args.image.channels = input->shape().channel();
args.image.height = input->shape().height();
args.image.width = input->shape().width();
args.image.pad_height = param_.paddings[0];
args.image.pad_width = param_.paddings[1];
args.image.pad_height = paddings[0];
args.image.pad_width = paddings[2];
args.image.scale_address = input->scale();
args.output.address = output->mutableData<float16>();
args.output.scale_address = output->scale();
......@@ -69,8 +70,7 @@ class PoolingPE : public PE {
param_.poolingArgs = args;
// use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1
// &&
// (k_width > 7 || k_height > 7);
// && (k_width > 7 || k_height > 7);
use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 &&
(k_width > 255 || k_height > 255);
// use_cpu_ = param_.type == AVERAGE;
......@@ -86,12 +86,13 @@ class PoolingPE : public PE {
float* image_addr = float_input.mutableData<float>(FP32, input->shape());
float_input.copyFrom(input);
float16* data_out = output->data<float16>();
auto paddings = *param_.paddings;
int image_height = input->shape().height();
int image_width = input->shape().width();
int image_channels = input->shape().channel();
int image_pad_h = param_.paddings[0];
int image_pad_w = param_.paddings[1];
int image_pad_h = paddings[0];
int image_pad_w = paddings[2];
int kernel_height = param_.kernelSize[1];
int kernel_width = param_.kernelSize[0];
int kernel_step_h = param_.strides[0];
......
......@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling;
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0;
......@@ -107,35 +108,65 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
lite::arm::math::pooling2x2s2p0_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
lite::arm::math::pooling2x2s2p0_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din,
dout,
......@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din,
dout,
......@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din,
dout,
......@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din,
dout,
......
......@@ -678,15 +678,9 @@ void resize(const uint8_t* src,
} else if (srcFormat == NV12 || srcFormat == NV21) {
nv21_resize(src, dst, srcw, srch, dstw, dsth);
return;
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) {
bgr_resize(src, dst, srcw, srch, dstw, dsth);
return;
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4;
w_out = dstw * 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册