提交 1949fb0c 编写于 作者: H hjchen2

Add pooling2x2 neon implementation

上级 2bbf01d1
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#include "operators/kernel/pool_kernel.h" #include "operators/kernel/pool_kernel.h"
#include "../central-arm-func/pool_arm_func.h" #include "operators/kernel/central-arm-func/pool_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -28,7 +29,8 @@ template <> ...@@ -28,7 +29,8 @@ template <>
void PoolKernel<CPU, float>::Compute(const PoolParam<CPU> &param) { void PoolKernel<CPU, float>::Compute(const PoolParam<CPU> &param) {
PoolCompute<float>(param); PoolCompute<float>(param);
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif // POOL_OP
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#pragma once #pragma once
#include <string> #include <string>
...@@ -54,8 +55,24 @@ void PoolCompute(const PoolParam<CPU> &param) { ...@@ -54,8 +55,24 @@ void PoolCompute(const PoolParam<CPU> &param) {
} else { } else {
math::Pooling<AVG>()(*input, ksize, strides, paddings, output); math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
} }
} else { }
// Others } else if (ksize[0] == 2 && ksize[0] == ksize[1]) {
if (pooling_type == "max" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling2x2<MAX, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling2x2<MAX, 2>()(*input, paddings, output);
} else {
math::Pooling<MAX>()(*input, ksize, strides, paddings, output);
}
} else if (pooling_type == "avg" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling2x2<AVG, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling2x2<AVG, 2>()(*input, paddings, output);
} else {
math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
}
} }
} else { } else {
if (pooling_type == "max") { if (pooling_type == "max") {
......
...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#if defined(__ARM_NEON__) && !defined(__aarch64__) #if defined(__ARM_NEON__) && !defined(__aarch64__)
#include "operators/math/depthwise_conv5x5.h" #include "operators/math/depthwise_conv5x5.h"
...@@ -81,7 +79,6 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, ...@@ -81,7 +79,6 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start; int valid_w_end = output_w - valid_w_start;
float *output_ptr = output + h_output * output_w; float *output_ptr = output + h_output * output_w;
// border left // border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
...@@ -111,6 +108,8 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, ...@@ -111,6 +108,8 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
_sum = vdupq_n_f32(0.f); _sum = vdupq_n_f32(0.f);
int remain_start = valid_w_start + (output_tiles << 2); int remain_start = valid_w_start + (output_tiles << 2);
int input_w_offset = remain_start * Stride_w - padding_w; int input_w_offset = remain_start * Stride_w - padding_w;
float *output_ptr0 = output_ptr + remain_start;
for (int h_in = h_start; h_in < h_end; ++h_in) { for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start; int index = h_in - h_in_start;
Depth5x5NormalRowLoadInput<Stride_w>( Depth5x5NormalRowLoadInput<Stride_w>(
...@@ -123,14 +122,14 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, ...@@ -123,14 +122,14 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
} }
switch (remain) { switch (remain) {
case 1: case 1:
vst1_lane_f32(output_ptr + remain_start, vget_low_f32(_sum), 0); vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0);
break; break;
case 2: case 2:
vst1_f32(output_ptr + remain_start, vget_low_f32(_sum)); vst1_f32(output_ptr0, vget_low_f32(_sum));
break; break;
case 3: case 3:
vst1_f32(output_ptr + remain_start, vget_low_f32(_sum)); vst1_f32(output_ptr0, vget_low_f32(_sum));
vst1_lane_f32(output_ptr + remain_start + 2, vget_high_f32(_sum), 0); vst1_lane_f32(output_ptr0 + 2, vget_high_f32(_sum), 0);
break; break;
} }
} }
......
此差异已折叠。
...@@ -169,28 +169,55 @@ int main(int argc, char *argv[]) { ...@@ -169,28 +169,55 @@ int main(int argc, char *argv[]) {
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2"; << "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// // kernel = 5, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO)
// LOG(paddle_mobile::kLOG_INFO) << "float, pooling_type=max, kernel=2, pad=0, stride=1";
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width);
// stride=1"; LOG(paddle_mobile::kLOG_INFO)
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height, << "float, pooling_type=max, kernel=2, pad=1, stride=1";
// in_width); paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width);
// // kernel = 5, pad = 0, stride = 2 LOG(paddle_mobile::kLOG_INFO)
// LOG(paddle_mobile::kLOG_INFO) << "float, pooling_type=max, kernel=2, pad=2, stride=1";
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width);
// stride=1"; LOG(paddle_mobile::kLOG_INFO)
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 2>(in_channels, in_height, << "float, pooling_type=max, kernel=2, pad=5, stride=1";
// in_width); paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width);
// // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, << "float, pooling_type=avg, kernel=2, pad=0, stride=1";
// stride=1"; paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width);
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height, LOG(paddle_mobile::kLOG_INFO)
// in_width); << "float, pooling_type=avg, kernel=2, pad=1, stride=1";
// // kernel = 7, pad = 0, stride = 4 paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, << "float, pooling_type=avg, kernel=2, pad=2, stride=1";
// stride=4"; paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width);
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height, LOG(paddle_mobile::kLOG_INFO)
// in_width); << "float, pooling_type=avg, kernel=2, pad=5, stride=1";
paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=0, stride=2";
paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=1, stride=2";
paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=2, stride=2";
paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=2, pad=5, stride=2";
paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=0, stride=2";
paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=1, stride=2";
paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=2, stride=2";
paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=2, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册