未验证 提交 31150e99 编写于 作者: H HappyAngel 提交者: GitHub

[arm]add 2x2s2p1 pooling (#3705)

* fix pooling bug and speed

* add 2x2s2p1 pooling. test=develop

* fix conflict, test=develop
上级 49a60805
此差异已折叠。
...@@ -76,7 +76,7 @@ void pooling1x1s2p0_max(const float* din, ...@@ -76,7 +76,7 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_max(const float* din, void pooling2x2s2p0_max(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
...@@ -88,7 +88,32 @@ void pooling2x2s2_max(const float* din, ...@@ -88,7 +88,32 @@ void pooling2x2s2_max(const float* din,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_avg(const float* din, void pooling2x2s2p0_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
......
...@@ -50,13 +50,14 @@ class PoolingPE : public PE { ...@@ -50,13 +50,14 @@ class PoolingPE : public PE {
PoolingArgs args = {0}; PoolingArgs args = {0};
args.mode = param_.type; args.mode = param_.type;
auto paddings = *param_.paddings;
args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height)); args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height));
args.image.address = input->data<float16>(); args.image.address = input->data<float16>();
args.image.channels = input->shape().channel(); args.image.channels = input->shape().channel();
args.image.height = input->shape().height(); args.image.height = input->shape().height();
args.image.width = input->shape().width(); args.image.width = input->shape().width();
args.image.pad_height = param_.paddings[0]; args.image.pad_height = paddings[0];
args.image.pad_width = param_.paddings[1]; args.image.pad_width = paddings[2];
args.image.scale_address = input->scale(); args.image.scale_address = input->scale();
args.output.address = output->mutableData<float16>(); args.output.address = output->mutableData<float16>();
args.output.scale_address = output->scale(); args.output.scale_address = output->scale();
...@@ -69,8 +70,7 @@ class PoolingPE : public PE { ...@@ -69,8 +70,7 @@ class PoolingPE : public PE {
param_.poolingArgs = args; param_.poolingArgs = args;
// use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 // use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1
// && // && (k_width > 7 || k_height > 7);
// (k_width > 7 || k_height > 7);
use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 &&
(k_width > 255 || k_height > 255); (k_width > 255 || k_height > 255);
// use_cpu_ = param_.type == AVERAGE; // use_cpu_ = param_.type == AVERAGE;
...@@ -86,12 +86,13 @@ class PoolingPE : public PE { ...@@ -86,12 +86,13 @@ class PoolingPE : public PE {
float* image_addr = float_input.mutableData<float>(FP32, input->shape()); float* image_addr = float_input.mutableData<float>(FP32, input->shape());
float_input.copyFrom(input); float_input.copyFrom(input);
float16* data_out = output->data<float16>(); float16* data_out = output->data<float16>();
auto paddings = *param_.paddings;
int image_height = input->shape().height(); int image_height = input->shape().height();
int image_width = input->shape().width(); int image_width = input->shape().width();
int image_channels = input->shape().channel(); int image_channels = input->shape().channel();
int image_pad_h = param_.paddings[0]; int image_pad_h = paddings[0];
int image_pad_w = param_.paddings[1]; int image_pad_w = paddings[2];
int kernel_height = param_.kernelSize[1]; int kernel_height = param_.kernelSize[1];
int kernel_width = param_.kernelSize[0]; int kernel_width = param_.kernelSize[0];
int kernel_step_h = param_.strides[0]; int kernel_step_h = param_.strides[0];
......
...@@ -58,6 +58,7 @@ void PoolCompute::Run() { ...@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0; paddings[2 * i] = 0;
...@@ -107,7 +108,7 @@ void PoolCompute::Run() { ...@@ -107,7 +108,7 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, lite::arm::math::pooling2x2s2p0_max(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
...@@ -120,7 +121,7 @@ void PoolCompute::Run() { ...@@ -120,7 +121,7 @@ void PoolCompute::Run() {
paddings[3]); paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din, lite::arm::math::pooling2x2s2p0_avg(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
...@@ -134,8 +135,38 @@ void PoolCompute::Run() { ...@@ -134,8 +135,38 @@ void PoolCompute::Run() {
paddings[3]); paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din, lite::arm::math::pooling3x3s1p1_max(din,
dout, dout,
...@@ -165,7 +196,7 @@ void PoolCompute::Run() { ...@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din, lite::arm::math::pooling3x3s1p0_max(din,
dout, dout,
...@@ -195,7 +226,7 @@ void PoolCompute::Run() { ...@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din, lite::arm::math::pooling3x3s2p0_max(din,
dout, dout,
...@@ -225,7 +256,7 @@ void PoolCompute::Run() { ...@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din, lite::arm::math::pooling3x3s2p1_max(din,
dout, dout,
......
...@@ -678,15 +678,9 @@ void resize(const uint8_t* src, ...@@ -678,15 +678,9 @@ void resize(const uint8_t* src,
} else if (srcFormat == NV12 || srcFormat == NV21) { } else if (srcFormat == NV12 || srcFormat == NV21) {
nv21_resize(src, dst, srcw, srch, dstw, dsth); nv21_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) { } else if (srcFormat == BGR || srcFormat == RGB) {
bgr_resize(src, dst, srcw, srch, dstw, dsth); bgr_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) { } else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4; w_in = srcw * 4;
w_out = dstw * 4; w_out = dstw * 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册