提交 8bd4b2be 编写于 作者: L liuqi

Change the pooling functor init logic.

上级 47eece0b
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
// //
#include "mace/kernels/pooling.h" #include "mace/kernels/pooling.h"
#include <arm_neon.h>
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -61,9 +59,15 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()( ...@@ -61,9 +59,15 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
const index_t *input_shape, const index_t *input_shape,
float *output, float *output,
const index_t *output_shape) { const index_t *output_shape) {
int paddings[2];
std::vector<index_t> filter_shape = {input_shape[1], input_shape[0],
kernels_[0], kernels_[1]};
kernels::CalPaddingSize(input_shape, filter_shape.data(), this->dilations_,
strides_, this->padding_, paddings);
#ifdef __COPY_MAKE_PADDING #ifdef __COPY_MAKE_PADDING
Tensor padded_input; Tensor padded_input;
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); ConstructInputWithPadding(input, input_shape, paddings, &padded_input);
input = padded_input.data<float>(); input = padded_input.data<float>();
input_shape = padded_input.shape().data(); input_shape = padded_input.shape().data();
#endif #endif
...@@ -76,14 +80,14 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()( ...@@ -76,14 +80,14 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape); PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape);
#else #else
PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape, PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape,
paddings_); paddings);
#endif #endif
} else { // AVG_POOL_2x2s2x2 } else { // AVG_POOL_2x2s2x2
#ifdef __COPY_MAKE_PADDING #ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK2x2S2x2Padded(input, input_shape, output, output_shape); PoolingAvgNeonK2x2S2x2Padded(input, input_shape, output, output_shape);
#else #else
PoolingAvgNeonK2x2S2x2(input, input_shape, output, output_shape, PoolingAvgNeonK2x2S2x2(input, input_shape, output, output_shape,
paddings_); paddings);
#endif #endif
} }
} else if (kernels_[0] == 3 && kernels_[1] == 3 && strides_[0] == 2 && } else if (kernels_[0] == 3 && kernels_[1] == 3 && strides_[0] == 2 &&
...@@ -94,19 +98,19 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()( ...@@ -94,19 +98,19 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
PoolingMaxNeonK3x3S2x2Padded(input, input_shape, output, output_shape); PoolingMaxNeonK3x3S2x2Padded(input, input_shape, output, output_shape);
#else #else
PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape, PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape,
paddings_); paddings);
#endif #endif
} else { // AVG_POOL_3x3s2x2 } else { // AVG_POOL_3x3s2x2
#ifdef __COPY_MAKE_PADDING #ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK3x3S2x2Padded(input, input_shape, output, output_shape); PoolingAvgNeonK3x3S2x2Padded(input, input_shape, output, output_shape);
#else #else
PoolingAvgNeonK3x3S2x2(input, input_shape, output, output_shape, PoolingAvgNeonK3x3S2x2(input, input_shape, output, output_shape,
paddings_); paddings);
#endif #endif
} }
} else { // not implement yet } else { // not implement yet
PoolingFunctor<DeviceType::CPU, float>(pooling_type_, kernels_, strides_, PoolingFunctor<DeviceType::CPU, float>(pooling_type_, kernels_, strides_,
paddings_, dilations_)( padding_, dilations_)(
input, input_shape, output, output_shape); input, input_shape, output, output_shape);
} }
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <limits> #include <limits>
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace { namespace mace {
...@@ -22,12 +23,12 @@ struct PoolingFunctor { ...@@ -22,12 +23,12 @@ struct PoolingFunctor {
PoolingFunctor(const PoolingType pooling_type, PoolingFunctor(const PoolingType pooling_type,
const int *kernels, const int *kernels,
const int *strides, const int *strides,
const int *paddings, const Padding padding,
const int *dilations) const int *dilations)
: pooling_type_(pooling_type), : pooling_type_(pooling_type),
kernels_(kernels), kernels_(kernels),
strides_(strides), strides_(strides),
paddings_(paddings), padding_(padding),
dilations_(dilations) {} dilations_(dilations) {}
void operator()(const T *input, void operator()(const T *input,
...@@ -54,9 +55,14 @@ struct PoolingFunctor { ...@@ -54,9 +55,14 @@ struct PoolingFunctor {
int dilation_h = dilations_[0]; int dilation_h = dilations_[0];
int dilation_w = dilations_[1]; int dilation_w = dilations_[1];
int paddings[2];
std::vector<index_t> filter_shape = {input_shape[1], input_shape[0],
kernels_[0], kernels_[1]};
kernels::CalPaddingSize(input_shape, filter_shape.data(), this->dilations_,
strides_, this->padding_, paddings);
// The left-upper most offset of the padded input // The left-upper most offset of the padded input
int padded_h_start = 0 - paddings_[0] / 2; int padded_h_start = 0 - paddings[0] / 2;
int padded_w_start = 0 - paddings_[1] / 2; int padded_w_start = 0 - paddings[1] / 2;
if (pooling_type_ == MAX) { if (pooling_type_ == MAX) {
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
...@@ -116,7 +122,7 @@ struct PoolingFunctor { ...@@ -116,7 +122,7 @@ struct PoolingFunctor {
const PoolingType pooling_type_; const PoolingType pooling_type_;
const int *kernels_; const int *kernels_;
const int *strides_; const int *strides_;
const int *paddings_; const Padding padding_;
const int *dilations_; const int *dilations_;
}; };
...@@ -127,6 +133,13 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()( ...@@ -127,6 +133,13 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
float *output, float *output,
const index_t *output_shape); const index_t *output_shape);
template <>
void PoolingFunctor<DeviceType::OPENCL, float>::operator()(
const float *input,
const index_t *input_shape,
float *output,
const index_t *output_shape);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -62,6 +62,22 @@ cc_test( ...@@ -62,6 +62,22 @@ cc_test(
], ],
) )
cc_test(
name = "pooling_test",
testonly = 1,
srcs = glob(
["pooling_test.cc"],
),
copts = ["-std=c++11"],
linkopts = ["-fopenmp"] + if_android(["-ldl"]),
linkstatic = 1,
deps = [
":ops",
":test",
"@gtest//:gtest_main",
],
)
cc_test( cc_test(
name = "ops_benchmark", name = "ops_benchmark",
testonly = 1, testonly = 1,
......
...@@ -19,7 +19,9 @@ class PoolingOp : public ConvPool2dOpBase<D, T> { ...@@ -19,7 +19,9 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")), kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")),
pooling_type_( pooling_type_(
static_cast<PoolingType>(OperatorBase::GetSingleArgument<int>( static_cast<PoolingType>(OperatorBase::GetSingleArgument<int>(
"pooling_type", static_cast<int>(AVG)))){}; "pooling_type", static_cast<int>(AVG)))),
functor_(pooling_type_, kernels_.data(), ConvPool2dOpBase::strides_.data(),
ConvPool2dOpBase::padding_, ConvPool2dOpBase::dilations_.data()){};
bool Run() override { bool Run() override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
...@@ -40,10 +42,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> { ...@@ -40,10 +42,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
paddings.data()); paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
auto pooling_func = kernels::PoolingFunctor<D, T>( functor_(input->data<float>(), input->shape().data(),
pooling_type_, kernels_.data(), this->strides_.data(), paddings.data(),
this->dilations_.data());
pooling_func(input->data<float>(), input->shape().data(),
output->mutable_data<float>(), output->shape().data()); output->mutable_data<float>(), output->shape().data());
return true; return true;
}; };
...@@ -51,6 +50,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> { ...@@ -51,6 +50,7 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
protected: protected:
std::vector<int> kernels_; std::vector<int> kernels_;
PoolingType pooling_type_; PoolingType pooling_type_;
kernels::PoolingFunctor<D, T> functor_;
OP_INPUT_TAGS(INPUT); OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT); OP_OUTPUT_TAGS(OUTPUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册