未验证 提交 363b25aa 编写于 作者: O Ouyang Chao 提交者: GitHub

improve performance of DepthwiseConv(NHWC) (#31677)

* improve performance of DepthwiseConv(NWHC)
上级 10af966a
......@@ -903,29 +903,19 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
"and input channel number is %d",
output->dims()[1], input->dims()[1]));
}
// transform tensor
Tensor transformed_input(input->type());
Tensor transformed_output(output->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output,
&transformed_output);
} else {
transformed_input = *input;
transformed_output = *output;
}
// update padding and dilation
auto in_dims = transformed_input.dims();
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_format);
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
......@@ -944,16 +934,12 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
if (fuse_relu) {
math::DepthwiseConvFunctor<DeviceContext, T, true> depthwiseConv;
depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings,
dilations, &transformed_output);
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output, data_layout);
} else {
math::DepthwiseConvFunctor<DeviceContext, T, false> depthwiseConv;
depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings,
dilations, &transformed_output);
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_output,
output);
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output, data_layout);
}
}
};
......@@ -981,33 +967,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
context.Attr<std::string>("padding_algorithm");
const std::string data_format = context.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensor
Tensor transformed_input(input->type());
Tensor transformed_output_grad(output_grad->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
TransToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
} else {
transformed_input = *input;
transformed_output_grad = *output_grad;
}
// update padding and dilation
auto in_dims = transformed_input.dims();
auto in_dims = input->dims();
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
const framework::DataLayout data_layout =
framework::StringToDataLayout(data_format);
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
......@@ -1025,33 +996,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
Tensor transformed_input_grad(input_grad->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input_grad,
&transformed_input_grad);
} else {
transformed_input_grad = *input_grad;
}
set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
set_zero(dev_ctx, input_grad, static_cast<T>(0));
if (fuse_relu) {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, true>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, transformed_input, filter,
transformed_output_grad, strides, paddings,
dilations, &transformed_input_grad);
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad, data_layout);
} else {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, false>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, transformed_input, filter,
transformed_output_grad, strides, paddings,
dilations, &transformed_input_grad);
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_input_grad,
input_grad);
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad, data_layout);
}
}
......@@ -1061,15 +1017,13 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
if (fuse_relu) {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, true>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, transformed_input,
transformed_output_grad, strides, paddings,
dilations, filter_grad);
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad, data_layout);
} else {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, false>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, transformed_input,
transformed_output_grad, strides, paddings,
dilations, filter_grad);
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad, data_layout);
}
}
}
......
......@@ -112,10 +112,6 @@ def _conv_nd(x,
# Due to the poor performance of NHWC, we transpose the input to NCHW.
origin_format = data_format
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
x = nn.transpose(x, perm=[0, 3, 1, 2])
data_format = "NCHW"
channel_dim = 1
if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn',
......@@ -159,10 +155,6 @@ def _conv_nd(x,
'use_mkldnn': use_mkldnn})
else:
out = pre_bias
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
out = nn.transpose(out, perm=[0, 2, 3, 1])
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册