提交 6defe6ef 编写于 作者: W wuchenghui

add max pooling 2x2 3x3

上级 999aa5db
......@@ -7,69 +7,103 @@
namespace mace {
namespace kernels {
void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW
const index_t* filter_shape, // OIHW
const int* dilations,
const int* strides,
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
index_t* output_shape,
int* padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
padding_size[0] = 0;
padding_size[1] = 0;
index_t *output_shape,
int *padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
padding_size[0] = 0;
padding_size[1] = 0;
index_t output_height, output_width;
index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3];
index_t output_channels = filter_shape[0];
index_t output_height, output_width;
index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3];
index_t output_channels = filter_shape[0];
index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
switch (padding) {
case VALID:
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
}
switch (padding) {
case VALID:
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
break;
case SAME:output_height = (input_shape[2] - 1) / strides[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break;
default:MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side
// TODO may be it's better to also truncate the left/top to
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
// Note: TensorFlow may padded one more on the right/bottom side
// TODO may be it's better to also truncate the left/top to
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
padding_size[0] = (output_height - 1) * strides[0] +
k_extent_height - input_shape[2];
padding_size[1] = (output_width - 1) * strides[1] +
k_extent_width - input_shape[3];
padding_size[0] = (output_height - 1) * strides[0] +
k_extent_height - input_shape[2];
padding_size[1] = (output_width - 1) * strides[1] +
k_extent_width - input_shape[3];
output_shape[0] = input_shape[0];
output_shape[1] = output_channels;
output_shape[2] = output_height;
output_shape[3] = output_width;
}
output_shape[0] = input_shape[0];
output_shape[1] = output_channels;
output_shape[2] = output_height;
output_shape[3] = output_width;
}
void ConstructInputWithPadding(const float *input,
const index_t *input_shape,
const int *paddings,
Tensor *output_tensor) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
std::vector<index_t> output_shape({batch,
channels,
paddings[0] + height,
paddings[1] + width});
const index_t output_width = output_shape[3];
const int padded_top = paddings[0] / 2;
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
float *output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, output_tensor->size() * sizeof(float));
// Skip the padded top rows
output_ptr += padded_top * output_width;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
for (int k = 0; k < height; ++k) {
memcpy(output_ptr + padded_left, input, width * sizeof(float));
input += width;
output_ptr += output_width;
}
// Skip the padded bottom in this channel and top in the next channel
output_ptr += paddings[0] * output_width;
}
}
}
} // namespace kernels
} // namespace mace
......@@ -11,20 +11,24 @@ namespace mace {
enum Padding {
VALID = 0, // No padding
SAME = 1, // Pads with half the filter size (rounded down) on both sides
FULL = 2, // Pads with one less than the filter size on both sides
SAME = 1, // Pads with half the filter size (rounded down) on both sides
FULL = 2, // Pads with one less than the filter size on both sides
};
namespace kernels {
void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW
const index_t* filter_shape, // OIHW
const int* dilations,
const int* strides,
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
index_t* output_shape,
int* padding_size);
index_t *output_shape,
int *padding_size);
void ConstructInputWithPadding(const float *input,
const index_t *input_shape,
const int *paddings,
Tensor *output_tensor);
} // namespace kernels
} // namespace mace
......
......@@ -3,97 +3,61 @@
//
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace kernels {
static inline void ConstructInputWithPadding(const float* input,
const index_t* input_shape,
const int* paddings,
Tensor* output_tensor) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
extern void Conv2dNeonK1x1S1(const float *input, const index_t *input_shape,
const float *filter, const float *bias,
float *output, const index_t *output_shape);
std::vector<index_t> output_shape({batch,
channels,
paddings[0] + height,
paddings[1] + width});
extern void Conv2dNeonK3x3S1(const float *input, const index_t *input_shape,
const float *filter, const float *bias,
float *output, const index_t *output_shape);
const index_t output_width = output_shape[3];
const int padded_top = paddings[0] / 2;
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
float* output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, output_tensor->size() * sizeof(float));
// Skip the padded top rows
output_ptr += padded_top * output_width;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
for (int k = 0; k < height; ++k) {
memcpy(output_ptr + padded_left, input, width * sizeof(float));
input += width;
output_ptr += output_width;
}
// Skip the padded bottom in this channel and top in the next channel
output_ptr += paddings[0] * output_width;
}
}
}
extern void Conv2dNeonK1x1S1(const float* input, const index_t* input_shape,
const float* filter, const float* bias,
float* output, const index_t* output_shape);
extern void Conv2dNeonK3x3S1(const float* input, const index_t* input_shape,
const float* filter, const float* bias,
float* output, const index_t* output_shape);
extern void Conv2dNeonK5x5S1(const float* input, const index_t* input_shape,
const float* filter, const float* bias,
float* output, const index_t* output_shape);
extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape,
const float *filter, const float *bias,
float *output, const index_t *output_shape);
template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const index_t* filter_shape,
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape) {
void Conv2dFunctor<DeviceType::NEON,
float>::operator()(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
typedef void (*Conv2dNeonFunction)(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape);
typedef void (*Conv2dNeonFunction)(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape);
// Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = {
{
Conv2dNeonK1x1S1,
nullptr
},
{
nullptr,
nullptr
},
{
Conv2dNeonK3x3S1,
nullptr
},
{
nullptr,
nullptr
},
{
Conv2dNeonK5x5S1,
nullptr
}
{
Conv2dNeonK1x1S1,
nullptr
},
{
nullptr,
nullptr
},
{
Conv2dNeonK3x3S1,
nullptr
},
{
nullptr,
nullptr
},
{
Conv2dNeonK5x5S1,
nullptr
}
};
// not implement yet
index_t kernel_h = filter_shape[2];
......@@ -104,13 +68,13 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion";
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input,
input_shape,
filter,
filter_shape,
bias,
output,
output_shape
input,
input_shape,
filter,
filter_shape,
bias,
output,
output_shape
);
return;
}
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <float.h>
#include <limits>
#include <arm_neon.h>
#include "mace/core/common.h"
namespace mace {
namespace kernels {
void PoolingMaxNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
const float *r0, *r1;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
if (padding_left > 0) {
*outptr = std::max(r0[0], r1[0]);
++r0;
++r1;
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t max0 = vmaxq_f32(r00, r10);
float32x4_t max1 = vmaxq_f32(r01, r11);
float32x4_t max_result = vpmaxq_f32(max0, max1);
vst1q_f32(outptr, max_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
w += num_vectors << 2;
for (; w < out_width; ++w) {
float max = std::numeric_limits<float>::lowest();
for (int kh = 0; kh < 2; ++kh) {
for (int kw = 0; kw < 2; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height &&
inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]);
}
}
}
*outptr = max;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingMaxNeonK2x2S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = img0 + in_width;
for (int h = 0; h < out_height; ++h) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t max0 = vmaxq_f32(r00, r10);
float32x4_t max1 = vmaxq_f32(r01, r11);
float32x4_t max_result = vpmaxq_f32(max0, max1);
vst1q_f32(outptr, max_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
for (; remain > 0; --remain) {
float max0 = std::max(r0[0], r0[1]);
float max1 = std::max(r1[0], r1[1]);
*outptr = std::max(max0, max1);
r0 += 2;
r1 += 2;
outptr++;
}
r0 += in_width;
r1 += in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <float.h>
#include <limits>
#include <arm_neon.h>
#include "mace/core/common.h"
namespace mace {
namespace kernels {
void PoolingMaxNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
const float *r0, *r1, *r2;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
r2 = r1 + in_width;
if (padding_left > 0) {
if (padding_left == 1) {
float max0 = std::max(r0[0], r0[1]);
float max1 = std::max(r1[0], r1[1]);
float max2 = std::max(r2[0], r2[1]);
*outptr = std::max(std::max(max0, max1), max2);
++r0;
++r1;
} else { // padding_left == 2
float max_tmp = std::max(r0[0], r1[0]);
*outptr = std::max(max_tmp, r2[0]);
}
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; --num_vectors) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]);
float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]);
float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
max0 = vmaxq_f32(max0, row02);
max1 = vmaxq_f32(max1, row12);
max2 = vmaxq_f32(max2, row22);
float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2);
vst1q_f32(outptr, max_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
w += num_vectors << 2;
for (; w < out_width; ++w) {
float max = std::numeric_limits<float>::lowest();
for (int kh = 0; kh < 3; ++kh) {
for (int kw = 0; kw < 3; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height &&
inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]);
}
}
}
*outptr = max;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingMaxNeonK3x3S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = r0 + in_width;
const float *r2 = r1 + in_width;
for (int h = 0; h < out_height; h++) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; num_vectors--) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]);
float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]);
float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
max0 = vmaxq_f32(max0, row02);
max1 = vmaxq_f32(max1, row12);
max2 = vmaxq_f32(max2, row22);
float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2);
vst1q_f32(outptr, max_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
for (; remain > 0; remain--) {
float max0 = std::max(std::max(r0[0], r0[1]), r0[2]);
float max1 = std::max(std::max(r1[0], r1[1]), r1[2]);
float max2 = std::max(std::max(r2[0], r2[1]), r2[2]);
*outptr = std::max(std::max(max0, max1), max2);
r0 += 2;
r1 += 2;
r2 += 2;
outptr++;
}
r0 += 1 + in_width;
r1 += 1 + in_width;
r2 += 1 + in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include "mace/kernels/pooling.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace kernels {
extern void PoolingMaxNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
extern void PoolingMaxNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
#ifdef __COPY_MAKE_PADDING
extern void PoolingMaxNeonK2x2S2x2Padded(const float* input,
const index_t* in_shape,
float* output,
const index_t* out_shape);
extern void PoolingMaxNeonK3x3S2x2Padded(const float* input,
const index_t* in_shape,
float* output,
const index_t* out_shape);
#endif
template<>
void PoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape,
float *output,
const index_t *output_shape) {
if (kernels_[0] == 2 && kernels_[1] == 2 &&
strides_[0] == 2 && strides_[1] == 2 &&
pooling_type_ == MAX) {
#ifdef __COPY_MAKE_PADDING
Tensor padded_input;
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape);
#else
PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape, paddings_);
#endif
} else if (kernels_[0] == 3 && kernels_[1] == 3 &&
strides_[0] == 2 && strides_[1] == 2 &&
pooling_type_ == MAX) {
#ifdef __COPY_MAKE_PADDING
Tensor padded_input;
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
PoolingMaxNeonK3x3S2x2V2Padded(input, input_shape, output, output_shape);
#else
PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape, paddings_);
#endif
} else { // not implement yet
PoolingFunctor<DeviceType::CPU, float>(pooling_type_, kernels_, strides_,
paddings_, dilations_)(
input,
input_shape,
output,
output_shape
);
}
}
} // namespace kernels
} // namespace mace
\ No newline at end of file
......@@ -19,33 +19,33 @@ namespace kernels {
template<DeviceType D, typename T>
class PoolingFunctor {
public:
public:
PoolingFunctor(const PoolingType pooling_type,
const int* kernels,
const int* strides,
const int* paddings,
const int* dilations)
: pooling_type_(pooling_type),
kernels_(kernels),
strides_(strides),
paddings_(paddings),
dilations_(dilations) {}
void operator()(const T* input,
const index_t* input_shape,
T* output,
const index_t* output_shape) {
index_t batch = output_shape[0];
const int *kernels,
const int *strides,
const int *paddings,
const int *dilations)
: pooling_type_(pooling_type),
kernels_(kernels),
strides_(strides),
paddings_(paddings),
dilations_(dilations) {}
void operator()(const T *input,
const index_t *input_shape,
T *output,
const index_t *output_shape) {
index_t batch = output_shape[0];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t input_channels = input_shape[1];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
int kernel_h = kernels_[0];
int kernel_w = kernels_[1];
int kernel_w = kernels_[1];
int stride_h = strides_[0];
int stride_w = strides_[1];
......@@ -61,20 +61,20 @@ public:
for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) {
index_t out_offset = n * channels * height * width +
c * height * width;
c * height * width;
index_t in_offset = n * input_channels * input_height * input_width +
c * input_height * input_width;
c * input_height * input_width;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
T sum_or_max = 0;
switch (pooling_type_) {
case AVG:
break;
case MAX:
sum_or_max = std::numeric_limits<T>::lowest();
case AVG:break;
case MAX:sum_or_max = std::numeric_limits<T>::lowest();
break;
default:
MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_);
MACE_CHECK(false,
"Unsupported pooling type: ",
pooling_type_);
}
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
......@@ -83,10 +83,9 @@ public:
if (inh >= 0 && inh < input_height &&
inw >= 0 && inw < input_width) {
index_t input_offset = in_offset +
inh * input_width + inw;
inh * input_width + inw;
switch (pooling_type_) {
case AVG:
sum_or_max += input[input_offset];
case AVG:sum_or_max += input[input_offset];
break;
case MAX:
sum_or_max = std::max(sum_or_max, input[input_offset]);
......@@ -99,14 +98,14 @@ public:
}
}
switch (pooling_type_) {
case AVG:
output[out_offset] = sum_or_max / (kernel_h * kernel_w);
case AVG:output[out_offset] = sum_or_max / (kernel_h * kernel_w);
break;
case MAX:
output[out_offset] = sum_or_max;
case MAX:output[out_offset] = sum_or_max;
break;
default:
MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_);
MACE_CHECK(false,
"Unsupported pooling type: ",
pooling_type_);
}
out_offset += 1;
}
......@@ -115,14 +114,20 @@ public:
}
}
private:
private:
const PoolingType pooling_type_;
const int* kernels_;
const int* strides_;
const int* paddings_;
const int* dilations_;
const int *kernels_;
const int *strides_;
const int *paddings_;
const int *dilations_;
};
template<>
void PoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape,
float *output,
const index_t *output_shape);
} // namespace kernels
} // namespace mace
......
......@@ -9,4 +9,8 @@ namespace mace {
REGISTER_CPU_OPERATOR(Pooling, PoolingOp<DeviceType::CPU, float>);
#if __ARM_NEON
REGISTER_NEON_OPERATOR(Pooling, PoolingOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/testing/test_benchmark.h"
#include "mace/core/operator.h"
#include "mace/kernels/pooling.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
using namespace mace::kernels;
template<DeviceType D>
static void Pooling(int iters, int batch, int channels, int height,
int width, int kernel, int stride, Padding padding,
PoolingType pooling_type) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntArg("pooling_type", pooling_type);
net.AddIntsArg("kernels", {kernel, kernel});
net.AddIntsArg("strides", {stride, stride});
net.AddIntArg("padding", padding);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<float>("Input", {batch, channels, height, width});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
#define BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \
static void BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot * (sizeof(float)));\
Pooling<DEVICE>(iters, N, C, H, W, KE, STRIDE, Padding::PA, PoolingType::PO); \
} \
BENCHMARK(BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE)
#define BM_POOLING(N, C, H, W, K, S, PA, PO) \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON);
BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 513, 513, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 1025, 1025, 2, 2, SAME, MAX);
......@@ -148,3 +148,67 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntArg("pooling_type", PoolingType::MAX);
net.AddIntsArg("kernels", {2, 2});
net.AddIntsArg("strides", {2, 2});
net.AddIntArg("padding", Padding::SAME);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddInputFromArray<float>("Input", {1, 1, 4, 5},
{0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14,
15, 16, 17, 18, 19});
// Run
net.RunOp(DeviceType::NEON);
// Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3},
{6, 8, 9,
16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(PoolingOpTest, MAX_k3x3s2x2) {
// Construct graph
auto& net = test_net();
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntArg("pooling_type", PoolingType::MAX);
net.AddIntsArg("kernels", {3, 3});
net.AddIntsArg("strides", {2, 2});
net.AddIntArg("padding", Padding::SAME);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddInputFromArray<float>("Input", {1, 1, 4, 5},
{0, 1, 2, 3, 4,
5, 6, 7, 8, 9,
10, 11, 12, 13, 14,
15, 16, 17, 18, 19});
// Run
net.RunOp(DeviceType::NEON);
// Check
Tensor expected = CreateTensor<float>({1, 1, 2, 3},
{11, 13, 14,
16, 18, 19});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册