diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 2e1a145a3aec2c37e77eecb7f853db4fd99f75e0..a102ccef4075ec133158725b42a83ceb3b5a4411 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -108,6 +108,15 @@ class Conv2dFunctor { const int* dilations_; // [dilation_h, dilation_w] }; +template<> +void Conv2dFunctor::operator()(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const index_t* filter_shape, + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9678fd8e2e2b6e44320c5182aae7f3d0a2f1a20 --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -0,0 +1,64 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/kernels/conv_2d.h" +#include "mace/kernels/neon/conv_2d_neon_base.h" + +namespace mace { +namespace kernels { + +static inline void ConstructInputWithPadding(const float* input, const index_t* input_shape, + const int* padding, + std::unique_ptr& output, + index_t* output_shape) { + +} + +template<> +void Conv2dFunctor::operator()(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const index_t* filter_shape, + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape) { + + static const bool selector[5][4] = { + {true, false, false, false}, + {false, false, false, false}, + {true, true, false, false}, + {false, false, false, false}, + {true, false, false, false}, + }; + // not implement yet + if (paddings_[0] != paddings_[1] || paddings_[0] > 5 || + strides_[0] != strides_[1] || strides_[0] > 4 || + dilations_[0] != 1 || dilations_[1] != 1 || + !selector[paddings_[0]-1, strides_[0]-1]) { + Conv2dFunctor(strides_, paddings_, dilations_)( + input, + input_shape, + filter, + filter_shape, + bias, + output, + output_shape + ); + } + std::unique_ptr padded_input; + index_t padded_input_shape[4]; + ConstructInputWithPadding(input, input_shape, paddings_, padded_input, padded_input_shape); + Conv2dNeon( + padded_input.get(), + padded_input_shape, + filter, + bias, + output, + output_shape + ); +} + +} // namespace kernels +} // namespace mace \ No newline at end of file diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc new file mode 100644 index 0000000000000000000000000000000000000000..4397076d162f5a61c277c4bacb65a7e68512c4fe --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -0,0 +1,22 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/kernels/neon/conv_2d_neon_base.h" + +namespace mace { +namespace kernels { + +template<> +void Conv2dNeon<3, 3, 1, 1>(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape) { + +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon_base.h b/mace/kernels/neon/conv_2d_neon_base.h new file mode 100644 index 0000000000000000000000000000000000000000..4e0b5024d77d3032bd6ff6cc5ad3b8a389458936 --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon_base.h @@ -0,0 +1,24 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_ +#define MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_ + +#include "mace/core/common.h" + +namespace mace { +namespace kernels { + +template +inline void Conv2dNeon(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape); + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_