pool_arm_func.h 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef POOL_OP
#pragma once

#include <string>
#include <vector>
#include "operators/math/pooling.h"

namespace paddle_mobile {
namespace operators {
using framework::Tensor;

Z
ZhenWang 已提交
26 27 28 29
template <typename T, typename S>
void PoolBasic(std::string pooling_type, std::vector<int> ksize,
               std::vector<int> strides, std::vector<int> paddings,
               const Tensor *in_x, Tensor *out) {
30
  if (pooling_type == "max") {
Z
ZhenWang 已提交
31 32
    math::PoolFunctor<CPU, math::MaxPool<T>, T> pool2d_forward;
    math::MaxPool<T> pool_process;
33 34 35
    pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);

  } else if (pooling_type == "avg") {
Z
ZhenWang 已提交
36 37
    math::PoolFunctor<CPU, math::AvgPool<T, S>, T> pool2d_forward;
    math::AvgPool<T, S> pool_process;
38 39 40
    pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
  }
}
Z
ZhenWang 已提交
41

42
template <typename P>
N
nhzlx 已提交
43
void PoolCompute(const PoolParam<CPU> &param) {
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
  const Tensor *in_x = param.Input();
  Tensor *out = param.Output();
  std::string pooling_type = param.PoolingType();

  std::vector<int> ksize = param.Ksize();

  std::vector<int> strides = param.Strides();

  std::vector<int> paddings = param.Paddings();
  if (ksize.size() != 2) {
    LOG(paddle_mobile::LogLevel::kLOG_ERROR)
        << "Pool op only supports 2D and 3D input.";
  }
  if (param.isGlobalPooling()) {
    for (size_t i = 0; i < ksize.size(); ++i) {
      paddings[i] = 0;
      ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
    }
N
nhzlx 已提交
62
  }
Z
ZhenWang 已提交
63 64 65 66 67 68
  if (in_x->type() == typeid(int8_t)) {
    if (pooling_type == "max" && ksize[0] == 3 && ksize[0] == ksize[1]) {
      if (strides[0] == strides[1] && strides[0] == 1) {
        math::Pool3x3Maxs1_int8(in_x, out, paddings[0], paddings[1]);
      } else if (strides[0] == strides[1] && strides[0] == 2) {
        math::Pool3x3Maxs2_int8(in_x, out, paddings[0], paddings[1]);
69
      } else {
Z
ZhenWang 已提交
70
        math::Pool3x3Max_int8(strides, paddings, in_x, out);
71
      }
Z
ZhenWang 已提交
72 73 74
    } else {
      PoolBasic<int8_t, int32_t>(pooling_type, ksize, strides, paddings, in_x,
                                 out);
75
    }
Z
ZhenWang 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  } else {
    if (ksize[0] == 3 && ksize[0] == ksize[1]) {
      if (pooling_type == "max") {
        if (strides[0] == strides[1] && strides[0] == 1 &&
            paddings[0] == paddings[1] && paddings[1] == 1) {
          math::Pool3x3Maxs1p1(in_x, out);
        } else {
          math::Pool3x3Max(strides, paddings, in_x, out);
        }
      } else if (pooling_type == "avg") {
        if (strides[0] == strides[1] && strides[0] == 1 &&
            paddings[0] == paddings[1] && paddings[1] == 1) {
          math::Pool3x3Avgs1p1(in_x, out);
        } else {
          math::Pool3x3Avg(strides, paddings, in_x, out);
        }
      }
93

Z
ZhenWang 已提交
94 95 96
    } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
               strides[0] == strides[1] && paddings[0] == paddings[1] &&
               paddings[1] == 0) {
97
#if __ARM_NEON
L
liuruilong 已提交
98
#if __aarch64__
Z
ZhenWang 已提交
99 100
      PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
                              out);
L
liuruilong 已提交
101
#else
Z
ZhenWang 已提交
102 103 104 105 106 107
      /// todo: fix bug in Pool2x2
      if (pooling_type == "max") {
        math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
      } else if (pooling_type == "avg") {
        math::Pool2x2Avgs2p0(strides, paddings, in_x, out);
      }
L
liuruilong 已提交
108
#endif
109
#else
Z
ZhenWang 已提交
110 111
      PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
                              out);
112 113
#endif  // __ARM_NEON

Z
ZhenWang 已提交
114 115 116 117
    } else {
      PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
                              out);
    }
118 119 120 121 122 123
  }
}

}  // namespace operators
}  // namespace paddle_mobile
#endif