diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index 4ecbdd6a7ebd2e4fd4896587b44bba316259fe46..c8bc2012bda48dd0a5955c783d799b3c08726af6 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -71,6 +71,51 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW output_shape[3] = output_width; } +void CalPaddingSize(const index_t *input_shape, // NCHW + const index_t *filter_shape, // OIHW + const int *dilations, + const int *strides, + Padding padding, + int *padding_size) { + + MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, + "Invalid dilations, must >= 1"); + MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && + (dilations[1] == 1 || strides[1] == 1), + "If dilations > 1, strides should be 1"); + MACE_CHECK_NOTNULL(padding_size); + + index_t output_height, output_width; + index_t k_extent_height = (filter_shape[2] - 1) * dilations[0] + 1; + index_t k_extent_width = (filter_shape[3] - 1) * dilations[1] + 1; + + switch (padding) { + case VALID: + output_height = (input_shape[2] - k_extent_height) / strides[0] + 1; + output_width = (input_shape[3] - k_extent_width) / strides[1] + 1; + break; + case SAME: + output_height = (input_shape[2] - 1) / strides[0] + 1; + output_width = (input_shape[3] - 1) / strides[1] + 1; + break; + case FULL: + output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; + output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; + break; + default: + MACE_CHECK(false, "Unsupported padding type: ", padding); + } + + // Note: TensorFlow may padded one more on the right/bottom side + // TODO may be it's better to also truncate the left/top to + // utilize the more centered features. We need to benchmark + // based on the model accuracy. + padding_size[0] = + (output_height - 1) * strides[0] + k_extent_height - input_shape[2]; + padding_size[1] = + (output_width - 1) * strides[1] + k_extent_width - input_shape[3]; +} + void ConstructInputWithPadding(const float *input, const index_t *input_shape, const int *paddings, diff --git a/mace/kernels/conv_pool_2d_util.h b/mace/kernels/conv_pool_2d_util.h index a164d55e35498949b7862c2209cf0ae15099329a..26f2ab37695088307d851da05aad4f99092294f2 100644 --- a/mace/kernels/conv_pool_2d_util.h +++ b/mace/kernels/conv_pool_2d_util.h @@ -25,6 +25,13 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW index_t *output_shape, int *padding_size); +void CalPaddingSize(const index_t *input_shape, // NCHW + const index_t *filter_shape, // OIHW + const int *dilations, + const int *strides, + Padding padding, + int *padding_size); + void ConstructInputWithPadding(const float *input, const index_t *input_shape, const int *paddings, diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h new file mode 100644 index 0000000000000000000000000000000000000000..472733af3656db1969e278ee9743b2510c2980ea --- /dev/null +++ b/mace/kernels/depthwise_conv2d.h @@ -0,0 +1,129 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_DEPTHWISE_CONV_H_ +#define MACE_KERNELS_DEPTHWISE_CONV_H_ + +#include "mace/proto/mace.pb.h" +#include "mace/core/common.h" +#include "mace/kernels/conv_pool_2d_util.h" + +namespace mace { +namespace kernels { + +template +class DepthwiseConv2dFunctor { + public: + DepthwiseConv2dFunctor(const index_t *input_shape, + const index_t *filter_shape, + const int *strides, + const Padding padding, + const int *dilations) : + strides_(strides), + paddings_(2, 0), + dilations_(dilations) { + CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data()); + } + DepthwiseConv2dFunctor(const int *strides, + const std::vector &paddings, + const int *dilations) : + strides_(strides), + paddings_(paddings), + dilations_(dilations) {} + + void operator()(const T *input, // NCHW + const index_t *input_shape, + const T *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, + const T *bias, // c_out + T *output, // NCHW + const index_t *output_shape) { + + MACE_CHECK_NOTNULL(output); + + index_t batch = output_shape[0]; + index_t channels = output_shape[1]; + index_t height = output_shape[2]; + index_t width = output_shape[3]; + + index_t input_batch = input_shape[0]; + index_t input_channels = input_shape[1]; + index_t input_height = input_shape[2]; + index_t input_width = input_shape[3]; + + index_t kernel_h = filter_shape[2]; + index_t kernel_w = filter_shape[3]; + + int stride_h = strides_[0]; + int stride_w = strides_[1]; + + int dilation_h = dilations_[0]; + int dilation_w = dilations_[1]; + + MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); + + // The left-upper most offset of the padded input + int padded_h_start = 0 - paddings_[0] / 2; + int padded_w_start = 0 - paddings_[1] / 2; + index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2; + index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2; + + index_t kernel_size = filter_shape[1] * kernel_h * kernel_w; + index_t multiplier = channels / input_channels; + +#pragma omp parallel for collapse(2) + for (int n = 0; n < batch; ++n) { + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + index_t offset = n * channels * height * width + + c * height * width + h * width + w; + T sum = 0; + const T *filter_ptr = filter + c * kernel_size; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + if (inh < 0 || inh >= input_height || inw < 0 || + inw >= input_width) { + MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && + inw >= padded_w_start && inw < padded_w_stop, + "Out of range read from input: ", inh, ", ", + inw); + // else padding with 0: + // sum += 0; + } else { + index_t input_offset = + n * input_channels * input_height * input_width + + (c / multiplier) * input_height * input_width + inh * input_width + + inw; + sum += input[input_offset] * *filter_ptr; + } + ++filter_ptr; + } + } + output[offset] = sum + bias[c]; + } + } + } + } + } + private: + const int *strides_; // [stride_h, stride_w] + std::vector paddings_; // [padding_h, padding_w] + const int *dilations_; // [dilation_h, dilation_w] +}; + +template<> +void DepthwiseConv2dFunctor::operator()(const float *input, + const index_t *input_shape, + const float *filter, + const index_t *filter_shape, + const float *bias, + float *output, + const index_t *output_shape); +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_DEPTHWISE_CONV_H_ diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index 0e01def90294623492f5a88a41843a2a4f05363c..c530e496e5459124808a5c4500f01d84100995a3 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -11,6 +11,7 @@ namespace kernels { extern void Conv2dNeonK1x1S1(const float *input, const index_t *input_shape, const float *filter, + const index_t *filter_shape, const float *bias, float *output, const index_t *output_shape); @@ -18,6 +19,7 @@ extern void Conv2dNeonK1x1S1(const float *input, extern void Conv2dNeonK3x3S1(const float *input, const index_t *input_shape, const float *filter, + const index_t *filter_shape, const float *bias, float *output, const index_t *output_shape); @@ -25,6 +27,7 @@ extern void Conv2dNeonK3x3S1(const float *input, extern void Conv2dNeonK3x3S2(const float *input, const index_t *input_shape, const float *filter, + const index_t *filter_shape, const float *bias, float *output, const index_t *output_shape); @@ -32,6 +35,7 @@ extern void Conv2dNeonK3x3S2(const float *input, extern void Conv2dNeonK5x5S1(const float *input, const index_t *input_shape, const float *filter, + const index_t *filter_shape, const float *bias, float *output, const index_t *output_shape); @@ -48,6 +52,7 @@ void Conv2dFunctor::operator()(const float *input, const float *input, const index_t *input_shape, const float *filter, + const index_t *filter_shape, const float *bias, float *output, const index_t *output_shape); @@ -81,7 +86,7 @@ void Conv2dFunctor::operator()(const float *input, input_shape = padded_input.shape().data(); } auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; - conv2d_neon_func(input, input_shape, filter, bias, output, output_shape); + conv2d_neon_func(input, input_shape, filter, nullptr, bias, output, output_shape); } } // namespace kernels diff --git a/mace/kernels/neon/conv_2d_neon_1x1.cc b/mace/kernels/neon/conv_2d_neon_1x1.cc index 59cd101bc970a3186bddfbe60efacb08a1d287f5..a82505e79296e7f362139643bd700584d6d89caa 100644 --- a/mace/kernels/neon/conv_2d_neon_1x1.cc +++ b/mace/kernels/neon/conv_2d_neon_1x1.cc @@ -8,12 +8,13 @@ namespace mace { namespace kernels { -void Conv2dNeonK1x1S1(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape) { +void Conv2dNeonK1x1S1(const float *input, // NCHW + const index_t *input_shape, + const float *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, + const float *bias, // c_out + float *output, // NCHW + const index_t *output_shape) { const index_t batch = output_shape[0]; const index_t channels = output_shape[1]; const index_t height = output_shape[2]; @@ -25,7 +26,7 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW const index_t input_width = input_shape[3]; MACE_CHECK(input_batch == batch && input_height == height && - input_width == width); + input_width == width); const index_t total_pixels = height * width; // Process 4 * 2 = 8 pixels for each innermost loop @@ -35,17 +36,17 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW // benchmark omp collapsed(2) for (index_t n = 0; n < batch; ++n) { - const float* filter_ptr = filter; + const float *filter_ptr = filter; #pragma omp parallel for for (index_t c = 0; c < channels; ++c) { // TODO Will GCC opt these out? - float* channel_output_start = + float *channel_output_start = output + n * channels * height * width + c * height * width; - const float* input_ptr = + const float *input_ptr = input + n * input_channels * input_height * input_width; // Fill with bias - float* output_ptr = channel_output_start; + float *output_ptr = channel_output_start; for (index_t ptr = 0; ptr < total_pixels; ++ptr) { output_ptr[ptr] = bias[c]; // TODO can we avoid this? } @@ -53,15 +54,15 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW index_t inc = 0; // Process 4 input channels in batch for (; inc + 3 < input_channels; inc += 4) { - float* output_ptr = channel_output_start; + float *output_ptr = channel_output_start; // The begining of each input feature map channel MACE_ASSERT(input_ptr == - input + n * input_channels * input_height * input_width + - inc * input_height * input_width); + input + n * input_channels * input_height * input_width + + inc * input_height * input_width); - const float* input_ptr1 = input_ptr + total_pixels; - const float* input_ptr2 = input_ptr1 + total_pixels; - const float* input_ptr3 = input_ptr2 + total_pixels; + const float *input_ptr1 = input_ptr + total_pixels; + const float *input_ptr2 = input_ptr1 + total_pixels; + const float *input_ptr3 = input_ptr2 + total_pixels; // filter is in c_out, c_in, 1, 1 order MACE_ASSERT(filter_ptr == filter + c * input_channels + inc); @@ -139,10 +140,10 @@ void Conv2dNeonK1x1S1(const float* input, // NCHW } // Process the remaining channels for (; inc < input_channels; ++inc) { - float* output_ptr = channel_output_start; + float *output_ptr = channel_output_start; MACE_ASSERT(input_ptr == - input + n * input_channels * input_height * input_width + - inc * input_height * input_width); + input + n * input_channels * input_height * input_width + + inc * input_height * input_width); MACE_ASSERT(filter_ptr == filter + c * input_channels + inc); const float k0 = filter_ptr[0]; diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc index 0cf589b929831375fab427d6f8e66b8728ea190c..6b62cb5937f84c7169e1da05883e9eaf40da701c 100644 --- a/mace/kernels/neon/conv_2d_neon_3x3.cc +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -17,30 +17,35 @@ namespace kernels { int input_channels = input_shape[1]; \ int input_height = input_shape[2]; \ int input_width = input_shape[3]; \ - int kernel_h = 3; \ - int kernel_w = 3; \ + int multiplier = filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels); \ + int filter_in_channels = filter_shape == nullptr ? input_channels : filter_shape[1]; \ for (int b = 0; b < output_batch; ++b) { \ - float* output_ptr_base = output + b * output_channels * output_height * output_width; \ + float *output_ptr_base = output + b * output_channels * output_height * output_width; \ for (int oc = 0; oc < output_channels; ++oc) { \ - const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; \ - const float* input_ptr = input + b * input_channels * input_height * input_width; \ - float* output_ptr = output_ptr_base + oc * output_height * output_width; \ + const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize; \ + const float *input_ptr = input + b * input_channels * input_height * input_width; \ + if (filter_shape != nullptr) { \ + input_ptr += (oc / multiplier) * input_height * input_width; \ + } \ + float *output_ptr = output_ptr_base + oc * output_height * output_width; \ std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \ - for (int ic = 0; ic < input_channels; ++ic) { \ + for (int ic = 0; ic < filter_in_channels; ++ic) { \ float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)}; #define KERNEL_TAIL_CODE \ - filter_ptr += 9; \ + filter_ptr += kFilterSize; \ input_ptr += input_height * input_width; \ } \ } \ } static const int kRegisterSize = 4; +static const int kFilterSize = 9; void Conv2dNeonK3x3S1(const float *input, // NCHW const index_t *input_shape, const float *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, const float *bias, // c_out float *output, // NCHW const index_t *output_shape) { @@ -213,6 +218,7 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW void Conv2dNeonK3x3S2(const float *input, // NCHW const index_t *input_shape, const float *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, const float *bias, // c_out float *output, // NCHW const index_t *output_shape) { @@ -287,7 +293,6 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW KERNEL_TAIL_CODE } - #undef KERNEL_HEAD_CODE #undef KERNEL_TAIL_CODE diff --git a/mace/kernels/neon/conv_2d_neon_5x5.cc b/mace/kernels/neon/conv_2d_neon_5x5.cc index 0e926eb21d4e590301eebcebe4e80354b4146a49..02c5ced2a3177af71544c6ccaf324cc133f686cf 100644 --- a/mace/kernels/neon/conv_2d_neon_5x5.cc +++ b/mace/kernels/neon/conv_2d_neon_5x5.cc @@ -10,12 +10,13 @@ namespace mace { namespace kernels { -void Conv2dNeonK5x5S1(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape) { +void Conv2dNeonK5x5S1(const float *input, // NCHW + const index_t *input_shape, + const float *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, + const float *bias, // c_out + float *output, // NCHW + const index_t *output_shape) { const index_t batch = output_shape[0]; const index_t channels = output_shape[1]; const index_t height = output_shape[2]; @@ -39,9 +40,9 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW #pragma omp parallel for collapse(2) for (index_t n = 0; n < batch; ++n) { for (index_t c = 0; c < channels; ++c) { - float* output_ptr = output + n * output_total_pixels_per_batch + - c * output_total_pixels_per_channel; - const float* input_ptr = input + n * input_total_pixels_per_batch; + float *output_ptr = output + n * output_total_pixels_per_batch + + c * output_total_pixels_per_channel; + const float *input_ptr = input + n * input_total_pixels_per_batch; // Fill with bias for (index_t i = 0; i < output_total_pixels_per_channel; ++i) { @@ -49,24 +50,24 @@ void Conv2dNeonK5x5S1(const float* input, // NCHW } for (index_t inc = 0; inc < input_channels; ++inc) { - float* outptr = output_ptr; - float* outptr2 = outptr + width; - - const float* inptr = input_ptr + inc * input_total_pixels_per_channel; - const float* filter_ptr = filter + c * patch_size + inc * 25; - - const float* r0 = inptr; - const float* r1 = inptr + input_width; - const float* r2 = inptr + input_width * 2; - const float* r3 = inptr + input_width * 3; - const float* r4 = inptr + input_width * 4; - const float* r5 = inptr + input_width * 5; - - const float* k0 = filter_ptr; - const float* k1 = filter_ptr + 5; - const float* k2 = filter_ptr + 10; - const float* k3 = filter_ptr + 15; - const float* k4 = filter_ptr + 20; + float *outptr = output_ptr; + float *outptr2 = outptr + width; + + const float *inptr = input_ptr + inc * input_total_pixels_per_channel; + const float *filter_ptr = filter + c * patch_size + inc * 25; + + const float *r0 = inptr; + const float *r1 = inptr + input_width; + const float *r2 = inptr + input_width * 2; + const float *r3 = inptr + input_width * 3; + const float *r4 = inptr + input_width * 4; + const float *r5 = inptr + input_width * 5; + + const float *k0 = filter_ptr; + const float *k1 = filter_ptr + 5; + const float *k2 = filter_ptr + 10; + const float *k3 = filter_ptr + 15; + const float *k4 = filter_ptr + 20; float32x4_t _k0123 = vld1q_f32(filter_ptr); float32x4_t _k4567 = vld1q_f32(filter_ptr + 4); diff --git a/mace/kernels/neon/depthwise_conv_neon.cc b/mace/kernels/neon/depthwise_conv_neon.cc new file mode 100644 index 0000000000000000000000000000000000000000..eda2325d8b371218f2dcedefd34c124e3b75a9e9 --- /dev/null +++ b/mace/kernels/neon/depthwise_conv_neon.cc @@ -0,0 +1,77 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/depthwise_conv2d.h" +#include "mace/kernels/conv_2d.h" + +namespace mace { +namespace kernels { + +extern void Conv2dNeonK3x3S1(const float *input, + const index_t *input_shape, + const float *filter, + const index_t *filter_shape, + const float *bias, + float *output, + const index_t *output_shape); + +extern void Conv2dNeonK3x3S2(const float *input, + const index_t *input_shape, + const float *filter, + const index_t *filter_shape, + const float *bias, + float *output, + const index_t *output_shape); + +template<> +void DepthwiseConv2dFunctor::operator()(const float *input, // NCHW + const index_t *input_shape, + const float *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, + const float *bias, // c_out + float *output, // NCHW + const index_t *output_shape) { + typedef void (*Conv2dNeonFunction)( + const float *input, + const index_t *input_shape, + const float *filter, + const index_t *filter_shape, + const float *bias, + float *output, + const index_t *output_shape); + // Selection matrix: kernel_size x stride_size + static const Conv2dNeonFunction selector[5][2] = { + {nullptr, nullptr}, + {nullptr, nullptr}, + {Conv2dNeonK3x3S1, Conv2dNeonK3x3S2}, + {nullptr, nullptr}, + {nullptr, nullptr}}; + // not implement yet + index_t kernel_h = filter_shape[2]; + index_t kernel_w = filter_shape[3]; + if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] || + strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 || + selector[kernel_h - 1][strides_[0] - 1] == nullptr) { + LOG(WARNING) << "Depthwise-Conv2d NEON kernel with " + << "filter" << kernel_h << "x" << kernel_w << "," + << " stride " << strides_[0] << "x" << strides_[1] + << " is not implemented yet, using slow version"; + DepthwiseConv2dFunctor(strides_, paddings_, dilations_)( + input, input_shape, filter, filter_shape, bias, output, output_shape); + return; + } + + // Keep this alive during kernel execution + Tensor padded_input; + if (paddings_[0] > 0 || paddings_[1] > 0) { + ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input); + input = padded_input.data(); + input_shape = padded_input.shape().data(); + } + auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1]; + conv2d_neon_func(input, input_shape, filter, filter_shape, bias, output, output_shape); +} + +} // namespace kernels +} // namespace mace \ No newline at end of file diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index 33c609565b8dcc500577ad66d083632ce2dd02c4..90542e7be4ab829aefad0c451b9d520b1e2b8103 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -3,7 +3,6 @@ // #include "mace/ops/conv_2d.h" -#include "mace/proto/mace.pb.h" namespace mace { diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index ad3206b0045db7be28452fcfc602ffc5da9082ff..3ac6689cd8d8a6f5198c56a70623ed50a7d6e0b7 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -13,17 +13,17 @@ namespace mace { -template +template class Conv2dOp : public ConvPool2dOpBase { public: - Conv2dOp(const OperatorDef& op_def, Workspace* ws) - : ConvPool2dOpBase(op_def, ws){}; + Conv2dOp(const OperatorDef &op_def, Workspace *ws) + : ConvPool2dOpBase(op_def, ws) {}; bool Run() override { - const Tensor* input = this->Input(INPUT); - const Tensor* filter = this->Input(FILTER); - const Tensor* bias = this->Input(BIAS); - Tensor* output = this->Output(OUTPUT); + const Tensor *input = this->Input(INPUT); + const Tensor *filter = this->Input(FILTER); + const Tensor *bias = this->Input(BIAS); + Tensor *output = this->Output(OUTPUT); std::vector output_shape(4); std::vector paddings(2); diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 4164ba0c9faa1189249d1037aa0204d962a08336..bbf6b608245a918c780e0b8d4ff7f375d7887f94 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -56,7 +56,7 @@ static void Conv2d(int iters, #define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \ static void \ - BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \ + BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \ int iters) { \ const int64_t tot = static_cast(iters) * N * C * H * W; \ mace::testing::ItemsProcessed(tot); \ @@ -64,7 +64,7 @@ static void Conv2d(int iters, Conv2d(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \ } \ BENCHMARK( \ - BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE) + BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) #define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \ @@ -74,9 +74,11 @@ BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); +BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float); +BM_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float); BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float); diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 96880a02ef0fc41857aba2bf8a78698a91e4240f..c5b8751430bda21491212e9c867d28d092d785e2 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -173,10 +173,10 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { // generate random input index_t batch = 1 + rand() % 10; - index_t input_channels = 1 + rand() % 50; - index_t height = 11 + rand() % 100; - index_t width = 11 + rand() % 100; - index_t output_channels = 1 + rand() % 50; + index_t input_channels = 1 + rand() % 10; + index_t height = 107; + index_t width = 113; + index_t output_channels = 1 + rand() % 10; // Construct graph auto& net = test_net(); OpDefBuilder("Conv2d", "Conv2dTest") diff --git a/mace/ops/conv_pool_2d_base.h b/mace/ops/conv_pool_2d_base.h index a572b71efbbe0ee1af1c2a253ee14f66226ec54f..9b1838a36c2ef4f68217616569edcee32d3a6f9e 100644 --- a/mace/ops/conv_pool_2d_base.h +++ b/mace/ops/conv_pool_2d_base.h @@ -18,8 +18,49 @@ class ConvPool2dOpBase : public Operator { strides_(OperatorBase::GetRepeatedArgument("strides")), padding_(static_cast(OperatorBase::GetSingleArgument( "padding", static_cast(SAME)))), - dilations_(OperatorBase::GetRepeatedArgument("dilations")) {} + dilations_(OperatorBase::GetRepeatedArgument("dilations", {1, 1})) {} + void CalOutputSize(const index_t *input_shape, // NCHW + const index_t *filter_shape, // OIHW + index_t *output_shape) { + + MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0, + "Invalid dilations, must >= 1"); + MACE_CHECK((dilations_[0] == 1 || strides_[0] == 1) && + (dilations_[1] == 1 || strides_[1] == 1), + "If dilations > 1, strides should be 1"); + MACE_CHECK_NOTNULL(output_shape); + /* + * Convlution/pooling arithmetic: + * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 + * For details, see https://arxiv.org/pdf/1603.07285.pdf or + * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html + */ + + index_t output_height, output_width; + + switch (padding_) { + case VALID: + output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1; + output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1; + break; + case SAME: + output_height = (input_shape[2] - 1) / strides_[0] + 1; + output_width = (input_shape[3] - 1) / strides_[1] + 1; + break; + case FULL: + output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1; + output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1; + break; + default: + MACE_CHECK(false, "Unsupported padding type: ", padding_); + } + + output_shape[0] = input_shape[0]; + output_shape[1] = filter_shape[0]; + output_shape[2] = output_height; + output_shape[3] = output_width; + } protected: std::vector strides_; Padding padding_; diff --git a/mace/ops/depthwise_conv2d.cc b/mace/ops/depthwise_conv2d.cc new file mode 100644 index 0000000000000000000000000000000000000000..320842e1e59dda9afd059b6144d8a07fffeff36d --- /dev/null +++ b/mace/ops/depthwise_conv2d.cc @@ -0,0 +1,15 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/depthwise_conv2d.h" + +namespace mace { + +REGISTER_CPU_OPERATOR(DepthwiseConv2d, DepthwiseConv2dOp); + +#if __ARM_NEON +REGISTER_NEON_OPERATOR(DepthwiseConv2d, DepthwiseConv2dOp); +#endif // __ARM_NEON + +} // namespace mace diff --git a/mace/ops/depthwise_conv2d.h b/mace/ops/depthwise_conv2d.h new file mode 100644 index 0000000000000000000000000000000000000000..cc220f3c5f5848bf5e989adc466c585153eb55d7 --- /dev/null +++ b/mace/ops/depthwise_conv2d.h @@ -0,0 +1,57 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_DEPTHWISE_CONV_H_ +#define MACE_OPS_DEPTHWISE_CONV_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/conv_2d.h" +#include "mace/ops/conv_pool_2d_base.h" +#include "mace/kernels/depthwise_conv2d.h" + +namespace mace { + +template +class DepthwiseConv2dOp : public ConvPool2dOpBase { + public: + DepthwiseConv2dOp(const OperatorDef &op_def, Workspace *ws) + : ConvPool2dOpBase(op_def, ws), + functor_(this->Input(INPUT)->shape().data(), + this->Input(FILTER)->shape().data(), + this->strides_.data(), this->padding_, this->dilations_.data()) {}; + + bool Run() override { + const Tensor *input = this->Input(INPUT); + const Tensor *filter = this->Input(FILTER); + const Tensor *bias = this->Input(BIAS); + Tensor *output = this->Output(OUTPUT); + + // resize filter shape. + std::vector filter_shape(filter->shape().begin(), filter->shape().end()); + filter_shape[0] *= filter_shape[1]; + filter_shape[1] = 1; + std::vector output_shape(4); + this->CalOutputSize(input->shape().data(), filter_shape.data(), output_shape.data()); + output->Resize(output_shape); + + functor_(input->data(), input->shape().data(), filter->data(), + filter_shape.data(), bias->data(), output->mutable_data(), + output->shape().data()); + + return true; + } + + private: + kernels::DepthwiseConv2dFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, FILTER, BIAS); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_DEPTHWISE_CONV_H_ diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2afe81171b55c2d578caeec8e6a6930ccf241cc --- /dev/null +++ b/mace/ops/depthwise_conv2d_test.cc @@ -0,0 +1,99 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/conv_2d.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; + +class DepthwiseConv2dOpTest : public OpsTestBase {}; + +TEST_F(DepthwiseConv2dOpTest, Simple_VALID) { + // Construct graph + auto& net = test_net(); + OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(net.operator_def()); + + // Add args + net.AddIntsArg("strides", {1, 1}); + net.AddIntArg("padding", Padding::VALID); + net.AddIntsArg("dilations", {1, 1}); + + // Add input data + net.AddInputFromArray( + "Input", {1, 2, 2, 3}, + {1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12}); + net.AddInputFromArray( + "Filter", {2, 2, 2, 2}, + {1.0f, 5.0f, 9.0f, 13.0f, + 2.0f, 6.0f, 10.0f, 14.0f, + 3.0f, 7.0f, 11.0f, 15.0f, + 4.0f, 8.0f, 12.0f, 16.0f}); + net.AddInputFromArray("Bias", {4}, {.1f, .2f, .3f, .4f}); + + // Run + net.RunOp(); + + // Check + auto expected = CreateTensor({1, 4, 1, 2}, + {196.1f, 252.1f, 216.2f, 280.2f, + 272.3f, 344.3f, 296.4f, 376.4f}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) { + testing::internal::LogToStderr(); + auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, + Padding type) { + srand(time(NULL)); + + // generate random input + index_t batch = 2 + rand() % 10; + index_t input_channels = 3 + rand() % 10; + index_t height = 107; + index_t width = 113; + index_t multiplier = 3 + rand() % 10; + // Construct graph + auto& net = test_net(); + OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(net.operator_def()); + + // Add args + net.AddIntsArg("strides", {stride_h, stride_w}); + net.AddIntArg("padding", type); + net.AddIntsArg("dilations", {1, 1}); + + // Add input data + net.AddRandomInput("Input", {batch, input_channels, height, width}); + net.AddRandomInput( + "Filter", {multiplier, input_channels, kernel_h, kernel_w}); + net.AddRandomInput("Bias", {multiplier * input_channels}); + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run NEON + net.RunOp(DeviceType::NEON); + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-3); + }; + + for (int kernel_size : {3}) { + for (int stride : {1, 2}) { + func(kernel_size, kernel_size, stride, stride, VALID); + func(kernel_size, kernel_size, stride, stride, SAME); + } + } +} diff --git a/mace/ops/depthwise_conv_2d_benchmark.cc b/mace/ops/depthwise_conv_2d_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..f535ea17273d028d01c4e56e8f7f32275c73eb33 --- /dev/null +++ b/mace/ops/depthwise_conv_2d_benchmark.cc @@ -0,0 +1,85 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/conv_2d.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { + +template +static void DepthwiseConv2d(int iters, + int batch, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int stride, + Padding padding, + int output_channels) { + mace::testing::StopTiming(); + + OpsTestNet net; + OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(net.operator_def()); + + // Add args + net.AddIntsArg("strides", {stride, stride}); + net.AddIntArg("padding", padding); + net.AddIntsArg("dilations", {1, 1}); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Filter", + {output_channels, channels, kernel_h, kernel_w}); + net.AddRandomInput("Bias", {output_channels}); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } +} + +#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \ + static void \ + BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot*(sizeof(TYPE))); \ + DepthwiseConv2d(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \ + } \ + BENCHMARK( \ + BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) + +#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ + BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \ + BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); + +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 3, float); +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float); +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3, float); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 3, float); +} // namespace mace