diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index a2e4b7f813ce22e553b57c4810e7f0e7876bcdfa..803bc2407fa23cbc57d0114155df36d1d633dd52 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -15,11 +15,21 @@ namespace kernels { template class DepthwiseConv2dFunctor { public: + DepthwiseConv2dFunctor(const index_t* input_shape, + const index_t* filter_shape, + const int* strides, + const Padding padding, + const int* dilations) : + strides_(strides), + paddings_(2, 0), + dilations_(dilations) { + CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data()); + } DepthwiseConv2dFunctor(const int* strides, - Padding paddings, - const int* dilations) : + const std::vector& paddings, + const int* dilations) : strides_(strides), - padding_(paddings), + paddings_(paddings), dilations_(dilations) {} void operator()(const T* input, // NCHW @@ -53,13 +63,11 @@ class DepthwiseConv2dFunctor { MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); - vector paddings_size(2, 0); - CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding_, paddings_size.data()); // The left-upper most offset of the padded input - int padded_h_start = 0 - paddings_size[0] / 2; - int padded_w_start = 0 - paddings_size[1] / 2; - index_t padded_h_stop = input_height + paddings_size[0] - paddings_size[0] / 2; - index_t padded_w_stop = input_width + paddings_size[1] - paddings_size[1] / 2; + int padded_h_start = 0 - paddings_[0] / 2; + int padded_w_start = 0 - paddings_[1] / 2; + index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2; + index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2; index_t kernel_size = filter_shape[1] * kernel_h * kernel_w; index_t multiplier = channels / input_channels; @@ -103,7 +111,7 @@ class DepthwiseConv2dFunctor { } private: const int* strides_; // [stride_h, stride_w] - Padding padding_ ; + std::vector paddings_; // [padding_h, padding_w] const int* dilations_; // [dilation_h, dilation_w] }; diff --git a/mace/kernels/neon/depthwise_conv_neon.cc b/mace/kernels/neon/depthwise_conv_neon.cc index d14ad1a0c5f4757ebb31f44bbec0338dec75fefa..7bf0a839ab0d7db294beb3e3bf073841bc84b986 100644 --- a/mace/kernels/neon/depthwise_conv_neon.cc +++ b/mace/kernels/neon/depthwise_conv_neon.cc @@ -57,17 +57,15 @@ void DepthwiseConv2dFunctor::operator()(const float* in << "filter" << kernel_h << "x" << kernel_w << "," << " stride " << strides_[0] << "x" << strides_[1] << " is not implemented yet, using slow version"; - DepthwiseConv2dFunctor(strides_, padding_, dilations_)( + DepthwiseConv2dFunctor(strides_, paddings_, dilations_)( input, input_shape, filter, filter_shape, bias, output, output_shape); return; } // Keep this alive during kernel execution - vector paddings_size(2, 0); - CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding_, paddings_size.data()); Tensor padded_input; - if (paddings_size[0] > 0 || paddings_size[1] > 0) { - ConstructInputWithPadding(input, input_shape, paddings_size.data(), &padded_input); + if (paddings_[0] > 0 || paddings_[1] > 0) { + ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input); input = padded_input.data(); input_shape = padded_input.shape().data(); } diff --git a/mace/ops/conv_pool_2d_base.h b/mace/ops/conv_pool_2d_base.h index b35b3340e4bc6b74a2adcf97360493066e7f4b4a..96b544d7bdeb52e4a70146873009fee5b0656fdb 100644 --- a/mace/ops/conv_pool_2d_base.h +++ b/mace/ops/conv_pool_2d_base.h @@ -22,15 +22,12 @@ class ConvPool2dOpBase : public Operator { void CalOutputSize(const index_t *input_shape, // NCHW const index_t *filter_shape, // OIHW - const int *dilations, - const int *strides, - Padding padding, index_t *output_shape) { - MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, + MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0, "Invalid dilations, must >= 1"); - MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), + MACE_CHECK((dilations_[0] == 1 || strides_[0] == 1) && + (dilations_[1] == 1 || strides_[1] == 1), "If dilations > 1, strides should be 1"); MACE_CHECK_NOTNULL(output_shape); /* @@ -42,21 +39,21 @@ class ConvPool2dOpBase : public Operator { index_t output_height, output_width; - switch (padding) { + switch (padding_) { case VALID: - output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations[0] - 1) / strides[0] + 1; - output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations[1] - 1) / strides[1] + 1; + output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1; + output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1; break; case SAME: - output_height = (input_shape[2] - 1) / strides[0] + 1; - output_width = (input_shape[3] - 1) / strides[1] + 1; + output_height = (input_shape[2] - 1) / strides_[0] + 1; + output_width = (input_shape[3] - 1) / strides_[1] + 1; break; case FULL: - output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations[0] - 1) / strides[0] + 1; - output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations[1] - 1) / strides[1] + 1; + output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1; + output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1; break; default: - MACE_CHECK(false, "Unsupported padding type: ", padding); + MACE_CHECK(false, "Unsupported padding type: ", padding_); } output_shape[0] = input_shape[0]; diff --git a/mace/ops/depthwise_conv2d.h b/mace/ops/depthwise_conv2d.h index 5e5d487834e651ba961b22c381d2d88278ad0f13..b6a458ead4e1c8c08630772a5a7f161ace2a3cd8 100644 --- a/mace/ops/depthwise_conv2d.h +++ b/mace/ops/depthwise_conv2d.h @@ -19,7 +19,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase { public: DepthwiseConv2dOp(const OperatorDef& op_def, Workspace* ws) : ConvPool2dOpBase(op_def, ws), - functor_(this->strides_.data(), this->padding_, this->dilations_.data()){}; + functor_(this->Input(INPUT)->shape().data(), + this->Input(FILTER)->shape().data(), + this->strides_.data(), this->padding_, this->dilations_.data()){}; bool Run() override { const Tensor* input = this->Input(INPUT); @@ -32,9 +34,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase { filter_shape[0] *= filter_shape[1]; filter_shape[1] = 1; std::vector output_shape(4); - this->CalOutputSize( - input->shape().data(), filter_shape.data(), this->dilations_.data(), - this->strides_.data(), this->padding_, output_shape.data()); + this->CalOutputSize(input->shape().data(), filter_shape.data(), output_shape.data()); output->Resize(output_shape); functor_(input->data(), input->shape().data(), filter->data(),