diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index a63a4319b5632197304dcbf3fbc239cd4ac66b06..536b28ad2b86899358f13bc5f380017efdd00ff0 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -6,23 +6,39 @@ #define MACE_KERNELS_CONV_2D_H_ #include "mace/core/tensor.h" +#include "mace/kernels/conv_pool_2d_util.h" namespace mace { namespace kernels { -template +template class Conv2dFunctor { public: - Conv2dFunctor(const int* strides, const int* paddings, const int* dilations) - : strides_(strides), paddings_(paddings), dilations_(dilations) {} - - void operator()(const T* input, // NCHW - const index_t* input_shape, - const T* filter, // c_out, c_in, kernel_h, kernel_w - const index_t* filter_shape, - const T* bias, // c_out - T* output, // NCHW - const index_t* output_shape) { + Conv2dFunctor(const index_t *input_shape, + const index_t *filter_shape, + const int *strides, + const Padding padding, + const int *dilations) : + strides_(strides), + paddings_(2, 0), + dilations_(dilations) { + CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data()); + } + + Conv2dFunctor(const int *strides, + const std::vector &paddings, + const int *dilations) : + strides_(strides), + paddings_(paddings), + dilations_(dilations) {} + + void operator()(const T *input, // NCHW + const index_t *input_shape, + const T *filter, // c_out, c_in, kernel_h, kernel_w + const index_t *filter_shape, + const T *bias, // c_out + T *output, // NCHW + const index_t *output_shape) { MACE_CHECK_NOTNULL(output); index_t batch = output_shape[0]; @@ -60,9 +76,9 @@ class Conv2dFunctor { for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { index_t offset = n * channels * height * width + - c * height * width + h * width + w; + c * height * width + h * width + w; T sum = 0; - const T* filter_ptr = filter + c * kernel_size; + const T *filter_ptr = filter + c * kernel_size; for (int inc = 0; inc < input_channels; ++inc) { for (int kh = 0; kh < kernel_h; ++kh) { for (int kw = 0; kw < kernel_w; ++kw) { @@ -71,7 +87,7 @@ class Conv2dFunctor { if (inh < 0 || inh >= input_height || inw < 0 || inw >= input_width) { MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && - inw >= padded_w_start && inw < padded_w_stop, + inw >= padded_w_start && inw < padded_w_stop, "Out of range read from input: ", inh, ", ", inw); // else padding with 0: @@ -79,8 +95,8 @@ class Conv2dFunctor { } else { index_t input_offset = n * input_channels * input_height * input_width + - inc * input_height * input_width + inh * input_width + - inw; + inc * input_height * input_width + inh * input_width + + inw; sum += input[input_offset] * *filter_ptr; } ++filter_ptr; @@ -95,20 +111,20 @@ class Conv2dFunctor { } private: - const int* strides_; // [stride_h, stride_w] - const int* paddings_; // [padding_h, padding_w] - const int* dilations_; // [dilation_h, dilation_w] + const int *strides_; // [stride_h, stride_w] + std::vector paddings_; // [padding_h, padding_w] + const int *dilations_; // [dilation_h, dilation_w] }; -template <> +template<> void Conv2dFunctor::operator()( - const float* input, - const index_t* input_shape, - const float* filter, - const index_t* filter_shape, - const float* bias, - float* output, - const index_t* output_shape); + const float *input, + const index_t *input_shape, + const float *filter, + const index_t *filter_shape, + const float *bias, + float *output, + const index_t *output_shape); } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc index ca7b0f1a169fcb5e91f711be1ba7f24c1af3ce58..0b121a70dc9b3b99892e086b525830364c7040a8 100644 --- a/mace/kernels/neon/batch_norm_neon.cc +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -44,8 +44,7 @@ void BatchNormFunctor::operator()( for (index_t j = 0; j < count; ++j) { float32x4_t input_f = vld1q_f32(input_sample_ptr); - float32x4_t output_f = new_offset_f; - output_f = vfmaq_f32(output_f, input_f, new_scale_f); + float32x4_t output_f = vfmaq_f32(new_offset_f, input_f, new_scale_f); vst1q_f32(output_sample_ptr, output_f); input_sample_ptr += 4; output_sample_ptr += 4; diff --git a/mace/kernels/neon/conv_2d_neon.cc b/mace/kernels/neon/conv_2d_neon.cc index c530e496e5459124808a5c4500f01d84100995a3..29bccaca933f9354488739558cc2726075b88564 100644 --- a/mace/kernels/neon/conv_2d_neon.cc +++ b/mace/kernels/neon/conv_2d_neon.cc @@ -81,7 +81,7 @@ void Conv2dFunctor::operator()(const float *input, // Keep this alive during kernel execution Tensor padded_input; if (paddings_[0] > 0 || paddings_[1] > 0) { - ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); + ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input); input = padded_input.data(); input_shape = padded_input.shape().data(); } diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 3ac6689cd8d8a6f5198c56a70623ed50a7d6e0b7..5c15ca83951fef68e0b5b6d3f94a26a79afc5ed4 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -17,7 +17,12 @@ template class Conv2dOp : public ConvPool2dOpBase { public: Conv2dOp(const OperatorDef &op_def, Workspace *ws) - : ConvPool2dOpBase(op_def, ws) {}; + : ConvPool2dOpBase(op_def, ws), + functor_(this->Input(INPUT)->shape().data(), + this->Input(FILTER)->shape().data(), + this->strides_.data(), + this->padding_, + this->dilations_.data()) {} bool Run() override { const Tensor *input = this->Input(INPUT); @@ -27,21 +32,19 @@ class Conv2dOp : public ConvPool2dOpBase { std::vector output_shape(4); std::vector paddings(2); - kernels::CalcPaddingAndOutputSize( - input->shape().data(), filter->shape().data(), this->dilations_.data(), - this->strides_.data(), this->padding_, output_shape.data(), - paddings.data()); + this->CalOutputSize(input->shape().data(), filter->shape().data(), output_shape.data()); output->Resize(output_shape); - auto conv2d = kernels::Conv2dFunctor( - this->strides_.data(), paddings.data(), this->dilations_.data()); - conv2d(input->data(), input->shape().data(), filter->data(), - filter->shape().data(), bias->data(), output->mutable_data(), - output->shape().data()); + functor_(input->data(), input->shape().data(), filter->data(), + filter->shape().data(), bias->data(), output->mutable_data(), + output->shape().data()); return true; } + private: + kernels::Conv2dFunctor functor_; + protected: OP_INPUT_TAGS(INPUT, FILTER, BIAS); OP_OUTPUT_TAGS(OUTPUT);