diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index fc71b187fdce52c4cb09ad220e05ff1d4c89fb07..8d45861a1b17ab2e1c59b217723aa6d30d962d63 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -4,16 +4,42 @@ #include #include "mace/kernels/conv_2d.h" -#include "mace/kernels/neon/conv_2d_neon_base.h" +#include "mace/kernels/neon/conv_2d_neon_3x3.h" namespace mace { namespace kernels { static inline void ConstructInputWithPadding(const float* input, const index_t* input_shape, - const int* padding, - std::unique_ptr& output, - index_t* output_shape) { + const int* paddings, + Tensor& output_tensor, + std::vector& output_shape) { + index_t batch = input_shape[0]; + index_t channels = input_shape[1]; + index_t height = input_shape[2]; + index_t width = input_shape[3]; + output_shape[0] = batch; + output_shape[1] = channels; + output_shape[2] = paddings[0] + height; + output_shape[3] = paddings[1] + width; + index_t output_width = output_shape[3]; + int padded_left = paddings[1] / 2; + + output_tensor.Resize(output_shape); + float* output_ptr = output_tensor.mutable_data(); + memset(output_ptr, 0, output_tensor.size() * sizeof(float)); + output_ptr += paddings[0] / 2 * output_width; + + for (; batch > 0; --batch) { + for (; channels > 0; --channels) { + for(; height > 0; --height) { + memcpy(output_ptr+padded_left, input, width*sizeof(float)); + input += width; + output_ptr += output_width; + } + output_ptr += paddings[0] * output_width; + } + } } template<> @@ -25,18 +51,39 @@ void Conv2dFunctor::operator()(const float* input, // N float* output, // NCHW const index_t* output_shape) { - static const bool selector[5][4] = { - {true, false, false, false}, - {false, false, false, false}, - {true, true, false, false}, - {false, false, false, false}, - {true, false, false, false}, + typedef void (*Conv2dNeonFunction)(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape); + static const Conv2dNeonFunction selector[5][2] = { + { + nullptr, + nullptr + }, + { + nullptr, + nullptr + }, + { + Conv2dNeonK3x3S1, + nullptr + }, + { + nullptr, + nullptr + }, + { + nullptr, + nullptr + } }; // not implement yet if (paddings_[0] != paddings_[1] || paddings_[0] > 5 || strides_[0] != strides_[1] || strides_[0] > 4 || dilations_[0] != 1 || dilations_[1] != 1 || - !selector[paddings_[0]-1, strides_[0]-1]) { + selector[paddings_[0]-1][strides_[0]-1] == nullptr) { Conv2dFunctor(strides_, paddings_, dilations_)( input, input_shape, @@ -47,12 +94,13 @@ void Conv2dFunctor::operator()(const float* input, // N output_shape ); } - std::unique_ptr padded_input; - index_t padded_input_shape[4]; + Tensor padded_input; + std::vector padded_input_shape(4); ConstructInputWithPadding(input, input_shape, paddings_, padded_input, padded_input_shape); - Conv2dNeon( - padded_input.get(), - padded_input_shape, + auto conv2d_neon_func = selector[paddings_[0] - 1][strides_[0] - 1]; + conv2d_neon_func( + padded_input.data(), + padded_input_shape.data(), filter, bias, output, diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.h similarity index 66% rename from mace/kernels/neon/conv_2d_neon_3x3.cc rename to mace/kernels/neon/conv_2d_neon_3x3.h index bc88dc494a459ab1ec53c7f1ac8c2c9173e65318..9916e3e03dd6bf4139aa32dbc487c7447119f425 100644 --- a/mace/kernels/neon/conv_2d_neon_3x3.cc +++ b/mace/kernels/neon/conv_2d_neon_3x3.h @@ -1,15 +1,16 @@ // // Copyright (c) 2017 XiaoMi All rights reserved. // +#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_ +#define MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_ #include -#include "mace/kernels/neon/conv_2d_neon_base.h" +#include "mace/core/common.h" namespace mace { namespace kernels { -template<> -void Conv2dNeon<3, 3, 1, 1>(const float* input, // NCHW +void Conv2dNeonK3x3S1(const float* input, // NCHW const index_t* input_shape, const float* filter, // c_out, c_in, kernel_h, kernel_w const float* bias, // c_out @@ -20,3 +21,5 @@ void Conv2dNeon<3, 3, 1, 1>(const float* input, // NCHW } // namespace kernels } // namespace mace + +#endif // MACE_KERNELS_NEON_CONV_2D_NEON_3X3_H_ diff --git a/mace/kernels/neon/conv_2d_neon_base.h b/mace/kernels/neon/conv_2d_neon_base.h deleted file mode 100644 index 3e93b51a8ba5020a0bb56aee232d189cb7382de7..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/conv_2d_neon_base.h +++ /dev/null @@ -1,24 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_ -#define MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_ - -#include "mace/core/common.h" - -namespace mace { -namespace kernels { - -template -inline void Conv2dNeon(const float* input, // NCHW - const index_t* input_shape, - const float* filter, // c_out, c_in, kernel_h, kernel_w - const float* bias, // c_out - float* output, // NCHW - const index_t* output_shape); - -} // namespace kernels -} // namespace mace - -#endif // MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_