提交 2cb646dc 编写于 作者: L liuqi

Update the padding calculation logic of depthwise conv2d.

上级 de985851
...@@ -15,11 +15,21 @@ namespace kernels { ...@@ -15,11 +15,21 @@ namespace kernels {
template<DeviceType D, typename T> template<DeviceType D, typename T>
class DepthwiseConv2dFunctor { class DepthwiseConv2dFunctor {
public: public:
DepthwiseConv2dFunctor(const index_t* input_shape,
const index_t* filter_shape,
const int* strides,
const Padding padding,
const int* dilations) :
strides_(strides),
paddings_(2, 0),
dilations_(dilations) {
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
}
DepthwiseConv2dFunctor(const int* strides, DepthwiseConv2dFunctor(const int* strides,
Padding paddings, const std::vector<int>& paddings,
const int* dilations) : const int* dilations) :
strides_(strides), strides_(strides),
padding_(paddings), paddings_(paddings),
dilations_(dilations) {} dilations_(dilations) {}
void operator()(const T* input, // NCHW void operator()(const T* input, // NCHW
...@@ -53,13 +63,11 @@ class DepthwiseConv2dFunctor { ...@@ -53,13 +63,11 @@ class DepthwiseConv2dFunctor {
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
vector<int> paddings_size(2, 0);
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding_, paddings_size.data());
// The left-upper most offset of the padded input // The left-upper most offset of the padded input
int padded_h_start = 0 - paddings_size[0] / 2; int padded_h_start = 0 - paddings_[0] / 2;
int padded_w_start = 0 - paddings_size[1] / 2; int padded_w_start = 0 - paddings_[1] / 2;
index_t padded_h_stop = input_height + paddings_size[0] - paddings_size[0] / 2; index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
index_t padded_w_stop = input_width + paddings_size[1] - paddings_size[1] / 2; index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2;
index_t kernel_size = filter_shape[1] * kernel_h * kernel_w; index_t kernel_size = filter_shape[1] * kernel_h * kernel_w;
index_t multiplier = channels / input_channels; index_t multiplier = channels / input_channels;
...@@ -103,7 +111,7 @@ class DepthwiseConv2dFunctor { ...@@ -103,7 +111,7 @@ class DepthwiseConv2dFunctor {
} }
private: private:
const int* strides_; // [stride_h, stride_w] const int* strides_; // [stride_h, stride_w]
Padding padding_ ; std::vector<int> paddings_; // [padding_h, padding_w]
const int* dilations_; // [dilation_h, dilation_w] const int* dilations_; // [dilation_h, dilation_w]
}; };
......
...@@ -57,17 +57,15 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float* in ...@@ -57,17 +57,15 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(const float* in
<< "filter" << kernel_h << "x" << kernel_w << "," << "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1] << " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version"; << " is not implemented yet, using slow version";
DepthwiseConv2dFunctor<DeviceType::CPU, float>(strides_, padding_, dilations_)( DepthwiseConv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, input_shape, filter, filter_shape, bias, output, output_shape); input, input_shape, filter, filter_shape, bias, output, output_shape);
return; return;
} }
// Keep this alive during kernel execution // Keep this alive during kernel execution
vector<int> paddings_size(2, 0);
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding_, paddings_size.data());
Tensor padded_input; Tensor padded_input;
if (paddings_size[0] > 0 || paddings_size[1] > 0) { if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_size.data(), &padded_input); ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input);
input = padded_input.data<float>(); input = padded_input.data<float>();
input_shape = padded_input.shape().data(); input_shape = padded_input.shape().data();
} }
......
...@@ -22,15 +22,12 @@ class ConvPool2dOpBase : public Operator<D, T> { ...@@ -22,15 +22,12 @@ class ConvPool2dOpBase : public Operator<D, T> {
void CalOutputSize(const index_t *input_shape, // NCHW void CalOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape) { index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0,
"Invalid dilations, must >= 1"); "Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && MACE_CHECK((dilations_[0] == 1 || strides_[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1), (dilations_[1] == 1 || strides_[1] == 1),
"If dilations > 1, strides should be 1"); "If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape); MACE_CHECK_NOTNULL(output_shape);
/* /*
...@@ -42,21 +39,21 @@ class ConvPool2dOpBase : public Operator<D, T> { ...@@ -42,21 +39,21 @@ class ConvPool2dOpBase : public Operator<D, T> {
index_t output_height, output_width; index_t output_height, output_width;
switch (padding) { switch (padding_) {
case VALID: case VALID:
output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations[0] - 1) / strides[0] + 1; output_height = (input_shape[2] - (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations[1] - 1) / strides[1] + 1; output_width = (input_shape[3] - (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break; break;
case SAME: case SAME:
output_height = (input_shape[2] - 1) / strides[0] + 1; output_height = (input_shape[2] - 1) / strides_[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1; output_width = (input_shape[3] - 1) / strides_[1] + 1;
break; break;
case FULL: case FULL:
output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations[0] - 1) / strides[0] + 1; output_height = (input_shape[2] + (filter_shape[2] - 1) * dilations_[0] - 1) / strides_[0] + 1;
output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations[1] - 1) / strides[1] + 1; output_width = (input_shape[3] + (filter_shape[3] - 1) * dilations_[1] - 1) / strides_[1] + 1;
break; break;
default: default:
MACE_CHECK(false, "Unsupported padding type: ", padding); MACE_CHECK(false, "Unsupported padding type: ", padding_);
} }
output_shape[0] = input_shape[0]; output_shape[0] = input_shape[0];
......
...@@ -19,7 +19,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -19,7 +19,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
public: public:
DepthwiseConv2dOp(const OperatorDef& op_def, Workspace* ws) DepthwiseConv2dOp(const OperatorDef& op_def, Workspace* ws)
: ConvPool2dOpBase<D, T>(op_def, ws), : ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->strides_.data(), this->padding_, this->dilations_.data()){}; functor_(this->Input(INPUT)->shape().data(),
this->Input(FILTER)->shape().data(),
this->strides_.data(), this->padding_, this->dilations_.data()){};
bool Run() override { bool Run() override {
const Tensor* input = this->Input(INPUT); const Tensor* input = this->Input(INPUT);
...@@ -32,9 +34,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -32,9 +34,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
filter_shape[0] *= filter_shape[1]; filter_shape[0] *= filter_shape[1];
filter_shape[1] = 1; filter_shape[1] = 1;
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
this->CalOutputSize( this->CalOutputSize(input->shape().data(), filter_shape.data(), output_shape.data());
input->shape().data(), filter_shape.data(), this->dilations_.data(),
this->strides_.data(), this->padding_, output_shape.data());
output->Resize(output_shape); output->Resize(output_shape);
functor_(input->data<T>(), input->shape().data(), filter->data<T>(), functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册