diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 6f62ce176f910dcf32d6c3104f0005242f83375e..f20c427b26e8e2d3718bd6184087e5fc20f21157 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -51,13 +51,13 @@ class Conv2dFunctor { MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); // The left-upper most offset of the padded input - int padded_h_start = 0 - paddings_[0]; - int padded_w_start = 0 - paddings_[1]; - int padded_h_stop = input_height + paddings_[0]; - int padded_w_stop = input_width + paddings_[1]; + int padded_h_start = 0 - paddings_[0] / 2; + int padded_w_start = 0 - paddings_[1] / 2; + int padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2; + int padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2; +#pragma omp parallel for collpse(2) for (int n = 0; n < batch; ++n) { - #pragma omp parallel for for (int c = 0; c < channels; ++c) { for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..8a35919489cd124edc6919ad4fa998d5ee91b754 --- /dev/null +++ b/mace/kernels/pooling.h @@ -0,0 +1,132 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_POOLING_H +#define MACE_KERNELS_POOLING_H + +#include +#include "mace/core/tensor.h" + +namespace mace { + +enum PoolingType { + AVG = 1, // avg_pool + MAX = 2, // max_pool +}; + +namespace kernels { + +template +class PoolingFunctor { +public: + PoolingFunctor(const PoolingType pooling_type, + const int* kernels, + const int* strides, + const int* paddings, + const int* dilations) + : pooling_type_(pooling_type), + kernels_(kernels), + strides_(strides), + paddings_(paddings), + dilations_(dilations) {} + + void operator()(const T* input, + const index_t* input_shape, + T* output, + const index_t* output_shape) { + index_t batch = output_shape[0]; + index_t channels = output_shape[1]; + index_t height = output_shape[2]; + index_t width = output_shape[3]; + + index_t input_channels = input_shape[1]; + index_t input_height = input_shape[2]; + index_t input_width = input_shape[3]; + + int kernel_h = kernels_[0]; + int kernel_w = kernels_[1]; + + int stride_h = strides_[0]; + int stride_w = strides_[1]; + + int dilation_h = dilations_[0]; + int dilation_w = dilations_[1]; + + // The left-upper most offset of the padded input + int padded_h_start = 0 - paddings_[0] / 2; + int padded_w_start = 0 - paddings_[1] / 2; + int padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2; + int padded_w_stop = input_width + paddings_[1] - paddings_[0] / 2; + +#pragma omp parallel for collpse(2) + for (int n = 0; n < batch; ++n) { + for (int c = 0; c < channels; ++c) { + index_t out_offset = n * channels * height * width + + c * height * width; + index_t in_offset = n * input_channels * input_height * input_width + + c * input_height * input_width; + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + T sum_or_max = 0; + switch (pooling_type_) { + case AVG: + break; + case MAX: + sum_or_max = std::numeric_limits::lowest(); + break; + default: + MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); + } + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + if (inh >= 0 && inh < input_height && + inw >= 0 && inw < input_width) { + index_t input_offset = in_offset + + inh * input_width + inw; + switch (pooling_type_) { + case AVG: + sum_or_max += input[input_offset]; + break; + case MAX: + sum_or_max = std::max(sum_or_max, input[input_offset]); + break; + default: + MACE_CHECK(false, "Unsupported pooling type: ", + pooling_type_); + } + } + } + } + switch (pooling_type_) { + case AVG: + output[out_offset] = sum_or_max / (kernel_h * kernel_w); + break; + case MAX: + output[out_offset] = sum_or_max; + break; + default: + MACE_CHECK(false, "Unsupported pooling type: ", pooling_type_); + } + out_offset += 1; + } + } + } + } + } + +private: + const PoolingType pooling_type_; + const int* kernels_; + const int* strides_; + const int* paddings_; + const int* dilations_; +}; + + +} // namespace kernels +} // namespace mace + +#endif //MACE_KERNELS_POOLING_H diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 1049ca12d512ae3df801e6a233236f3db786b4dc..2adc5fbe3cb3bf85e574dfb2929fb7371dca9b87 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -41,9 +41,16 @@ cc_library( ) cc_test( - name = "batch_norm_test", - srcs = ["batch_norm_test.cc"], + name = "ops_test", + srcs = glob( + ["*_test.cc"], + ), copts = ["-std=c++11"], + linkopts = if_android([ + "-pie", + "-llog", + "-latomic", + ]), linkstatic = 1, deps = [ ":ops", diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 5707d82685fbf2d82b2bb763f0cdf12c2e8dce24..7ce4e69a020e123c1ea0a9f3249dc979c9a958f6 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -16,7 +16,7 @@ namespace mace { template class Conv2dOp : public ConvPool2dOpBase { public: - Conv2dOp(const OperatorDef &op_def, Workspace *ws) + Conv2dOp(const OperatorDef& op_def, Workspace* ws) : ConvPool2dOpBase(op_def, ws) {}; bool Run() override { @@ -27,7 +27,10 @@ class Conv2dOp : public ConvPool2dOpBase { std::vector output_shape; std::vector paddings; - this->CalcPaddingAndOutputSize(input, filter, &output_shape, &paddings); + this->CalcPaddingAndOutputSize(input->shape().data(), + filter->shape().data(), + &output_shape, + &paddings); output->Resize(output_shape); auto conv2d = kernels::Conv2dFunctor(this->strides_.data(), diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index ebde143f4d49ac7624c98252495a8445781ae7e2..6ab3de8da6c349477341154f7f31043c28f276ca 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -21,7 +21,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) { // Add args AddIntsArg("strides", {1, 1}); - AddIntArg("padding", static_cast(Conv2dOp::Padding::VALID)); + AddIntArg("padding", Padding::VALID); AddIntsArg("dilations", {1, 1}); // Add input data @@ -58,7 +58,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) { // Add args AddIntsArg("strides", {1, 1}); - AddIntArg("padding", static_cast(Conv2dOp::Padding::SAME)); + AddIntArg("padding", Padding::SAME); AddIntsArg("dilations", {1, 1}); // Add input data @@ -98,7 +98,7 @@ TEST_F(Conv2dOpTest, Combined) { // Add args AddIntsArg("strides", {2, 2}); - AddIntArg("padding", static_cast(Conv2dOp::Padding::SAME)); + AddIntArg("padding", Padding::SAME); AddIntsArg("dilations", {1, 1}); // Add input data diff --git a/mace/ops/conv_pool_2d_base.h b/mace/ops/conv_pool_2d_base.h index e4db6e90f8d3d1665b40c8352eb80bb9f15c3212..36939e31c63cb116b776e4784f011672c57b3fa2 100644 --- a/mace/ops/conv_pool_2d_base.h +++ b/mace/ops/conv_pool_2d_base.h @@ -9,10 +9,16 @@ namespace mace { +enum Padding { + VALID = 0, // No padding + SAME = 1, // Pads with half the filter size (rounded down) on both sides + FULL = 2, // Pads with one less than the filter size on both sides +}; + template class ConvPool2dOpBase : public Operator { public: - ConvPool2dOpBase(const OperatorDef &op_def, Workspace *ws) + ConvPool2dOpBase(const OperatorDef& op_def, Workspace* ws) : Operator(op_def, ws), strides_(OperatorBase::GetRepeatedArgument("strides")), padding_(static_cast( @@ -20,58 +26,65 @@ class ConvPool2dOpBase : public Operator { static_cast(SAME)))), dilations_(OperatorBase::GetRepeatedArgument("dilations")) {} - void CalcPaddingAndOutputSize(const Tensor* input, - const Tensor* filter, + void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW + const index_t* filter_shape, // HWIO std::vector* output_shape, std::vector* padding_size) { MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0, - "Invalid dilations, must >= 1"); + "Invalid dilations, must >= 1"); + MACE_CHECK((dilations_[0] == 1 || strides_[0] == 1) && + (dilations_[1] == 1 || strides_[1] == 1), + "If dilations > 1, strides should be 1"); /* - * Convlution/pooling arithmetic: - * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 - * For details, see https://arxiv.org/pdf/1603.07285.pdf or - * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html - */ - auto& input_shape = input->shape(); - auto& filter_shape = filter->shape(); // HWIO - int kernel_h = filter_shape[0]; - int kernel_w = filter_shape[1]; - int output_channel = filter_shape[3]; - MACE_CHECK(input_shape[1] == filter_shape[2], - input_shape[1], " != ", filter_shape[2]); - + * Convlution/pooling arithmetic: + * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 + * For details, see https://arxiv.org/pdf/1603.07285.pdf or + * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html + */ *padding_size = {0, 0}; + + index_t output_height, output_width; + index_t kernel_height = filter_shape[0]; + index_t kernel_width = filter_shape[1]; + index_t output_channels = filter_shape[3]; + + int k_extent_height = (kernel_height - 1) * dilations_[0] + 1; + int k_extent_width = (kernel_width - 1) * dilations_[1] + 1; + switch (padding_) { case VALID: + output_height = (input_shape[2] - k_extent_height) / strides_[0] + 1; + output_width = (input_shape[3] - k_extent_width) / strides_[1] + 1; break; case SAME: - (*padding_size)[0] = kernel_h / 2; - (*padding_size)[1] = kernel_w / 2; + output_height = (input_shape[2] - 1) / strides_[0] + 1; + output_width = (input_shape[3] - 1) / strides_[1] + 1; break; case FULL: - (*padding_size)[0] = kernel_h - 1; - (*padding_size)[1] = kernel_w - 1; + output_height = (input_shape[2] + k_extent_height - 2) / strides_[0] + 1; + output_width = (input_shape[3] + k_extent_width - 2) / strides_[1] + 1; break; default: - MACE_CHECK(false, "Unsupported padding type: ", padding_); + MACE_CHECK(false, "Unsupported padding type: ", this->padding_); } + + // Note: TensorFlow may padded one more on the right/bottom side + // TODO may be it's better to also truncate the left/top to + // utilize the more centered features. We need to benchmark + // based on the model accuracy. + + (*padding_size)[0] = (output_height - 1) * strides_[0] + + k_extent_height - input_shape[2]; + (*padding_size)[1] = (output_width - 1) * strides_[1] + + k_extent_width - input_shape[3]; + *output_shape = std::vector(4); // NCHW (*output_shape)[0] = input_shape[0]; - (*output_shape)[1] = output_channel; - (*output_shape)[2] = (input_shape[2] + 2 * (*padding_size)[0] - kernel_h - - (kernel_h - 1) * (dilations_[0] - 1)) / - strides_[0] + 1; - (*output_shape)[3] = (input_shape[3] + 2 * (*padding_size)[1] - kernel_w - - (kernel_w - 1) * (dilations_[1] - 1)) / - strides_[1] + 1; + (*output_shape)[1] = output_channels; + (*output_shape)[2] = output_height; + (*output_shape)[3] = output_width; } - enum Padding { - VALID = 0, // No padding - SAME = 1, // Pads with half the filter size (rounded down) on both sides - FULL = 2, // Pads with one less than the filter size on both sides - }; - protected: std::vector strides_; Padding padding_; diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..915035c858bb66d5f1aa7f61f4717187b2f11b5e --- /dev/null +++ b/mace/ops/pooling.cc @@ -0,0 +1,14 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + + +#include "mace/ops/pooling.h" +#include "mace/proto/mace.pb.h" +#include "mace/kernels/pooling.h" + +namespace mace { + +REGISTER_CPU_OPERATOR(Pooling, PoolingOp); + +} // namespace mace diff --git a/mace/ops/pooling.h b/mace/ops/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..bc62a075a3b6fe75b40caa57d31cb7f29a6ad9a7 --- /dev/null +++ b/mace/ops/pooling.h @@ -0,0 +1,62 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_POOLING_H_ +#define MACE_OPS_POOLING_H_ + +#include "mace/core/operator.h" +#include "mace/ops/conv_pool_2d_base.h" +#include "mace/kernels/pooling.h" + +namespace mace { + +template +class PoolingOp : public ConvPool2dOpBase { +public: + PoolingOp(const OperatorDef& op_def, Workspace* ws) + : ConvPool2dOpBase(op_def, ws), + kernels_(OperatorBase::GetRepeatedArgument("kernels")), + pooling_type_(static_cast( + OperatorBase::GetSingleArgument( + "pooling_type", static_cast(AVG)))) {}; + + bool Run() override{ + const Tensor* input = this->Input(INPUT); + Tensor* output = this->Output(OUTPUT); + std::vector in_shape = input->shape(); + + std::vector output_shape; + std::vector paddings; + std::vector filter_shape = std::vector(4); + filter_shape[0] = kernels_[0]; + filter_shape[1] = kernels_[1]; + filter_shape[2] = in_shape[0]; + filter_shape[3] = in_shape[1]; + this->CalcPaddingAndOutputSize(in_shape.data(), filter_shape.data(), + &output_shape, &paddings); + output->Resize(output_shape); + + auto pooling_func = kernels::PoolingFunctor(pooling_type_, + kernels_.data(), + this->strides_.data(), + paddings.data(), + this->dilations_.data()); + pooling_func(input->data(), + in_shape.data(), + output->mutable_data(), + output->shape().data()); + return true; + }; + +protected: + PoolingType pooling_type_; + std::vector kernels_; + + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif //MACE_OPS_POOLING_H_ diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2484677a3b1b4bddc6ab4c72839a528eb0714fa --- /dev/null +++ b/mace/ops/pooling_test.cc @@ -0,0 +1,147 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "gtest/gtest.h" + +#include "mace/core/operator.h" +#include "mace/core/net.h" +#include "mace/ops/ops_test_util.h" +#include "mace/ops/conv_pool_2d_base.h" +#include "mace/kernels/pooling.h" + +using namespace mace; + +class PoolingOpTest : public OpsTestBase {}; + +TEST_F(PoolingOpTest, MAX_VALID) { + // Construct graph + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("kernels", {2, 2}); + AddIntsArg("strides", {2, 2}); + AddIntArg("padding", Padding::VALID); + AddIntsArg("dilations", {1, 1}); + AddIntArg("pooling_type", PoolingType::MAX); + + // Add input data + AddInputFromArray("Input", {1, 2, 4, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + 24, 25, 26, 27, + 28, 29, 30, 31}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 2, 2, 2}, + {5, 7, 13, 15, 21, 23, 29, 31}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +} + + +TEST_F(PoolingOpTest, AVG_VALID) { + // Construct graph + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("kernels", {2, 2}); + AddIntsArg("strides", {2, 2}); + AddIntArg("padding", Padding::VALID); + AddIntsArg("dilations", {1, 1}); + AddIntArg("pooling_type", PoolingType::AVG); + + // Add input data + AddInputFromArray("Input", {1, 2, 4, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + 24, 25, 26, 27, + 28, 29, 30, 31}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 2, 2, 2}, + {2.5, 4.5, 10.5, 12.5, 18.5, 20.5, 26.5, 28.5}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +} + +TEST_F(PoolingOpTest, MAX_SAME) { + // Construct graph + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("kernels", {2, 2}); + AddIntsArg("strides", {2, 2}); + AddIntArg("padding", Padding::SAME); + AddIntsArg("dilations", {1, 1}); + AddIntArg("pooling_type", PoolingType::MAX); + + // Add input data + AddInputFromArray("Input", {1, 1, 3, 3}, + {0, 1, 2, + 3, 4, 5, + 6, 7, 8}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 1, 2, 2}, + {4, 5, 7, 8}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +} + +TEST_F(PoolingOpTest, MAX_VALID_DILATION) { + // Construct graph + OpDefBuilder("Pooling", "PoolingTest") + .Input("Input") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("kernels", {2, 2}); + AddIntsArg("strides", {1, 1}); + AddIntArg("padding", Padding::VALID); + AddIntsArg("dilations", {2, 2}); + AddIntArg("pooling_type", PoolingType::MAX); + + // Add input data + AddInputFromArray("Input", {1, 1, 4, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 1, 2, 2}, + {10, 11, 14, 15}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +}