提交 6defe6ef 编写于 作者: W wuchenghui

add max pooling 2x2 3x3

上级 999aa5db
...@@ -7,69 +7,103 @@ ...@@ -7,69 +7,103 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t* filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int* dilations, const int *dilations,
const int* strides, const int *strides,
Padding padding, Padding padding,
index_t* output_shape, index_t *output_shape,
int* padding_size) { int *padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1"); "Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1), (dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1"); "If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape); MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size); MACE_CHECK_NOTNULL(padding_size);
/* /*
* Convlution/pooling arithmetic: * Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or * For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/ */
padding_size[0] = 0; padding_size[0] = 0;
padding_size[1] = 0; padding_size[1] = 0;
index_t output_height, output_width; index_t output_height, output_width;
index_t kernel_height = filter_shape[2]; index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3]; index_t kernel_width = filter_shape[3];
index_t output_channels = filter_shape[0]; index_t output_channels = filter_shape[0];
index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1; index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1; index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
switch (padding) { switch (padding) {
case VALID: case VALID:
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1; output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1; output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
break; break;
case SAME: case SAME:output_height = (input_shape[2] - 1) / strides[0] + 1;
output_height = (input_shape[2] - 1) / strides[0] + 1; output_width = (input_shape[3] - 1) / strides[1] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1; break;
break; case FULL:
case FULL: output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; break;
break; default:MACE_CHECK(false, "Unsupported padding type: ", padding);
default: }
MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side // Note: TensorFlow may padded one more on the right/bottom side
// TODO may be it's better to also truncate the left/top to // TODO may be it's better to also truncate the left/top to
// utilize the more centered features. We need to benchmark // utilize the more centered features. We need to benchmark
// based on the model accuracy. // based on the model accuracy.
padding_size[0] = (output_height - 1) * strides[0] + padding_size[0] = (output_height - 1) * strides[0] +
k_extent_height - input_shape[2]; k_extent_height - input_shape[2];
padding_size[1] = (output_width - 1) * strides[1] + padding_size[1] = (output_width - 1) * strides[1] +
k_extent_width - input_shape[3]; k_extent_width - input_shape[3];
output_shape[0] = input_shape[0]; output_shape[0] = input_shape[0];
output_shape[1] = output_channels; output_shape[1] = output_channels;
output_shape[2] = output_height; output_shape[2] = output_height;
output_shape[3] = output_width; output_shape[3] = output_width;
} }
void ConstructInputWithPadding(const float *input,
const index_t *input_shape,
const int *paddings,
Tensor *output_tensor) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
std::vector<index_t> output_shape({batch,
channels,
paddings[0] + height,
paddings[1] + width});
const index_t output_width = output_shape[3];
const int padded_top = paddings[0] / 2;
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
float *output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, output_tensor->size() * sizeof(float));
// Skip the padded top rows
output_ptr += padded_top * output_width;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
for (int k = 0; k < height; ++k) {
memcpy(output_ptr + padded_left, input, width * sizeof(float));
input += width;
output_ptr += output_width;
}
// Skip the padded bottom in this channel and top in the next channel
output_ptr += paddings[0] * output_width;
}
}
}
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -11,20 +11,24 @@ namespace mace { ...@@ -11,20 +11,24 @@ namespace mace {
enum Padding { enum Padding {
VALID = 0, // No padding VALID = 0, // No padding
SAME = 1, // Pads with half the filter size (rounded down) on both sides SAME = 1, // Pads with half the filter size (rounded down) on both sides
FULL = 2, // Pads with one less than the filter size on both sides FULL = 2, // Pads with one less than the filter size on both sides
}; };
namespace kernels { namespace kernels {
void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t* filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int* dilations, const int *dilations,
const int* strides, const int *strides,
Padding padding, Padding padding,
index_t* output_shape, index_t *output_shape,
int* padding_size); int *padding_size);
void ConstructInputWithPadding(const float *input,
const index_t *input_shape,
const int *paddings,
Tensor *output_tensor);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -3,97 +3,61 @@ ...@@ -3,97 +3,61 @@
// //
#include "mace/kernels/conv_2d.h" #include "mace/kernels/conv_2d.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
static inline void ConstructInputWithPadding(const float* input, extern void Conv2dNeonK1x1S1(const float *input, const index_t *input_shape,
const index_t* input_shape, const float *filter, const float *bias,
const int* paddings, float *output, const index_t *output_shape);
Tensor* output_tensor) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
std::vector<index_t> output_shape({batch, extern void Conv2dNeonK3x3S1(const float *input, const index_t *input_shape,
channels, const float *filter, const float *bias,
paddings[0] + height, float *output, const index_t *output_shape);
paddings[1] + width});
const index_t output_width = output_shape[3]; extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape,
const int padded_top = paddings[0] / 2; const float *filter, const float *bias,
const int padded_left = paddings[1] / 2; float *output, const index_t *output_shape);
output_tensor->Resize(output_shape);
float* output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, output_tensor->size() * sizeof(float));
// Skip the padded top rows
output_ptr += padded_top * output_width;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
for (int k = 0; k < height; ++k) {
memcpy(output_ptr + padded_left, input, width * sizeof(float));
input += width;
output_ptr += output_width;
}
// Skip the padded bottom in this channel and top in the next channel
output_ptr += paddings[0] * output_width;
}
}
}
extern void Conv2dNeonK1x1S1(const float* input, const index_t* input_shape,
const float* filter, const float* bias,
float* output, const index_t* output_shape);
extern void Conv2dNeonK3x3S1(const float* input, const index_t* input_shape,
const float* filter, const float* bias,
float* output, const index_t* output_shape);
extern void Conv2dNeonK5x5S1(const float* input, const index_t* input_shape,
const float* filter, const float* bias,
float* output, const index_t* output_shape);
template<> template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // NCHW void Conv2dFunctor<DeviceType::NEON,
const index_t* input_shape, float>::operator()(const float *input, // NCHW
const float* filter, // c_out, c_in, kernel_h, kernel_w const index_t *input_shape,
const index_t* filter_shape, const float *filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out const index_t *filter_shape,
float* output, // NCHW const float *bias, // c_out
const index_t* output_shape) { float *output, // NCHW
const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)(const float* input, // NCHW typedef void (*Conv2dNeonFunction)(const float *input, // NCHW
const index_t* input_shape, const index_t *input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w const float *filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out const float *bias, // c_out
float* output, // NCHW float *output, // NCHW
const index_t* output_shape); const index_t *output_shape);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = { static const Conv2dNeonFunction selector[5][2] = {
{ {
Conv2dNeonK1x1S1, Conv2dNeonK1x1S1,
nullptr nullptr
}, },
{ {
nullptr, nullptr,
nullptr nullptr
}, },
{ {
Conv2dNeonK3x3S1, Conv2dNeonK3x3S1,
nullptr nullptr
}, },
{ {
nullptr, nullptr,
nullptr nullptr
}, },
{ {
Conv2dNeonK5x5S1, Conv2dNeonK5x5S1,
nullptr nullptr
} }
}; };
// not implement yet // not implement yet
index_t kernel_h = filter_shape[2]; index_t kernel_h = filter_shape[2];
...@@ -104,13 +68,13 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N ...@@ -104,13 +68,13 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N
selector[kernel_h - 1][strides_[0] - 1] == nullptr) { selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion"; LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion";
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)( Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, input,
input_shape, input_shape,
filter, filter,
filter_shape, filter_shape,
bias, bias,
output, output,
output_shape output_shape
); );
return; return;
} }
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <float.h>
#include <limits>
#include <arm_neon.h>
#include "mace/core/common.h"
namespace mace {
namespace kernels {
void PoolingMaxNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
const float *r0, *r1;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
if (padding_left > 0) {
*outptr = std::max(r0[0], r1[0]);
++r0;
++r1;
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t max0 = vmaxq_f32(r00, r10);
float32x4_t max1 = vmaxq_f32(r01, r11);
float32x4_t max_result = vpmaxq_f32(max0, max1);
vst1q_f32(outptr, max_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
w += num_vectors << 2;
for (; w < out_width; ++w) {
float max = std::numeric_limits<float>::lowest();
for (int kh = 0; kh < 2; ++kh) {
for (int kw = 0; kw < 2; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height &&
inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]);
}
}
}
*outptr = max;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingMaxNeonK2x2S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = img0 + in_width;
for (int h = 0; h < out_height; ++h) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t max0 = vmaxq_f32(r00, r10);
float32x4_t max1 = vmaxq_f32(r01, r11);
float32x4_t max_result = vpmaxq_f32(max0, max1);
vst1q_f32(outptr, max_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
for (; remain > 0; --remain) {
float max0 = std::max(r0[0], r0[1]);
float max1 = std::max(r1[0], r1[1]);
*outptr = std::max(max0, max1);
r0 += 2;
r1 += 2;
outptr++;
}
r0 += in_width;
r1 += in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <float.h>
#include <limits>
#include <arm_neon.h>
#include "mace/core/common.h"
namespace mace {
namespace kernels {
void PoolingMaxNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
const float *r0, *r1, *r2;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
r2 = r1 + in_width;
if (padding_left > 0) {
if (padding_left == 1) {
float max0 = std::max(r0[0], r0[1]);
float max1 = std::max(r1[0], r1[1]);
float max2 = std::max(r2[0], r2[1]);
*outptr = std::max(std::max(max0, max1), max2);
++r0;
++r1;
} else { // padding_left == 2
float max_tmp = std::max(r0[0], r1[0]);
*outptr = std::max(max_tmp, r2[0]);
}
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; --num_vectors) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]);
float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]);
float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
max0 = vmaxq_f32(max0, row02);
max1 = vmaxq_f32(max1, row12);
max2 = vmaxq_f32(max2, row22);
float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2);
vst1q_f32(outptr, max_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
w += num_vectors << 2;
for (; w < out_width; ++w) {
float max = std::numeric_limits<float>::lowest();
for (int kh = 0; kh < 3; ++kh) {
for (int kw = 0; kw < 3; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height &&
inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]);
}
}
}
*outptr = max;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingMaxNeonK3x3S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = r0 + in_width;
const float *r2 = r1 + in_width;
for (int h = 0; h < out_height; h++) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; num_vectors--) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]);
float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]);
float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
max0 = vmaxq_f32(max0, row02);
max1 = vmaxq_f32(max1, row12);
max2 = vmaxq_f32(max2, row22);
float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2);
vst1q_f32(outptr, max_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
for (; remain > 0; remain--) {
float max0 = std::max(std::max(r0[0], r0[1]), r0[2]);
float max1 = std::max(std::max(r1[0], r1[1]), r1[2]);
float max2 = std::max(std::max(r2[0], r2[1]), r2[2]);
*outptr = std::max(std::max(max0, max1), max2);
r0 += 2;
r1 += 2;
r2 += 2;
outptr++;
}
r0 += 1 + in_width;
r1 += 1 + in_width;
r2 += 1 + in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include "mace/kernels/pooling.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace kernels {
extern void PoolingMaxNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
extern void PoolingMaxNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
#ifdef __COPY_MAKE_PADDING
extern void PoolingMaxNeonK2x2S2x2Padded(const float* input,
const index_t* in_shape,
float* output,
const index_t* out_shape);
extern void PoolingMaxNeonK3x3S2x2Padded(const float* input,
const index_t* in_shape,
float* output,
const index_t* out_shape);
#endif
template<>
void PoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape,
float *output,
const index_t *output_shape) {
if (kernels_[0] == 2 && kernels_[1] == 2 &&
strides_[0] == 2 && strides_[1] == 2 &&
pooling_type_ == MAX) {
#ifdef __COPY_MAKE_PADDING
Tensor padded_input;
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape);
#else
PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape, paddings_);
#endif
} else if (kernels_[0] == 3 && kernels_[1] == 3 &&
strides_[0] == 2 && strides_[1] == 2 &&
pooling_type_ == MAX) {
#ifdef __COPY_MAKE_PADDING
Tensor padded_input;
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
PoolingMaxNeonK3x3S2x2V2Padded(input, input_shape, output, output_shape);
#else
PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape, paddings_);
#endif
} else { // not implement yet
PoolingFunctor<DeviceType::CPU, float>(pooling_type_, kernels_, strides_,
paddings_, dilations_)(
input,
input_shape,
output,
output_shape
);
}
}
} // namespace kernels
} // namespace mace
\ No newline at end of file
...@@ -19,33 +19,33 @@ namespace kernels { ...@@ -19,33 +19,33 @@ namespace kernels {
template<DeviceType D, typename T> template<DeviceType D, typename T>
class PoolingFunctor { class PoolingFunctor {
public: public:
PoolingFunctor(const PoolingType pooling_type, PoolingFunctor(const PoolingType pooling_type,
const int* kernels, const int *kernels,
const int* strides, const int *strides,
const int* paddings, const int *paddings,
const int* dilations) const int *dilations)
: pooling_type_(pooling_type), : pooling_type_(pooling_type),
kernels_(kernels), kernels_(kernels),
strides_(strides), strides_(strides),
paddings_(paddings), paddings_(paddings),
dilations_(dilations) {} dilations_(dilations) {}
void operator()(const T* input, void operator()(const T *input,
const index_t* input_shape, const index_t *input_shape,
T* output, T *output,
const index_t* output_shape) { const index_t *output_shape) {
index_t batch = output_shape[0]; index_t batch = output_shape[0];
index_t channels = output_shape[1]; index_t channels = output_shape[1];
index_t height = output_shape[2]; index_t height = output_shape[2];
index_t width = output_shape[3]; index_t width = output_shape[3];
index_t input_channels = input_shape[1]; index_t input_channels = input_shape[1];
index_t input_height = input_shape[2]; index_t input_height = input_shape[2];
index_t input_width = input_shape[3]; index_t input_width = input_shape[3];
int kernel_h = kernels_[0]; int kernel_h = kernels_[0];
int kernel_w = kernels_[1]; int kernel_w = kernels_[1];
int stride_h = strides_[0]; int stride_h = strides_[0];
int stride_w = strides_[1]; int stride_w = strides_[1];
...@@ -61,20 +61,20 @@ public: ...@@ -61,20 +61,20 @@ public:
for (int n = 0; n < batch; ++n) { for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
index_t out_offset = n * channels * height * width + index_t out_offset = n * channels * height * width +
c * height * width; c * height * width;
index_t in_offset = n * input_channels * input_height * input_width + index_t in_offset = n * input_channels * input_height * input_width +
c * input_height * input_width; c * input_height * input_width;
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) { for (int w = 0; w < width; ++w) {
T sum_or_max = 0; T sum_or_max = 0;
switch (pooling_type_) { switch (pooling_type_) {
case AVG: case AVG:break;
break; case MAX:sum_or_max = std::numeric_limits<T>::lowest();
case MAX:
sum_or_max = std::numeric_limits<T>::lowest();
break; break;
default: default:
MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); MACE_CHECK(false,
"Unsupported pooling type: ",
pooling_type_);
} }
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < kernel_w; ++kw) {
...@@ -83,10 +83,9 @@ public: ...@@ -83,10 +83,9 @@ public:
if (inh >= 0 && inh < input_height && if (inh >= 0 && inh < input_height &&
inw >= 0 && inw < input_width) { inw >= 0 && inw < input_width) {
index_t input_offset = in_offset + index_t input_offset = in_offset +
inh * input_width + inw; inh * input_width + inw;
switch (pooling_type_) { switch (pooling_type_) {
case AVG: case AVG:sum_or_max += input[input_offset];
sum_or_max += input[input_offset];
break; break;
case MAX: case MAX:
sum_or_max = std::max(sum_or_max, input[input_offset]); sum_or_max = std::max(sum_or_max, input[input_offset]);
...@@ -99,14 +98,14 @@ public: ...@@ -99,14 +98,14 @@ public:
} }
} }
switch (pooling_type_) { switch (pooling_type_) {
case AVG: case AVG:output[out_offset] = sum_or_max / (kernel_h * kernel_w);
output[out_offset] = sum_or_max / (kernel_h * kernel_w);
break; break;
case MAX: case MAX:output[out_offset] = sum_or_max;
output[out_offset] = sum_or_max;
break; break;
default: default:
MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); MACE_CHECK(false,
"Unsupported pooling type: ",
pooling_type_);
} }
out_offset += 1; out_offset += 1;
} }
...@@ -115,14 +114,20 @@ public: ...@@ -115,14 +114,20 @@ public:
} }
} }
private: private:
const PoolingType pooling_type_; const PoolingType pooling_type_;
const int* kernels_; const int *kernels_;
const int* strides_; const int *strides_;
const int* paddings_; const int *paddings_;
const int* dilations_; const int *dilations_;
}; };
template<>
void PoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape,
float *output,
const index_t *output_shape);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -9,4 +9,8 @@ namespace mace { ...@@ -9,4 +9,8 @@ namespace mace {
REGISTER_CPU_OPERATOR(Pooling, PoolingOp<DeviceType::CPU, float>); REGISTER_CPU_OPERATOR(Pooling, PoolingOp<DeviceType::CPU, float>);
#if __ARM_NEON
REGISTER_NEON_OPERATOR(Pooling, PoolingOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
} // namespace mace } // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/testing/test_benchmark.h"
#include "mace/core/operator.h"
#include "mace/kernels/pooling.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
using namespace mace::kernels;
template<DeviceType D>
static void Pooling(int iters, int batch, int channels, int height,
int width, int kernel, int stride, Padding padding,
PoolingType pooling_type) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntArg("pooling_type", pooling_type);
net.AddIntsArg("kernels", {kernel, kernel});
net.AddIntsArg("strides", {stride, stride});
net.AddIntArg("padding", padding);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<float>("Input", {batch, channels, height, width});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
#define BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \
static void BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot * (sizeof(float)));\
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, PoolingType::PO); \
} \
BENCHMARK(BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE)
#define BM_POOLING(N, C, H, W, K, S, PA, PO) \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON);
BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 513, 513, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 1025, 1025, 2, 2, SAME, MAX);
...@@ -148,3 +148,67 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -148,3 +148,67 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntArg("pooling_type", PoolingType::MAX);
net.AddIntsArg("kernels", {2, 2});
net.AddIntsArg("strides", {2, 2});
net.AddIntArg("padding", Padding::SAME);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddInputFromArray<float>("Input", {1, 1, 4, 5},
{0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14,
15, 16, 17, 18, 19});
// Run
net.RunOp(DeviceType::NEON);
// Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3},
{6, 8, 9,
16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntArg("pooling_type", PoolingType::MAX);
net.AddIntsArg("kernels", {3, 3});
net.AddIntsArg("strides", {2, 2});
net.AddIntArg("padding", Padding::SAME);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddInputFromArray<float>("Input", {1, 1, 4, 5},
{0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14,
15, 16, 17, 18, 19});
// Run
net.RunOp(DeviceType::NEON);
// Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3},
{11, 13, 14,
16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册