未验证 提交 31150e99 编写于 作者: H HappyAngel 提交者: GitHub

[arm]add 2x2s2p1 pooling (#3705)

* fix pooling bug and speed

* add 2x2s2p1 pooling. test=develop

* fix conflict, test=develop
上级 49a60805
此差异已折叠。
...@@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din, ...@@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_max(const float* din, void pooling2x2s2p0_max(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_avg(const float* din, void pooling2x2s2p0_avg(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive, bool exclusive,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_max(const float* din, void pooling3x3s1p1_max(const float* din,
float* dout, float* dout,
......
...@@ -50,13 +50,14 @@ class PoolingPE : public PE { ...@@ -50,13 +50,14 @@ class PoolingPE : public PE {
PoolingArgs args = {0}; PoolingArgs args = {0};
args.mode = param_.type; args.mode = param_.type;
auto paddings = *param_.paddings;
args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height)); args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height));
args.image.address = input->data<float16>(); args.image.address = input->data<float16>();
args.image.channels = input->shape().channel(); args.image.channels = input->shape().channel();
args.image.height = input->shape().height(); args.image.height = input->shape().height();
args.image.width = input->shape().width(); args.image.width = input->shape().width();
args.image.pad_height = param_.paddings[0]; args.image.pad_height = paddings[0];
args.image.pad_width = param_.paddings[1]; args.image.pad_width = paddings[2];
args.image.scale_address = input->scale(); args.image.scale_address = input->scale();
args.output.address = output->mutableData<float16>(); args.output.address = output->mutableData<float16>();
args.output.scale_address = output->scale(); args.output.scale_address = output->scale();
...@@ -69,8 +70,7 @@ class PoolingPE : public PE { ...@@ -69,8 +70,7 @@ class PoolingPE : public PE {
param_.poolingArgs = args; param_.poolingArgs = args;
// use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 // use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1
// && // && (k_width > 7 || k_height > 7);
// (k_width > 7 || k_height > 7);
use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 &&
(k_width > 255 || k_height > 255); (k_width > 255 || k_height > 255);
// use_cpu_ = param_.type == AVERAGE; // use_cpu_ = param_.type == AVERAGE;
...@@ -86,12 +86,13 @@ class PoolingPE : public PE { ...@@ -86,12 +86,13 @@ class PoolingPE : public PE {
float* image_addr = float_input.mutableData<float>(FP32, input->shape()); float* image_addr = float_input.mutableData<float>(FP32, input->shape());
float_input.copyFrom(input); float_input.copyFrom(input);
float16* data_out = output->data<float16>(); float16* data_out = output->data<float16>();
auto paddings = *param_.paddings;
int image_height = input->shape().height(); int image_height = input->shape().height();
int image_width = input->shape().width(); int image_width = input->shape().width();
int image_channels = input->shape().channel(); int image_channels = input->shape().channel();
int image_pad_h = param_.paddings[0]; int image_pad_h = paddings[0];
int image_pad_w = param_.paddings[1]; int image_pad_w = paddings[2];
int kernel_height = param_.kernelSize[1]; int kernel_height = param_.kernelSize[1];
int kernel_width = param_.kernelSize[0]; int kernel_width = param_.kernelSize[0];
int kernel_step_h = param_.strides[0]; int kernel_step_h = param_.strides[0];
......
...@@ -58,6 +58,7 @@ void PoolCompute::Run() { ...@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0; paddings[2 * i] = 0;
...@@ -107,35 +108,65 @@ void PoolCompute::Run() { ...@@ -107,35 +108,65 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, lite::arm::math::pooling2x2s2p0_max(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
out_dims[2], out_dims[2],
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
paddings[1], paddings[1],
paddings[3]); paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din, lite::arm::math::pooling2x2s2p0_avg(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
out_dims[2], out_dims[2],
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive, exclusive,
paddings[1], paddings[1],
paddings[3]); paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din, lite::arm::math::pooling3x3s1p1_max(din,
dout, dout,
...@@ -165,7 +196,7 @@ void PoolCompute::Run() { ...@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din, lite::arm::math::pooling3x3s1p0_max(din,
dout, dout,
...@@ -195,7 +226,7 @@ void PoolCompute::Run() { ...@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din, lite::arm::math::pooling3x3s2p0_max(din,
dout, dout,
...@@ -225,7 +256,7 @@ void PoolCompute::Run() { ...@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din, lite::arm::math::pooling3x3s2p1_max(din,
dout, dout,
......
...@@ -678,15 +678,9 @@ void resize(const uint8_t* src, ...@@ -678,15 +678,9 @@ void resize(const uint8_t* src,
} else if (srcFormat == NV12 || srcFormat == NV21) { } else if (srcFormat == NV12 || srcFormat == NV21) {
nv21_resize(src, dst, srcw, srch, dstw, dsth); nv21_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) { } else if (srcFormat == BGR || srcFormat == RGB) {
bgr_resize(src, dst, srcw, srch, dstw, dsth); bgr_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) { } else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4; w_in = srcw * 4;
w_out = dstw * 4; w_out = dstw * 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册