未验证 提交 7eb4a391 编写于 作者: H HappyAngel 提交者: GitHub

improve pooling speed in gaze model. (#3881)

* improve pooling speed in gaze. test=develoop

* fix format test=develop
上级 6d787479
此差异已折叠。
...@@ -76,7 +76,7 @@ void pooling1x1s2p0_max(const float* din, ...@@ -76,7 +76,7 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_max(const float* din, void pooling2x2s2p0_max(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
...@@ -88,7 +88,32 @@ void pooling2x2s2_max(const float* din, ...@@ -88,7 +88,32 @@ void pooling2x2s2_max(const float* din,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_avg(const float* din, void pooling2x2s2p0_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
......
...@@ -58,6 +58,7 @@ void PoolCompute::Run() { ...@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0; paddings[2 * i] = 0;
...@@ -107,7 +108,7 @@ void PoolCompute::Run() { ...@@ -107,7 +108,7 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, lite::arm::math::pooling2x2s2p0_max(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
...@@ -120,7 +121,7 @@ void PoolCompute::Run() { ...@@ -120,7 +121,7 @@ void PoolCompute::Run() {
paddings[3]); paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din, lite::arm::math::pooling2x2s2p0_avg(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
...@@ -134,8 +135,38 @@ void PoolCompute::Run() { ...@@ -134,8 +135,38 @@ void PoolCompute::Run() {
paddings[3]); paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din, lite::arm::math::pooling3x3s1p1_max(din,
dout, dout,
...@@ -165,7 +196,7 @@ void PoolCompute::Run() { ...@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din, lite::arm::math::pooling3x3s1p0_max(din,
dout, dout,
...@@ -195,7 +226,7 @@ void PoolCompute::Run() { ...@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din, lite::arm::math::pooling3x3s2p0_max(din,
dout, dout,
...@@ -225,7 +256,7 @@ void PoolCompute::Run() { ...@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din, lite::arm::math::pooling3x3s2p1_max(din,
dout, dout,
...@@ -276,7 +307,6 @@ void PoolCompute::Run() { ...@@ -276,7 +307,6 @@ void PoolCompute::Run() {
use_quantizer, use_quantizer,
pooling_type); pooling_type);
} }
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册