提交 afa836d9 编写于 作者: H hjchen2

Refactor pooling implementation

上级 889c8ebc
......@@ -102,6 +102,11 @@ enum ActivationType {
Sigmoid = 6,
};
enum PoolingType {
Max = 0,
Avg = 1,
};
extern const char *G_OP_TYPE_CONV;
extern const char *G_OP_TYPE_BATCHNORM;
extern const char *G_OP_TYPE_BOX_CODER;
......
......@@ -17,103 +17,53 @@ limitations under the License. */
#include <string>
#include <vector>
#include "common/types.h"
#include "operators/math/pooling.h"
namespace paddle_mobile {
namespace operators {
using framework::Tensor;
template <typename T, typename S>
void PoolBasic(std::string pooling_type, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
if (pooling_type == "max") {
math::PoolFunctor<CPU, math::MaxPool<T>, T> pool2d_forward;
math::MaxPool<T> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
} else if (pooling_type == "avg") {
math::PoolFunctor<CPU, math::AvgPool<T, S>, T> pool2d_forward;
math::AvgPool<T, S> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
}
}
template <typename P>
void PoolCompute(const PoolParam<CPU> &param) {
const Tensor *in_x = param.Input();
Tensor *out = param.Output();
std::string pooling_type = param.PoolingType();
const framework::Tensor *input = param.Input();
framework::Tensor *output = param.Output();
const std::string &pooling_type = param.PoolingType();
std::vector<int> ksize = param.Ksize();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
if (ksize.size() != 2) {
LOG(paddle_mobile::LogLevel::kLOG_ERROR)
<< "Pool op only supports 2D and 3D input.";
}
if (param.isGlobalPooling()) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
ksize[i] = static_cast<int>(input->dims()[i + 2]);
}
}
if (in_x->type() == typeid(int8_t)) {
if (pooling_type == "max" && ksize[0] == 3 && ksize[0] == ksize[1]) {
if (strides[0] == strides[1] && strides[0] == 1) {
math::Pool3x3Maxs1_int8(in_x, out, paddings[0], paddings[1]);
} else if (strides[0] == strides[1] && strides[0] == 2) {
math::Pool3x3Maxs2_int8(in_x, out, paddings[0], paddings[1]);
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling3x3<Max, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling3x3<Max, 2>()(*input, paddings, output);
} else {
math::Pool3x3Max_int8(strides, paddings, in_x, out);
math::Pooling<Max>()(*input, ksize, strides, paddings, output);
}
} else if (pooling_type == "avg" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling3x3<Avg, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling3x3<Avg, 2>()(*input, paddings, output);
} else {
math::Pooling<Avg>()(*input, ksize, strides, paddings, output);
}
} else {
PoolBasic<int8_t, int32_t>(pooling_type, ksize, strides, paddings, in_x,
out);
// Others
}
} else {
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
}
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) {
#if __ARM_NEON
#if __aarch64__
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
#else
/// todo: fix bug in Pool2x2
if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool2x2Avgs2p0(strides, paddings, in_x, out);
}
#endif
#else
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
#endif // __ARM_NEON
if (pooling_type == "max") {
math::Pooling<Max>()(*input, ksize, strides, paddings, output);
} else if (pooling_type == "avg") {
math::Pooling<Avg>()(*input, ksize, strides, paddings, output);
} else {
PoolBasic<float, float>(pooling_type, ksize, strides, paddings, in_x,
out);
// Others
}
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#include "operators/math/pool_2x2.h"
#include <algorithm>
#include <vector>
namespace paddle_mobile {
namespace operators {
namespace math {
#define FLT_MAX __FLT_MAX__
void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
const Tensor *input, Tensor *output) {
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int ksize_height = 2;
const int ksize_width = 2;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
int w1 = input_width / 16;
int _w1 = input_width % 16;
int w2 = _w1 / 4;
int _w2 = _w1 % 4;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < input_height; ph += 2) {
const float *in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width;
const float *in_ptr2 = in_ptr1 + input_width;
if (ph != input_height && ph + 1 >= input_height) {
in_ptr2 = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * input_width));
memset(static_cast<void *>(const_cast<float *>(in_ptr2)), -FLT_MAX,
sizeof(float) * input_width);
}
float *out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width;
#if __ARM_NEON
#if __aarch64__
#else
asm volatile(
"subs %[w1], %[w1], #1 \n\t"
"blt end_w1_%= \n\t"
"loop_w1_%=: \n\t"
"pld [%[in_ptr1], #64] \n\t"
"pld [%[in_ptr2], #64] \n\t"
"vld1.f32 {q0, q1}, [%[in_ptr1]]! \n\t"
"vld1.f32 {q2, q3}, [%[in_ptr2]]! \n\t"
"vld1.f32 {q6, q7}, [%[in_ptr1]]! \n\t"
"vld1.f32 {q8, q9}, [%[in_ptr2]]! \n\t"
"vmax.f32 q0, q0, q2 \n\t"
"vmax.f32 q1, q1, q3 \n\t"
"vmax.f32 q6, q6, q8 \n\t"
"vmax.f32 q7, q7, q9 \n\t"
"vpmax.f32 d8, d0, d1 \n\t"
"vpmax.f32 d9, d2, d3 \n\t"
"vpmax.f32 d10, d12, d13 \n\t"
"vpmax.f32 d11, d14, d15 \n\t"
"vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
"subs %[w1], %[w1], #1 \n\t"
"bge loop_w1_%= \n\t"
"end_w1_%=: \n\t"
"subs %[w2], %[w2], #1 \n\t"
"blt end_w2_%= \n\t"
"loop_w2_%=: \n\t"
"vld1.f32 {q0}, [%[in_ptr1]]! \n\t"
"vld1.f32 {q1}, [%[in_ptr2]]! \n\t"
"vmax.f32 q0, q0, q1 \n\t"
"vpmax.f32 d4, d0, d1 \n\t"
"vst1.32 {d4}, [%[out_ptr]]! \n\t"
"subs %[w2], %[w2], #1 \n\t"
"bge loop_w2_%= \n\t"
"end_w2_%=: \n\t"
:
: [w1] "r"(w1), [w2] "r"(w2), [in_ptr1] "r"(in_ptr1),
[in_ptr2] "r"(in_ptr2), [out_ptr] "r"(out_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9");
#endif
#endif
if (_w2 != 0) {
in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width + 16 * w1 +
4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) {
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
} else if (_w2 == 2) {
float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
*out_ptr = (temp > temp1) ? temp : temp1;
} else if (_w2 == 3) {
float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
*out_ptr = (temp > temp1) ? temp : temp1;
out_ptr++;
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
}
}
}
}
}
}
void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
const Tensor *input, Tensor *output) {
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
const int input_width = input->dims()[3];
const int output_channels = output->dims()[1];
int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int ksize_height = 2;
const int ksize_width = 2;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
int w1 = input_width / 16;
int _w1 = input_width % 16;
int w2 = _w1 / 4;
int _w2 = _w1 % 4;
float quarter = 0.25;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < input_height; ph += 2) {
const float *in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width;
const float *in_ptr2 = in_ptr1 + input_width;
if (ph + 1 >= input_height) {
in_ptr2 = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * input_width));
memset(static_cast<void *>(const_cast<float *>(in_ptr2)), 0,
sizeof(float) * input_width);
}
float *out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width;
#if __ARM_NEON
#if __aarch64__
#else
asm volatile(
"subs %[w1], %[w1], #1 \n\t"
"blt end_w1_%= \n\t"
"loop_w1_%=: \n\t"
"pld [%[in_ptr1], #64] \n\t"
"pld [%[in_ptr2], #64] \n\t"
"vmov.f32 d0[0], %[quarter] \n\t"
"vld1.f32 {q1, q2}, [%[in_ptr1]]! \n\t"
"vld1.f32 {q3, q4}, [%[in_ptr2]]! \n\t"
"vld1.f32 {q7, q8}, [%[in_ptr1]]! \n\t"
"vld1.f32 {q9, q10}, [%[in_ptr2]]! \n\t"
"vadd.f32 q1, q1, q3 \n\t"
"vadd.f32 q2, q2, q4 \n\t"
"vadd.f32 q7, q7, q9 \n\t"
"vadd.f32 q8, q8, q10 \n\t"
"vpadd.f32 d10, d2, d3 \n\t"
"vpadd.f32 d11, d4, d5 \n\t"
"vpadd.f32 d12, d14, d15 \n\t"
"vpadd.f32 d13, d16, d17 \n\t"
"vmul.f32 q5, q5, d0[0] \n\t"
"vmul.f32 q6, q6, d0[0] \n\t"
"vst1.32 {q5, q6}, [%[out_ptr]]! \n\t"
"subs %[w1], %[w1], #1 \n\t"
"bge loop_w1_%= \n\t"
"end_w1_%=: \n\t"
"subs %[w2], %[w2], #1 \n\t"
"blt end_w2_%= \n\t"
"loop_w2_%=: \n\t"
"vld1.f32 {q1}, [%[in_ptr1]]! \n\t"
"vld1.f32 {q2}, [%[in_ptr2]]! \n\t"
"vadd.f32 q1, q1, q2 \n\t"
"vpadd.f32 d4, d2, d3 \n\t"
"vmul.f32 d4, d4, d0[0] \n\t"
"vst1.32 {d4}, [%[out_ptr]]! \n\t"
"subs %[w2], %[w2], #1 \n\t"
"bge loop_w2_%= \n\t"
"end_w2_%=: \n\t"
:
: [w1] "r"(w1), [w2] "r"(w2), [in_ptr1] "r"(in_ptr1),
[in_ptr2] "r"(in_ptr2), [out_ptr] "r"(out_ptr),
[quarter] "r"(quarter)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10");
#endif
#endif
if (_w2 != 0) {
in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width + 16 * w1 +
4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) {
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
} else if (_w2 == 2) {
float temp = 0;
temp += *in_ptr1;
temp += *in_ptr2;
in_ptr1++;
in_ptr2++;
temp += *in_ptr1;
temp += *in_ptr2;
*out_ptr = 0.25 * temp;
} else if (_w2 == 3) {
float temp = 0;
temp += *in_ptr1++;
temp += *in_ptr2++;
temp += *in_ptr1++;
temp += *in_ptr2++;
*out_ptr = 0.25 * temp;
out_ptr++;
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
}
}
}
}
}
}
//}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#pragma once
#include "framework/tensor.h"
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif // __ARM_NEON
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::Tensor;
using std::vector;
void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
const Tensor *input, Tensor *output);
void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
const Tensor *in_x, Tensor *out);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#pragma once
#ifdef _OPENMP
#include <omp.h>
#endif
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#if __ARM_NEON
#include <arm_neon.h>
#endif // __ARM_NEON
namespace paddle_mobile {
namespace operators {
namespace math {
void Pool3x3Avgs1p1(const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Maxs1p1(const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Max(std::vector<int> strides, std::vector<int> paddings,
const framework::Tensor *input, framework::Tensor *output);
void Pool3x3Avg(std::vector<int> strides, std::vector<int> paddings,
const framework::Tensor *in_x, framework::Tensor *out);
void Pool3x3Maxs1_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Maxs2_int8(const framework::Tensor *input,
framework::Tensor *output, int32_t pad_h, int32_t pad_w);
void Pool3x3Max_int8(const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *input, framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
此差异已折叠。
......@@ -15,87 +15,68 @@ limitations under the License. */
#ifdef POOL_OP
#include "operators/math/pooling.h"
#include <algorithm>
#include <vector>
#include "common/types.h"
#ifdef _OPENMP
#include <omp.h>
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename PoolProcess, typename T>
class PoolFunctor<CPU, PoolProcess, T> {
public:
void operator()(const framework::Tensor &input, const std::vector<int> &ksize,
const std::vector<int> &strides,
const std::vector<int> &paddings, PoolProcess pool_process,
framework::Tensor *output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T *input_data = input.data<T>();
T *output_data = output->mutable_data<T>();
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
#pragma omp parallel for
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
auto ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.compute(input_data[h * input_width + w], &ele);
}
template <PoolingType P>
void Pooling<P>::operator()(const framework::Tensor &input,
const std::vector<int> &kernel_size,
const std::vector<int> &strides,
const std::vector<int> &paddings,
framework::Tensor *output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int ksize_height = kernel_size[0];
const int ksize_width = kernel_size[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const float *input_data = input.data<float>();
float *output_data = output->mutable_data<float>();
const size_t input_spatial_size = input_height * input_width;
const size_t output_spatial_size = output_height * output_width;
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
int channel = i * output_channels + c;
const float *input_ptr = input_data + channel * input_spatial_size;
float *output_ptr = output_data + channel * output_spatial_size;
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
PoolingVal<P> val;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
val += input_ptr[h * input_width + w];
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(static_cast<float>(pool_size), &ele);
output_data[ph * output_width + pw] = static_cast<T>(ele);
}
output_data[ph * output_width + pw] = val.Value();
}
input_data += input_stride;
output_data += output_stride;
}
}
}
};
}
template struct Pooling<Max>;
template struct Pooling<Avg>;
template class PoolFunctor<CPU, math::AvgPool<float, float>, float>;
template class PoolFunctor<CPU, math::MaxPool<float>, float>;
template class PoolFunctor<CPU, math::AvgPool<int8_t, int32_t>, int8_t>;
template class PoolFunctor<CPU, math::MaxPool<int8_t>, int8_t>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // POOL_OP
......@@ -16,75 +16,143 @@ limitations under the License. */
#pragma once
#include <climits>
#include <algorithm>
#include <cmath>
#include "common/log.h"
#include <limits>
#include <vector>
#include "common/types.h"
#include "framework/tensor.h"
#include "pool_2x2.h"
#include "pool_3x3.h"
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
#define FLT_MAX __FLT_MAX__
/*
* \brief Extracting simple operations from pooling.
* Both MaxPool and AvgPool need "initial", "compute" and "finalize"
* operation.
* MaxPool initializes temp variable to the negative maximum to find the
* maximum value in the pooling field.
* AvgPool initializes temp variable to the zero to accumulate all values
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template <typename T>
class MaxPool {
public:
inline T initial() {
if (typeid(T) == typeid(int8_t)) {
return static_cast<T>(-SCHAR_MAX);
template <PoolingType P = Max>
struct PoolingVal {
float val;
int count;
PoolingVal() {
val = std::numeric_limits<float>::min();
count = 0;
}
inline PoolingVal<P> &operator+=(const float &x) {
val = std::max(val, x);
count += 1;
return *this;
}
float Value() const {
if (count > 0) {
return val;
}
return static_cast<T>(-FLT_MAX);
return 0.f;
}
inline void compute(const T &x, T *y) { *y = *y > x ? *y : x; }
inline void finalize(const T &pool_field, T *y) {}
};
template <typename Itype, typename Otype>
class AvgPool {
public:
inline Otype initial() { return static_cast<Otype>(0); }
inline void compute(const Itype &x, Otype *y) { *y += x; }
inline void finalize(const float &pool_field, Otype *y) {
if (typeid(Itype) == typeid(int8_t)) {
float tmp = *y / pool_field;
if (tmp > SCHAR_MAX) {
*y = SCHAR_MAX;
} else if (tmp < -SCHAR_MAX) {
*y = -SCHAR_MAX;
} else {
*y = static_cast<Otype>(std::round(tmp));
}
} else {
*y /= pool_field;
template <>
struct PoolingVal<Avg> {
float val;
int count;
PoolingVal() {
val = 0.f;
count = 0;
}
inline PoolingVal<Avg> &operator+=(const float &x) {
val += x;
count += 1;
return *this;
}
float Value() const {
if (count > 0) {
return val / count;
}
return 0.f;
}
};
template <typename DeviceType, typename PoolProcess, typename T>
class PoolFunctor {
public:
void operator()(const framework::Tensor &input, const std::vector<int> &ksize,
const std::vector<int> &strides,
const std::vector<int> &paddings, PoolProcess pool_compute,
framework::Tensor *output);
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
template <PoolingType P = Max>
inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) {
return vmaxq_f32(x1, x2);
}
template <>
inline float32x4_t vPoolPreq_f32<Avg>(const float32x4_t &x1,
const float32x4_t &x2) {
return vaddq_f32(x1, x2);
}
template <PoolingType P = Max>
inline float32x4_t vPoolPostq_f32(const float32x4_t &x) {
return x;
}
template <>
inline float32x4_t vPoolPostq_f32<Avg>(const float32x4_t &x) {
float32x4_t avg = vdupq_n_f32(1.f / 9);
return vmulq_f32(avg, x);
}
#endif // __ARM_NEON__
template <PoolingType P = Max>
inline float PoolPre(const float &x1, const float &x2) {
return std::max(x1, x2);
}
template <>
inline float PoolPre<Avg>(const float &x1, const float &x2) {
return x1 + x2;
}
template <PoolingType P = Max>
inline float PoolPost(const float &x) {
return x;
}
template <>
inline float PoolPost<Avg>(const float &x) {
return 1.f / 9 * x;
}
template <PoolingType P>
struct Pooling {
inline void operator()(const framework::Tensor &input,
const std::vector<int> &kernel_size,
const std::vector<int> &strides,
const std::vector<int> &paddings,
framework::Tensor *output);
};
template <PoolingType P, int Stride>
struct Pooling2x2 {
inline void operator()(const framework::Tensor &input,
const std::vector<int> &paddings,
framework::Tensor *output);
};
template <PoolingType P, int Stride>
struct Pooling3x3 {
inline void operator()(const framework::Tensor &input,
const std::vector<int> &paddings,
framework::Tensor *output);
};
template <PoolingType P, int Stride>
struct Pooling5x5 {
inline void operator()(const framework::Tensor &input,
const std::vector<int> &paddings,
framework::Tensor *output);
};
template <PoolingType P, int Stride>
struct Pooling7x7 {
inline void operator()(const framework::Tensor &input,
const std::vector<int> &paddings,
framework::Tensor *output);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
此差异已折叠。
......@@ -14,10 +14,13 @@ limitations under the License. */
#include <iostream>
#include "../test_include.h"
#include "operators/kernel/central-arm-func/pool_arm_func.h"
#include "operators/math/pooling.h"
#include "operators/pool_op.h"
namespace paddle_mobile {
namespace math = operators::math;
static int PoolOutputSize(int input_size, int filter_size, int padding,
int stride, bool ceil_mode) {
int output_size;
......@@ -30,70 +33,6 @@ static int PoolOutputSize(int input_size, int filter_size, int padding,
return output_size;
}
template <typename T>
static void PoolAvgPad0(std::vector<int> ksize, std::vector<int> strides,
const framework::Tensor *input,
framework::Tensor *out) {
const int32_t batch_size = input->dims()[0];
const int32_t input_c = input->dims()[1];
const int32_t input_h = input->dims()[2];
const int32_t input_w = input->dims()[3];
const int32_t out_c = out->dims()[1];
const int32_t out_h = out->dims()[2];
const int32_t out_w = out->dims()[3];
const int32_t kernel_h = ksize[0];
const int32_t kernel_w = ksize[1];
const int32_t stride_h = strides[0];
const int32_t stride_w = strides[1];
const int32_t inputdata_channel_stride = input_h * input_w;
const int32_t input_batch_stride = input_c * inputdata_channel_stride;
const int32_t outputdata_channel_stride = out_h * out_w;
const int32_t output_batch_stride = out_c * outputdata_channel_stride;
T *out_data = out->mutable_data<T>();
const T *input_data = input->data<T>();
const T **rows = new const T *[kernel_h];
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < out_c; ++j) {
const T *img_in = input_data + j * inputdata_channel_stride;
T *img_out = out_data + j * outputdata_channel_stride;
for (int k = 0; k < out_h; ++k) {
for (int m = 0; m < kernel_h; ++m) {
rows[m] = img_in + (stride_h * k + m) * input_w;
}
int32_t left = out_w;
while (left > 0) {
float tmp = 0;
for (int m = 0; m < kernel_h; ++m) {
for (int l = 0; l < kernel_w; ++l) {
tmp += rows[m][l];
}
}
if (typeid(T) == typeid(int8_t)) {
tmp = tmp / (kernel_h * kernel_w);
if (tmp < -127) {
*img_out = -127;
} else if (tmp > 127) {
*img_out = 127;
} else {
*img_out = static_cast<T>(std::round(tmp));
}
} else {
*img_out = static_cast<T>(tmp / (kernel_h * kernel_w));
}
for (int m = 0; m < kernel_h; ++m) {
rows[m] += stride_w;
}
img_out++;
left--;
}
}
}
input_data += input_batch_stride;
out_data += output_batch_stride;
}
delete[] rows;
}
template <typename T, int CeilMode, int PoolType, int Kernel, int Pad,
int Stride>
int TestPoolOp(int in_channels, int in_height, int in_width) {
......@@ -149,41 +88,27 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
framework::Tensor output_cmp;
output_cmp.mutable_data<T>(output_shape);
if (pooling_type == "avg" && pad_h == 0 && pad_h == pad_w) {
PoolAvgPad0<T>(std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, input, &output_cmp);
if (pooling_type == "avg") {
math::Pooling<Avg>()(*input, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w},
std::vector<int>{pad_h, pad_w}, &output_cmp);
} else {
if (typeid(T) == typeid(int8_t)) {
operators::PoolBasic<int8_t, int32_t>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
} else {
operators::PoolBasic<float, float>(
pooling_type, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w}, std::vector<int>{pad_h, pad_w},
input, &output_cmp);
}
math::Pooling<Max>()(*input, std::vector<int>{kernel_h, kernel_w},
std::vector<int>{stride_h, stride_w},
std::vector<int>{pad_h, pad_w}, &output_cmp);
}
// compare results
int eq = 0;
int neq = 0;
auto output = output_var->template Get<framework::LoDTensor>();
const T *output_data = output->data<T>();
T *output_cmp_data = output_cmp.data<T>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
"The execution of test_pool_op is failed!");
if (output_data[i] == output_cmp_data[i]) {
++eq;
} else {
++neq;
}
"output[%d] = %d, output_cmp[%d] = %d", i,
output_data[i], i, output_cmp_data[i]);
}
std::cout << "eq = " << eq << ", neq = " << neq << std::endl;
delete op;
return 0;
}
} // namespace paddle_mobile
......@@ -202,7 +127,6 @@ int main(int argc, char *argv[]) {
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
#if __ARM_NEON
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
......@@ -213,67 +137,16 @@ int main(int argc, char *argv[]) {
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
#endif
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 0, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 1, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 2>(in_channels, in_height,
in_width);
// kernel = 3, pad = 3, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 1
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 3
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
in_width);
// kernel = 5, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height,
in_width);
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 2>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
......@@ -284,9 +157,4 @@ int main(int argc, char *argv[]) {
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
in_width);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
in_width);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册