From c3e881847a856ad122b322ecf35b9b38b84ec363 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 13 Sep 2017 10:51:52 +0800 Subject: [PATCH] Add conv2d neon api. --- mace/kernels/conv_2d.h | 9 ++++ mace/kernels/neon/conv_2d_neon.cc | 64 +++++++++++++++++++++++++++ mace/kernels/neon/conv_2d_neon_3x3.cc | 22 +++++++++ mace/kernels/neon/conv_2d_neon_base.h | 24 ++++++++++ 4 files changed, 119 insertions(+) create mode 100644 mace/kernels/neon/conv_2d_neon.cc create mode 100644 mace/kernels/neon/conv_2d_neon_3x3.cc create mode 100644 mace/kernels/neon/conv_2d_neon_base.h diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 2e1a145a..a102ccef 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -108,6 +108,15 @@ class Conv2dFunctor { const int* dilations_; // [dilation_h, dilation_w] }; +template<> +void Conv2dFunctor::operator()(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const index_t* filter_shape, + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc new file mode 100644 index 00000000..c9678fd8 --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -0,0 +1,64 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/kernels/conv_2d.h" +#include "mace/kernels/neon/conv_2d_neon_base.h" + +namespace mace { +namespace kernels { + +static inline void ConstructInputWithPadding(const float* input, const index_t* input_shape, + const int* padding, + std::unique_ptr& output, + index_t* output_shape) { + +} + +template<> +void Conv2dFunctor::operator()(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const index_t* filter_shape, + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape) { + + static const bool selector[5][4] = { + {true, false, false, false}, + {false, false, false, false}, + {true, true, false, false}, + {false, false, false, false}, + {true, false, false, false}, + }; + // not implement yet + if (paddings_[0] != paddings_[1] || paddings_[0] > 5 || + strides_[0] != strides_[1] || strides_[0] > 4 || + dilations_[0] != 1 || dilations_[1] != 1 || + !selector[paddings_[0]-1, strides_[0]-1]) { + Conv2dFunctor(strides_, paddings_, dilations_)( + input, + input_shape, + filter, + filter_shape, + bias, + output, + output_shape + ); + } + std::unique_ptr padded_input; + index_t padded_input_shape[4]; + ConstructInputWithPadding(input, input_shape, paddings_, padded_input, padded_input_shape); + Conv2dNeon( + padded_input.get(), + padded_input_shape, + filter, + bias, + output, + output_shape + ); +} + +} // namespace kernels +} // namespace mace \ No newline at end of file diff --git a/mace/kernels/neon/conv_2d_neon_3x3.cc b/mace/kernels/neon/conv_2d_neon_3x3.cc new file mode 100644 index 00000000..4397076d --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon_3x3.cc @@ -0,0 +1,22 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include "mace/kernels/neon/conv_2d_neon_base.h" + +namespace mace { +namespace kernels { + +template<> +void Conv2dNeon<3, 3, 1, 1>(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape) { + +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/neon/conv_2d_neon_base.h b/mace/kernels/neon/conv_2d_neon_base.h new file mode 100644 index 00000000..4e0b5024 --- /dev/null +++ b/mace/kernels/neon/conv_2d_neon_base.h @@ -0,0 +1,24 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_ +#define MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_ + +#include "mace/core/common.h" + +namespace mace { +namespace kernels { + +template +inline void Conv2dNeon(const float* input, // NCHW + const index_t* input_shape, + const float* filter, // c_out, c_in, kernel_h, kernel_w + const float* bias, // c_out + float* output, // NCHW + const index_t* output_shape); + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_NEON_CONV_2D_NEON_BASE_H_ -- GitLab