diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index fe6f4e685883accfad0cca2d3c1cc3fa1055aa40..be5a742af6931e69ecdab3deb1bee79660e4f4f8 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -7,69 +7,103 @@ namespace mace { namespace kernels { -void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW - const index_t* filter_shape, // OIHW - const int* dilations, - const int* strides, +void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW + const index_t *filter_shape, // OIHW + const int *dilations, + const int *strides, Padding padding, - index_t* output_shape, - int* padding_size) { - MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, - "Invalid dilations, must >= 1"); - MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), - "If dilations > 1, strides should be 1"); - MACE_CHECK_NOTNULL(output_shape); - MACE_CHECK_NOTNULL(padding_size); - /* - * Convlution/pooling arithmetic: - * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 - * For details, see https://arxiv.org/pdf/1603.07285.pdf or - * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html - */ - padding_size[0] = 0; - padding_size[1] = 0; + index_t *output_shape, + int *padding_size) { + MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, + "Invalid dilations, must >= 1"); + MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && + (dilations[1] == 1 || strides[1] == 1), + "If dilations > 1, strides should be 1"); + MACE_CHECK_NOTNULL(output_shape); + MACE_CHECK_NOTNULL(padding_size); + /* + * Convlution/pooling arithmetic: + * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 + * For details, see https://arxiv.org/pdf/1603.07285.pdf or + * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html + */ + padding_size[0] = 0; + padding_size[1] = 0; - index_t output_height, output_width; - index_t kernel_height = filter_shape[2]; - index_t kernel_width = filter_shape[3]; - index_t output_channels = filter_shape[0]; + index_t output_height, output_width; + index_t kernel_height = filter_shape[2]; + index_t kernel_width = filter_shape[3]; + index_t output_channels = filter_shape[0]; - index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1; - index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1; + index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1; + index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1; - switch (padding) { - case VALID: - output_height = (input_shape[2] - k_extent_height) / strides[0] + 1; - output_width = (input_shape[3] - k_extent_width) / strides[1] + 1; - break; - case SAME: - output_height = (input_shape[2] - 1) / strides[0] + 1; - output_width = (input_shape[3] - 1) / strides[1] + 1; - break; - case FULL: - output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; - output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; - break; - default: - MACE_CHECK(false, "Unsupported padding type: ", padding); - } + switch (padding) { + case VALID: + output_height = (input_shape[2] - k_extent_height) / strides[0] + 1; + output_width = (input_shape[3] - k_extent_width) / strides[1] + 1; + break; + case SAME:output_height = (input_shape[2] - 1) / strides[0] + 1; + output_width = (input_shape[3] - 1) / strides[1] + 1; + break; + case FULL: + output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; + output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; + break; + default:MACE_CHECK(false, "Unsupported padding type: ", padding); + } - // Note: TensorFlow may padded one more on the right/bottom side - // TODO may be it's better to also truncate the left/top to - // utilize the more centered features. We need to benchmark - // based on the model accuracy. + // Note: TensorFlow may padded one more on the right/bottom side + // TODO may be it's better to also truncate the left/top to + // utilize the more centered features. We need to benchmark + // based on the model accuracy. - padding_size[0] = (output_height - 1) * strides[0] + - k_extent_height - input_shape[2]; - padding_size[1] = (output_width - 1) * strides[1] + - k_extent_width - input_shape[3]; + padding_size[0] = (output_height - 1) * strides[0] + + k_extent_height - input_shape[2]; + padding_size[1] = (output_width - 1) * strides[1] + + k_extent_width - input_shape[3]; - output_shape[0] = input_shape[0]; - output_shape[1] = output_channels; - output_shape[2] = output_height; - output_shape[3] = output_width; - } + output_shape[0] = input_shape[0]; + output_shape[1] = output_channels; + output_shape[2] = output_height; + output_shape[3] = output_width; +} + +void ConstructInputWithPadding(const float *input, + const index_t *input_shape, + const int *paddings, + Tensor *output_tensor) { + index_t batch = input_shape[0]; + index_t channels = input_shape[1]; + index_t height = input_shape[2]; + index_t width = input_shape[3]; + + std::vector output_shape({batch, + channels, + paddings[0] + height, + paddings[1] + width}); + const index_t output_width = output_shape[3]; + const int padded_top = paddings[0] / 2; + const int padded_left = paddings[1] / 2; + + output_tensor->Resize(output_shape); + float *output_ptr = output_tensor->mutable_data(); + memset(output_ptr, 0, output_tensor->size() * sizeof(float)); + + // Skip the padded top rows + output_ptr += padded_top * output_width; + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + for (int k = 0; k < height; ++k) { + memcpy(output_ptr + padded_left, input, width * sizeof(float)); + input += width; + output_ptr += output_width; + } + // Skip the padded bottom in this channel and top in the next channel + output_ptr += paddings[0] * output_width; + } + } +} } // namespace kernels } // namespace mace diff --git a/mace/kernels/conv_pool_2d_util.h b/mace/kernels/conv_pool_2d_util.h index 3cca8a79ed5cfbc9a774f4b8bccfce594f1a5a4e..c6b9f0908ab85d36eed17d5702b5ea3504781d19 100644 --- a/mace/kernels/conv_pool_2d_util.h +++ b/mace/kernels/conv_pool_2d_util.h @@ -11,20 +11,24 @@ namespace mace { enum Padding { VALID = 0, // No padding - SAME = 1, // Pads with half the filter size (rounded down) on both sides - FULL = 2, // Pads with one less than the filter size on both sides + SAME = 1, // Pads with half the filter size (rounded down) on both sides + FULL = 2, // Pads with one less than the filter size on both sides }; namespace kernels { -void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW - const index_t* filter_shape, // OIHW - const int* dilations, - const int* strides, +void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW + const index_t *filter_shape, // OIHW + const int *dilations, + const int *strides, Padding padding, - index_t* output_shape, - int* padding_size); + index_t *output_shape, + int *padding_size); +void ConstructInputWithPadding(const float *input, + const index_t *input_shape, + const int *paddings, + Tensor *output_tensor); } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index 15550308b1af228345f0a0c3fa4cfdca39c478ab..75b22e9a93a0b0e8eaee97e1f0a904071accccdd 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -3,97 +3,61 @@ // #include "mace/kernels/conv_2d.h" +#include "mace/kernels/conv_pool_2d_util.h" namespace mace { namespace kernels { -static inline void ConstructInputWithPadding(const float* input, - const index_t* input_shape, - const int* paddings, - Tensor* output_tensor) { - index_t batch = input_shape[0]; - index_t channels = input_shape[1]; - index_t height = input_shape[2]; - index_t width = input_shape[3]; +extern void Conv2dNeonK1x1S1(const float *input, const index_t *input_shape, + const float *filter, const float *bias, + float *output, const index_t *output_shape); - std::vector output_shape({batch, - channels, - paddings[0] + height, - paddings[1] + width}); +extern void Conv2dNeonK3x3S1(const float *input, const index_t *input_shape, + const float *filter, const float *bias, + float *output, const index_t *output_shape); - const index_t output_width = output_shape[3]; - const int padded_top = paddings[0] / 2; - const int padded_left = paddings[1] / 2; - - output_tensor->Resize(output_shape); - float* output_ptr = output_tensor->mutable_data(); - memset(output_ptr, 0, output_tensor->size() * sizeof(float)); - - // Skip the padded top rows - output_ptr += padded_top * output_width; - for (int i = 0; i < batch; ++i) { - for (int j = 0; j < channels; ++j) { - for (int k = 0; k < height; ++k) { - memcpy(output_ptr + padded_left, input, width * sizeof(float)); - input += width; - output_ptr += output_width; - } - // Skip the padded bottom in this channel and top in the next channel - output_ptr += paddings[0] * output_width; - } - } -} - - -extern void Conv2dNeonK1x1S1(const float* input, const index_t* input_shape, - const float* filter, const float* bias, - float* output, const index_t* output_shape); - -extern void Conv2dNeonK3x3S1(const float* input, const index_t* input_shape, - const float* filter, const float* bias, - float* output, const index_t* output_shape); - -extern void Conv2dNeonK5x5S1(const float* input, const index_t* input_shape, - const float* filter, const float* bias, - float* output, const index_t* output_shape); +extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape, + const float *filter, const float *bias, + float *output, const index_t *output_shape); template<> -void Conv2dFunctor::operator()(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const index_t* filter_shape, - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape) { +void Conv2dFunctor::operator()(const float *input, // NCHW + const index_t *input_shape, + const float *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, + const float *bias, // c_out + float *output, // NCHW + const index_t *output_shape) { - typedef void (*Conv2dNeonFunction)(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape); + typedef void (*Conv2dNeonFunction)(const float *input, // NCHW + const index_t *input_shape, + const float *filter, // c_out, c_in, kernel_h, kernel_w + const float *bias, // c_out + float *output, // NCHW + const index_t *output_shape); // Selection matrix: kernel_size x stride_size static const Conv2dNeonFunction selector[5][2] = { - { - Conv2dNeonK1x1S1, - nullptr - }, - { - nullptr, - nullptr - }, - { - Conv2dNeonK3x3S1, - nullptr - }, - { - nullptr, - nullptr - }, - { - Conv2dNeonK5x5S1, - nullptr - } + { + Conv2dNeonK1x1S1, + nullptr + }, + { + nullptr, + nullptr + }, + { + Conv2dNeonK3x3S1, + nullptr + }, + { + nullptr, + nullptr + }, + { + Conv2dNeonK5x5S1, + nullptr + } }; // not implement yet index_t kernel_h = filter_shape[2]; @@ -104,13 +68,13 @@ void Conv2dFunctor::operator()(const float* input, // N selector[kernel_h - 1][strides_[0] - 1] == nullptr) { LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion"; Conv2dFunctor(strides_, paddings_, dilations_)( - input, - input_shape, - filter, - filter_shape, - bias, - output, - output_shape + input, + input_shape, + filter, + filter_shape, + bias, + output, + output_shape ); return; } diff --git a/mace/kernels/neon/max_pooling_neon_2x2.cc b/mace/kernels/neon/max_pooling_neon_2x2.cc new file mode 100644 index 0000000000000000000000000000000000000000..088ea467e1ab187a210efc72c6af82c47732d234 --- /dev/null +++ b/mace/kernels/neon/max_pooling_neon_2x2.cc @@ -0,0 +1,173 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include +#include + +#include "mace/core/common.h" + +namespace mace { +namespace kernels { + +void PoolingMaxNeonK2x2S2x2(const float *input, + const index_t *in_shape, + float *output, + const index_t *out_shape, + const int *paddings) { + index_t batch = in_shape[0]; + index_t channels = in_shape[1]; + index_t in_height = in_shape[2]; + index_t in_width = in_shape[3]; + + index_t out_height = out_shape[2]; + index_t out_width = out_shape[3]; + + int padding_top = paddings[0] / 2; + int padding_bottom = paddings[0] - padding_top; + int padding_left = paddings[1] / 2; + int padding_right = paddings[1] - padding_left; + + int in_image_size = in_height * in_width; + int out_image_size = out_height * out_width; + index_t input_offset = 0; + index_t output_offset = 0; + +#pragma omp parallel for collapse(2) + for (int b = 0; b < batch; ++b) { + for (int c = 0; c < channels; ++c) { + float *outptr = output + output_offset; + const float *r0, *r1; + + for (int h = 0; h < out_height; ++h) { + int w = 0; + int num_vectors = 0; + if (!((h == 0 && padding_top > 0) || + (h == out_height - 1 && padding_bottom > 0))) { + r0 = input + input_offset + (h * 2 - padding_top) * in_width; + r1 = r0 + in_width; + if (padding_left > 0) { + *outptr = std::max(r0[0], r1[0]); + ++r0; + ++r1; + ++outptr; + ++w; + } + if (padding_right > 0) { + num_vectors = (out_width - w - 1) >> 2; + } else { + num_vectors = (out_width - w) >> 2; + } + } + + for (; num_vectors > 0; --num_vectors) { + float32x4_t r00 = vld1q_f32(r0); + float32x4_t r10 = vld1q_f32(r1); + float32x4_t r01 = vld1q_f32(r0 + 4); + float32x4_t r11 = vld1q_f32(r1 + 4); + + float32x4_t max0 = vmaxq_f32(r00, r10); + float32x4_t max1 = vmaxq_f32(r01, r11); + + float32x4_t max_result = vpmaxq_f32(max0, max1); + + vst1q_f32(outptr, max_result); + + r0 += 8; + r1 += 8; + outptr += 4; + } + + w += num_vectors << 2; + for (; w < out_width; ++w) { + float max = std::numeric_limits::lowest(); + for (int kh = 0; kh < 2; ++kh) { + for (int kw = 0; kw < 2; ++kw) { + int inh = h * 2 - padding_top + kh; + int inw = w * 2 - padding_left + kw; + if (inh >= 0 && inh < in_height && + inw >= 0 && inw < in_width) { + max = std::max(max, input[input_offset + inh * in_width + inw]); + } + } + } + + *outptr = max; + ++outptr; + } + } + input_offset += in_image_size; + output_offset += out_image_size; + } + } +} + +// assume the input has already been padded +void PoolingMaxNeonK2x2S2x2Padded(const float *input, + const index_t *in_shape, + float *output, + const index_t *out_shape) { + index_t batch = in_shape[0]; + index_t channels = in_shape[1]; + index_t in_height = in_shape[2]; + index_t in_width = in_shape[3]; + + index_t out_height = out_shape[2]; + index_t out_width = out_shape[3]; + + int in_image_size = in_height * in_width; + int out_image_size = out_height * out_width; + index_t input_offset = 0; + index_t output_offset = 0; + +#pragma omp parallel for collapse(2) + for (int b = 0; b < batch; ++b) { + for (int c = 0; c < channels; ++c) { + const float *img0 = input + input_offset; + float *outptr = output + output_offset; + + const float *r0 = img0; + const float *r1 = img0 + in_width; + + for (int h = 0; h < out_height; ++h) { + int num_vectors = out_width >> 2; + int remain = out_width - (num_vectors << 2); + + for (; num_vectors > 0; --num_vectors) { + float32x4_t r00 = vld1q_f32(r0); + float32x4_t r10 = vld1q_f32(r1); + float32x4_t r01 = vld1q_f32(r0 + 4); + float32x4_t r11 = vld1q_f32(r1 + 4); + + float32x4_t max0 = vmaxq_f32(r00, r10); + float32x4_t max1 = vmaxq_f32(r01, r11); + + float32x4_t max_result = vpmaxq_f32(max0, max1); + + vst1q_f32(outptr, max_result); + r0 += 8; + r1 += 8; + outptr += 4; + } + + for (; remain > 0; --remain) { + float max0 = std::max(r0[0], r0[1]); + float max1 = std::max(r1[0], r1[1]); + *outptr = std::max(max0, max1); + + r0 += 2; + r1 += 2; + outptr++; + } + r0 += in_width; + r1 += in_width; + } + input_offset += in_image_size; + output_offset += out_image_size; + } + } +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/neon/max_pooling_neon_3x3.cc b/mace/kernels/neon/max_pooling_neon_3x3.cc new file mode 100644 index 0000000000000000000000000000000000000000..045ce7b08e451ac6ebf06124036d094fc3206534 --- /dev/null +++ b/mace/kernels/neon/max_pooling_neon_3x3.cc @@ -0,0 +1,222 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include +#include + +#include "mace/core/common.h" + +namespace mace { +namespace kernels { + +void PoolingMaxNeonK3x3S2x2(const float *input, + const index_t *in_shape, + float *output, + const index_t *out_shape, + const int *paddings) { + index_t batch = in_shape[0]; + index_t channels = in_shape[1]; + index_t in_height = in_shape[2]; + index_t in_width = in_shape[3]; + + index_t out_height = out_shape[2]; + index_t out_width = out_shape[3]; + + int padding_top = paddings[0] / 2; + int padding_bottom = paddings[0] - padding_top; + int padding_left = paddings[1] / 2; + int padding_right = paddings[1] - padding_left; + + int in_image_size = in_height * in_width; + int out_image_size = out_height * out_width; + index_t input_offset = 0; + index_t output_offset = 0; + +#pragma omp parallel for collapse(2) + for (int b = 0; b < batch; ++b) { + for (int c = 0; c < channels; ++c) { + float *outptr = output + output_offset; + + for (int h = 0; h < out_height; ++h) { + int w = 0; + int num_vectors = 0; + const float *r0, *r1, *r2; + if (!((h == 0 && padding_top > 0) || + (h == out_height - 1 && padding_bottom > 0))) { + r0 = input + input_offset + (h * 2 - padding_top) * in_width; + r1 = r0 + in_width; + r2 = r1 + in_width; + + if (padding_left > 0) { + if (padding_left == 1) { + float max0 = std::max(r0[0], r0[1]); + float max1 = std::max(r1[0], r1[1]); + float max2 = std::max(r2[0], r2[1]); + *outptr = std::max(std::max(max0, max1), max2); + ++r0; + ++r1; + } else { // padding_left == 2 + float max_tmp = std::max(r0[0], r1[0]); + *outptr = std::max(max_tmp, r2[0]); + } + ++outptr; + ++w; + } + if (padding_right > 0) { + num_vectors = (out_width - w - 1) >> 2; + } else { + num_vectors = (out_width - w) >> 2; + } + } + + float32x4x2_t row0 = vld2q_f32(r0); + float32x4x2_t row1 = vld2q_f32(r1); + float32x4x2_t row2 = vld2q_f32(r2); + for (; num_vectors > 0; --num_vectors) { + float32x4x2_t row0_next = vld2q_f32(r0 + 8); + float32x4x2_t row1_next = vld2q_f32(r1 + 8); + float32x4x2_t row2_next = vld2q_f32(r2 + 8); + + float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]); + float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]); + float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]); + + float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1); + float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1); + float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1); + + max0 = vmaxq_f32(max0, row02); + max1 = vmaxq_f32(max1, row12); + max2 = vmaxq_f32(max2, row22); + + float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2); + + vst1q_f32(outptr, max_result); + + row0 = row0_next; + row1 = row1_next; + row2 = row2_next; + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 4; + } + + w += num_vectors << 2; + for (; w < out_width; ++w) { + float max = std::numeric_limits::lowest(); + for (int kh = 0; kh < 3; ++kh) { + for (int kw = 0; kw < 3; ++kw) { + int inh = h * 2 - padding_top + kh; + int inw = w * 2 - padding_left + kw; + if (inh >= 0 && inh < in_height && + inw >= 0 && inw < in_width) { + max = std::max(max, input[input_offset + inh * in_width + inw]); + } + } + } + + *outptr = max; + ++outptr; + } + } + input_offset += in_image_size; + output_offset += out_image_size; + } + } +} + +// assume the input has already been padded +void PoolingMaxNeonK3x3S2x2Padded(const float *input, + const index_t *in_shape, + float *output, + const index_t *out_shape) { + index_t batch = in_shape[0]; + index_t channels = in_shape[1]; + index_t in_height = in_shape[2]; + index_t in_width = in_shape[3]; + + index_t out_height = out_shape[2]; + index_t out_width = out_shape[3]; + + int in_image_size = in_height * in_width; + int out_image_size = out_height * out_width; + index_t input_offset = 0; + index_t output_offset = 0; + +#pragma omp parallel for collapse(2) + for (int b = 0; b < batch; ++b) { + for (int c = 0; c < channels; ++c) { + const float *img0 = input + input_offset; + float *outptr = output + output_offset; + + const float *r0 = img0; + const float *r1 = r0 + in_width; + const float *r2 = r1 + in_width; + + for (int h = 0; h < out_height; h++) { + int num_vectors = out_width >> 2; + int remain = out_width - (num_vectors << 2); + + float32x4x2_t row0 = vld2q_f32(r0); + float32x4x2_t row1 = vld2q_f32(r1); + float32x4x2_t row2 = vld2q_f32(r2); + for (; num_vectors > 0; num_vectors--) { + float32x4x2_t row0_next = vld2q_f32(r0 + 8); + float32x4x2_t row1_next = vld2q_f32(r1 + 8); + float32x4x2_t row2_next = vld2q_f32(r2 + 8); + + float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]); + float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]); + float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]); + + float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1); + float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1); + float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1); + + max0 = vmaxq_f32(max0, row02); + max1 = vmaxq_f32(max1, row12); + max2 = vmaxq_f32(max2, row22); + + float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2); + + vst1q_f32(outptr, max_result); + + row0 = row0_next; + row1 = row1_next; + row2 = row2_next; + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 4; + } + + for (; remain > 0; remain--) { + float max0 = std::max(std::max(r0[0], r0[1]), r0[2]); + float max1 = std::max(std::max(r1[0], r1[1]), r1[2]); + float max2 = std::max(std::max(r2[0], r2[1]), r2[2]); + + *outptr = std::max(std::max(max0, max1), max2); + + r0 += 2; + r1 += 2; + r2 += 2; + outptr++; + } + + r0 += 1 + in_width; + r1 += 1 + in_width; + r2 += 1 + in_width; + } + input_offset += in_image_size; + output_offset += out_image_size; + } + } +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/neon/pooling_neon.cc b/mace/kernels/neon/pooling_neon.cc new file mode 100644 index 0000000000000000000000000000000000000000..33d763413c3c879e0d286d77ba96c405f9859df2 --- /dev/null +++ b/mace/kernels/neon/pooling_neon.cc @@ -0,0 +1,77 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/kernels/pooling.h" +#include "mace/kernels/conv_pool_2d_util.h" + +namespace mace { +namespace kernels { + +extern void PoolingMaxNeonK2x2S2x2(const float *input, + const index_t *in_shape, + float *output, + const index_t *out_shape, + const int *paddings); + +extern void PoolingMaxNeonK3x3S2x2(const float *input, + const index_t *in_shape, + float *output, + const index_t *out_shape, + const int *paddings); + +#ifdef __COPY_MAKE_PADDING +extern void PoolingMaxNeonK2x2S2x2Padded(const float* input, + const index_t* in_shape, + float* output, + const index_t* out_shape); +extern void PoolingMaxNeonK3x3S2x2Padded(const float* input, + const index_t* in_shape, + float* output, + const index_t* out_shape); +#endif + +template<> +void PoolingFunctor::operator()( + const float *input, + const index_t *input_shape, + float *output, + const index_t *output_shape) { + if (kernels_[0] == 2 && kernels_[1] == 2 && + strides_[0] == 2 && strides_[1] == 2 && + pooling_type_ == MAX) { +#ifdef __COPY_MAKE_PADDING + Tensor padded_input; + ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); + input = padded_input.data(); + input_shape = padded_input.shape().data(); + PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape); +#else + PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape, paddings_); +#endif + } else if (kernels_[0] == 3 && kernels_[1] == 3 && + strides_[0] == 2 && strides_[1] == 2 && + pooling_type_ == MAX) { +#ifdef __COPY_MAKE_PADDING + Tensor padded_input; + ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); + input = padded_input.data(); + input_shape = padded_input.shape().data(); + PoolingMaxNeonK3x3S2x2V2Padded(input, input_shape, output, output_shape); +#else + PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape, paddings_); +#endif + } else { // not implement yet + PoolingFunctor(pooling_type_, kernels_, strides_, + paddings_, dilations_)( + input, + input_shape, + output, + output_shape + ); + } +} + +} // namespace kernels +} // namespace mace \ No newline at end of file diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index 52956159dea0b25a63d0d90f02f6e2c2bf2672b7..b8a1bdd75972bee260a401290e531925b6d665e1 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -19,33 +19,33 @@ namespace kernels { template class PoolingFunctor { -public: + public: PoolingFunctor(const PoolingType pooling_type, - const int* kernels, - const int* strides, - const int* paddings, - const int* dilations) - : pooling_type_(pooling_type), - kernels_(kernels), - strides_(strides), - paddings_(paddings), - dilations_(dilations) {} - - void operator()(const T* input, - const index_t* input_shape, - T* output, - const index_t* output_shape) { - index_t batch = output_shape[0]; + const int *kernels, + const int *strides, + const int *paddings, + const int *dilations) + : pooling_type_(pooling_type), + kernels_(kernels), + strides_(strides), + paddings_(paddings), + dilations_(dilations) {} + + void operator()(const T *input, + const index_t *input_shape, + T *output, + const index_t *output_shape) { + index_t batch = output_shape[0]; index_t channels = output_shape[1]; - index_t height = output_shape[2]; - index_t width = output_shape[3]; + index_t height = output_shape[2]; + index_t width = output_shape[3]; index_t input_channels = input_shape[1]; - index_t input_height = input_shape[2]; - index_t input_width = input_shape[3]; + index_t input_height = input_shape[2]; + index_t input_width = input_shape[3]; int kernel_h = kernels_[0]; - int kernel_w = kernels_[1]; + int kernel_w = kernels_[1]; int stride_h = strides_[0]; int stride_w = strides_[1]; @@ -61,20 +61,20 @@ public: for (int n = 0; n < batch; ++n) { for (int c = 0; c < channels; ++c) { index_t out_offset = n * channels * height * width + - c * height * width; + c * height * width; index_t in_offset = n * input_channels * input_height * input_width + - c * input_height * input_width; + c * input_height * input_width; for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { T sum_or_max = 0; switch (pooling_type_) { - case AVG: - break; - case MAX: - sum_or_max = std::numeric_limits::lowest(); + case AVG:break; + case MAX:sum_or_max = std::numeric_limits::lowest(); break; default: - MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); + MACE_CHECK(false, + "Unsupported pooling type: ", + pooling_type_); } for (int kh = 0; kh < kernel_h; ++kh) { for (int kw = 0; kw < kernel_w; ++kw) { @@ -83,10 +83,9 @@ public: if (inh >= 0 && inh < input_height && inw >= 0 && inw < input_width) { index_t input_offset = in_offset + - inh * input_width + inw; + inh * input_width + inw; switch (pooling_type_) { - case AVG: - sum_or_max += input[input_offset]; + case AVG:sum_or_max += input[input_offset]; break; case MAX: sum_or_max = std::max(sum_or_max, input[input_offset]); @@ -99,14 +98,14 @@ public: } } switch (pooling_type_) { - case AVG: - output[out_offset] = sum_or_max / (kernel_h * kernel_w); + case AVG:output[out_offset] = sum_or_max / (kernel_h * kernel_w); break; - case MAX: - output[out_offset] = sum_or_max; + case MAX:output[out_offset] = sum_or_max; break; default: - MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); + MACE_CHECK(false, + "Unsupported pooling type: ", + pooling_type_); } out_offset += 1; } @@ -115,14 +114,20 @@ public: } } -private: + private: const PoolingType pooling_type_; - const int* kernels_; - const int* strides_; - const int* paddings_; - const int* dilations_; + const int *kernels_; + const int *strides_; + const int *paddings_; + const int *dilations_; }; +template<> +void PoolingFunctor::operator()( + const float *input, + const index_t *input_shape, + float *output, + const index_t *output_shape); } // namespace kernels } // namespace mace diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index 58ced63e941e7d1f0f602d4de2191aa40fcb2d0e..cab59685270fe4b2efb452085fe7f24aa357e4f6 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -9,4 +9,8 @@ namespace mace { REGISTER_CPU_OPERATOR(Pooling, PoolingOp); +#if __ARM_NEON +REGISTER_NEON_OPERATOR(Pooling, PoolingOp); +#endif // __ARM_NEON + } // namespace mace diff --git a/mace/ops/pooling_benchmark.cc b/mace/ops/pooling_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..ccdcb206957356734eb4bd29c562e01699931e75 --- /dev/null +++ b/mace/ops/pooling_benchmark.cc @@ -0,0 +1,65 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/testing/test_benchmark.h" +#include "mace/core/operator.h" +#include "mace/kernels/pooling.h" +#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; +using namespace mace::kernels; + +template +static void Pooling(int iters, int batch, int channels, int height, + int width, int kernel, int stride, Padding padding, + PoolingType pooling_type) { + + mace::testing::StopTiming(); + + OpsTestNet net; + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .Finalize(net.operator_def()); + + // Add args + net.AddIntArg("pooling_type", pooling_type); + net.AddIntsArg("kernels", {kernel, kernel}); + net.AddIntsArg("strides", {stride, stride}); + net.AddIntArg("padding", padding); + net.AddIntsArg("dilations", {1, 1}); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } +} + +#define BM_POOLING_MACRO(N, C, H, W, KE, STRIDE, PA, PO, DEVICE) \ + static void BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot * (sizeof(float)));\ + Pooling(iters, N, C, H, W, KE, STRIDE, Padding::PA, PoolingType::PO); \ + } \ + BENCHMARK(BM_POOLING_##N##_##C##_##H##_##W##_K##KE##S##STRIDE##_##PA##_##PO##_##DEVICE) + +#define BM_POOLING(N, C, H, W, K, S, PA, PO) \ + BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \ + BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON); + +BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX); +BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX); +BM_POOLING(1, 3, 513, 513, 2, 2, SAME, MAX); +BM_POOLING(1, 3, 1025, 1025, 2, 2, SAME, MAX); diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index 9b7cc0cf3e24b07908708f100aa6906a9d6a8b2e..7ff8e351e69da6cb98ed2730edcbeed77eebb095 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -148,3 +148,67 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } + +TEST_F(PoolingOpTest, MAX_k2x2s2x2) { + // Construct graph + auto& net = test_net(); + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .Finalize(net.operator_def()); + + // Add args + net.AddIntArg("pooling_type", PoolingType::MAX); + net.AddIntsArg("kernels", {2, 2}); + net.AddIntsArg("strides", {2, 2}); + net.AddIntArg("padding", Padding::SAME); + net.AddIntsArg("dilations", {1, 1}); + + // Add input data + net.AddInputFromArray("Input", {1, 1, 4, 5}, + {0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19}); + // Run + net.RunOp(DeviceType::NEON); + + // Check + Tensor expected = CreateTensor({1, 1, 2, 3}, + {6, 8, 9, + 16, 18, 19}); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 0.001); +} + +TEST_F(PoolingOpTest, MAX_k3x3s2x2) { + // Construct graph + auto& net = test_net(); + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .Finalize(net.operator_def()); + + // Add args + net.AddIntArg("pooling_type", PoolingType::MAX); + net.AddIntsArg("kernels", {3, 3}); + net.AddIntsArg("strides", {2, 2}); + net.AddIntArg("padding", Padding::SAME); + net.AddIntsArg("dilations", {1, 1}); + + // Add input data + net.AddInputFromArray("Input", {1, 1, 4, 5}, + {0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19}); + // Run + net.RunOp(DeviceType::NEON); + + // Check + Tensor expected = CreateTensor({1, 1, 2, 3}, + {11, 13, 14, + 16, 18, 19}); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 0.001); +}