diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h new file mode 100644 index 0000000000000000000000000000000000000000..6f62ce176f910dcf32d6c3104f0005242f83375e --- /dev/null +++ b/mace/kernels/conv_2d.h @@ -0,0 +1,118 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_CONV_2D_H_ +#define MACE_KERNELS_CONV_2D_H_ + +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +template +class Conv2dFunctor { + public: + Conv2dFunctor(const int* strides, + const int* paddings, + const int* dilations) : + strides_(strides), + paddings_(paddings), + dilations_(dilations) {} + + void operator()(const T* input, // NCHW + const index_t* input_shape, + const T* filter, // kernel_h, kernel_w, c_in, c_out + const index_t* filter_shape, + const T* bias, // c_out + T* output, // NCHW + const index_t* output_shape) { + MACE_CHECK_NOTNULL(output); + + index_t batch = output_shape[0]; + index_t channels = output_shape[1]; + index_t height = output_shape[2]; + index_t width = output_shape[3]; + + index_t input_batch = input_shape[0]; + index_t input_channels = input_shape[1]; + index_t input_height = input_shape[2]; + index_t input_width = input_shape[3]; + + int kernel_h = filter_shape[0]; + int kernel_w = filter_shape[1]; + + int stride_h = strides_[0]; + int stride_w = strides_[1]; + + int dilation_h = dilations_[0]; + int dilation_w = dilations_[1]; + + MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); + + // The left-upper most offset of the padded input + int padded_h_start = 0 - paddings_[0]; + int padded_w_start = 0 - paddings_[1]; + int padded_h_stop = input_height + paddings_[0]; + int padded_w_stop = input_width + paddings_[1]; + + for (int n = 0; n < batch; ++n) { + #pragma omp parallel for + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + index_t offset = n * channels * height * width + + c * height * width + + h * width + w; + T sum = 0; + for (int inc = 0; inc < input_channels; ++inc) { + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + /* + * TODO The tensorflow filter order is HWCiCo. + * We should consider other order for different + * implementaion to optimize memory access. + */ + int filter_offset = kh * kernel_w * input_channels * channels + + kw * input_channels * channels + + inc * channels + c; + + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + if (inh < 0 || inh >= input_height || + inw < 0 || inw >= input_width) { + MACE_CHECK(inh >= padded_h_start && + inh < padded_h_stop && + inw >= padded_w_start && + inw < padded_w_stop, + "Out of range read from input: ", + inh, ", ", inw); + // else padding with 0: + // sum += 0; + } else { + index_t input_offset = + n * input_channels * input_height * input_width + + inc * input_height * input_width + + inh * input_width + inw; + sum += input[input_offset] * filter[filter_offset]; + } + } + } + output[offset] = sum + bias[c]; + } + } + } + } + } + } + + private: + const int* strides_; // [stride_h, stride_w] + const int* paddings_; // [padding_h, padding_w] + const int* dilations_; // [dilation_h, dilation_w] +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_CONV_2D_H_ diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 3c09daa38d9e9af7f2bc5bd8e139ed61ec8dab81..0acce0243c91f5d0ff3d1e234ada9695529b107d 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -83,3 +83,15 @@ cc_test( "@gtest//:gtest_main", ], ) + +cc_test( + name = "conv_2d_test", + srcs = ["conv_2d_test.cc",], + deps = [ + ":ops", + ":test", + "@gtest//:gtest_main", + ], + copts = ['-std=c++11'], + linkstatic = 1, +) diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index 533fa0c46c87bdb6479ea865928027ac2401d463..a236856b0ed8a27ebdae42c6f12ad25b5f9bfffa 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -7,25 +7,10 @@ namespace mace { -template <> -bool Conv2dOp::Run() { - const Tensor* input = Input(INPUT); - const Tensor* filter = Input(FILTER); - const Tensor* bias = Input(BIAS); - Tensor* output = Output(OUTPUT); - - - // Test - VLOG(0) << "conv_2d([" << kernels_[0] << ", " << kernels_[1] << "], )"; - const float* input_data = input->data(); - for (int i = 0; i < 6; ++i) { - VLOG(0) << input_data[i]; - } - - return true; -} - - REGISTER_CPU_OPERATOR(Conv2d, Conv2dOp); -} +#if __ARM_NEON +REGISTER_NEON_OPERATOR(Conv2d, Conv2dOp); +#endif // __ARM_NEON + +} // namespace mace diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 61b7f88b67f1011c8fe2382ef3180250b45e5ba9..5707d82685fbf2d82b2bb763f0cdf12c2e8dce24 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -5,28 +5,43 @@ #ifndef MACE_OPS_CONV_2D_H_ #define MACE_OPS_CONV_2D_H_ +#include + #include "mace/core/operator.h" +#include "mace/kernels/conv_2d.h" +#include "mace/ops/conv_pool_2d_base.h" namespace mace { -template -class Conv2dOp : public Operator { +template +class Conv2dOp : public ConvPool2dOpBase { public: - Conv2dOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws), - kernels_(OperatorBase::GetRepeatedArgument("kernels")), - strides_(OperatorBase::GetRepeatedArgument("strides")), - paddings_(OperatorBase::GetRepeatedArgument("paddings")), - dilations_(OperatorBase::GetRepeatedArgument("dilations")) {} - - bool Run() override; - - private: - vector kernels_; - vector strides_; - vector paddings_; - vector dilations_; - + Conv2dOp(const OperatorDef &op_def, Workspace *ws) + : ConvPool2dOpBase(op_def, ws) {}; + + bool Run() override { + const Tensor* input = this->Input(INPUT); + const Tensor* filter = this->Input(FILTER); + const Tensor* bias = this->Input(BIAS); + Tensor* output = this->Output(OUTPUT); + + std::vector output_shape; + std::vector paddings; + this->CalcPaddingAndOutputSize(input, filter, &output_shape, &paddings); + output->Resize(output_shape); + + auto conv2d = kernels::Conv2dFunctor(this->strides_.data(), + paddings.data(), + this->dilations_.data()); + conv2d(input->data(), input->shape().data(), + filter->data(), filter->shape().data(), + bias->data(), output->mutable_data(), + output->shape().data()); + + return true; + } + + protected: OP_INPUT_TAGS(INPUT, FILTER, BIAS); OP_OUTPUT_TAGS(OUTPUT); }; diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebde143f4d49ac7624c98252495a8445781ae7e2 --- /dev/null +++ b/mace/ops/conv_2d_test.cc @@ -0,0 +1,144 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "mace/ops/conv_2d.h" + +using namespace mace; + +class Conv2dOpTest : public OpsTestBase {}; + +TEST_F(Conv2dOpTest, Simple_VALID) { + // Construct graph + OpDefBuilder("Conv2d", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("strides", {1, 1}); + AddIntArg("padding", static_cast(Conv2dOp::Padding::VALID)); + AddIntsArg("dilations", {1, 1}); + + // Add input data + AddInputFromArray("Input", {1, 2, 3, 3}, + {1, 1, 1, + 1, 1, 1, + 1, 1, 1, + 1, 1, 1, + 1, 1, 1, + 1, 1, 1}); + AddInputFromArray("Filter", {3, 3, 2, 1}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); + AddInputFromArray("Bias", {1}, {0.1f}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 1, 1, 1}, {18.1f}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +} + +TEST_F(Conv2dOpTest, Simple_SAME) { + // Construct graph + OpDefBuilder("Conv2d", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("strides", {1, 1}); + AddIntArg("padding", static_cast(Conv2dOp::Padding::SAME)); + AddIntsArg("dilations", {1, 1}); + + // Add input data + AddInputFromArray("Input", {1, 2, 3, 3}, + {1, 1, 1, + 1, 1, 1, + 1, 1, 1, + 1, 1, 1, + 1, 1, 1, + 1, 1, 1}); + AddInputFromArray("Filter", {3, 3, 2, 1}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); + AddInputFromArray("Bias", {1}, {0.1f}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 1, 3, 3}, + { 8.1f, 12.1f, 8.1f, + 12.1f, 18.1f, 12.1f, + 8.1f, 12.1f, 8.1f}); + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +} + +TEST_F(Conv2dOpTest, Combined) { + // Construct graph + OpDefBuilder("Conv2d", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .Finalize(operator_def()); + + // Add args + AddIntsArg("strides", {2, 2}); + AddIntArg("padding", static_cast(Conv2dOp::Padding::SAME)); + AddIntsArg("dilations", {1, 1}); + + // Add input data + AddInputFromArray("Input", {1, 2, 5, 5}, + {1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1}); + AddInputFromArray("Filter", {3, 3, 2, 2}, + {1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f, + 1.0f, 0.5f, 1.0f, 0.5f}); + AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); + + // Run + RunOp(); + + // Check + Tensor expected = CreateTensor({1, 2, 3, 3}, + { 8.1f, 12.1f, 8.1f, + 12.1f, 18.1f, 12.1f, + 8.1f, 12.1f, 8.1f, + 4.2f, 6.2f, 4.2f, + 6.2f, 9.2f, 6.2f, + 4.2f, 6.2f, 4.2f}); + + + ExpectTensorNear(expected, *GetOutput("Output"), 0.001); +} + +// TODO we need more tests diff --git a/mace/ops/conv_pool_2d_base.h b/mace/ops/conv_pool_2d_base.h new file mode 100644 index 0000000000000000000000000000000000000000..e4db6e90f8d3d1665b40c8352eb80bb9f15c3212 --- /dev/null +++ b/mace/ops/conv_pool_2d_base.h @@ -0,0 +1,83 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_CONV_POOL_2D_BASE_H_ +#define MACE_OPS_CONV_POOL_2D_BASE_H_ + +#include "mace/core/operator.h" + +namespace mace { + +template +class ConvPool2dOpBase : public Operator { + public: + ConvPool2dOpBase(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + strides_(OperatorBase::GetRepeatedArgument("strides")), + padding_(static_cast( + OperatorBase::GetSingleArgument("padding", + static_cast(SAME)))), + dilations_(OperatorBase::GetRepeatedArgument("dilations")) {} + + void CalcPaddingAndOutputSize(const Tensor* input, + const Tensor* filter, + std::vector* output_shape, + std::vector* padding_size) { + MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0, + "Invalid dilations, must >= 1"); + /* + * Convlution/pooling arithmetic: + * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 + * For details, see https://arxiv.org/pdf/1603.07285.pdf or + * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html + */ + auto& input_shape = input->shape(); + auto& filter_shape = filter->shape(); // HWIO + int kernel_h = filter_shape[0]; + int kernel_w = filter_shape[1]; + int output_channel = filter_shape[3]; + MACE_CHECK(input_shape[1] == filter_shape[2], + input_shape[1], " != ", filter_shape[2]); + + *padding_size = {0, 0}; + switch (padding_) { + case VALID: + break; + case SAME: + (*padding_size)[0] = kernel_h / 2; + (*padding_size)[1] = kernel_w / 2; + break; + case FULL: + (*padding_size)[0] = kernel_h - 1; + (*padding_size)[1] = kernel_w - 1; + break; + default: + MACE_CHECK(false, "Unsupported padding type: ", padding_); + } + *output_shape = std::vector(4); // NCHW + (*output_shape)[0] = input_shape[0]; + (*output_shape)[1] = output_channel; + (*output_shape)[2] = (input_shape[2] + 2 * (*padding_size)[0] - kernel_h - + (kernel_h - 1) * (dilations_[0] - 1)) / + strides_[0] + 1; + (*output_shape)[3] = (input_shape[3] + 2 * (*padding_size)[1] - kernel_w - + (kernel_w - 1) * (dilations_[1] - 1)) / + strides_[1] + 1; + } + + enum Padding { + VALID = 0, // No padding + SAME = 1, // Pads with half the filter size (rounded down) on both sides + FULL = 2, // Pads with one less than the filter size on both sides + }; + + protected: + std::vector strides_; + Padding padding_; + std::vector dilations_; +}; + +} // namespace mace + +#endif // MACE_OPS_CONV_POOL_2D_BASE_H_ diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 0e96943c60085014bc01c65323882ebc0480249e..f61d4b19af319c154d2d99f5240138678372fce9 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -5,6 +5,8 @@ #ifndef MACE_OPS_TEST_UTIL_H_ #define MACE_OPS_TEST_UTIL_H_ +#include + #include "gtest/gtest.h" #include "mace/core/common.h" #include "mace/core/tensor.h" @@ -50,6 +52,48 @@ class OpsTestBase : public ::testing::Test { memcpy(input_data, data.data(), data.size() * sizeof(T)); } + void AddIntArg(const char* name, const int value) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + arg->set_i(value); + } + + void AddFloatArg(const char* name, const float value) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + arg->set_f(value); + } + + void AddStringArg(const char* name, const char* value) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + arg->set_s(value); + } + + void AddIntsArg(const char* name, const std::vector& values) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + for (auto value : values) { + arg->add_ints(value); + } + } + + void AddFloatsArg(const char* name, const std::vector& values) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + for (auto value : values) { + arg->add_floats(value); + } + } + + void AddStringsArg(const char* name, const std::vector& values) { + auto arg = op_def_.add_arg(); + arg->set_name(name); + for (auto value : values) { + arg->add_strings(value); + } + } + OperatorDef* operator_def() { return &op_def_; } bool RunOp() {