未验证 提交 363b25aa 编写于 作者: O Ouyang Chao 提交者: GitHub

improve performance of DepthwiseConv(NHWC) (#31677)

* improve performance of DepthwiseConv(NWHC)
上级 10af966a
...@@ -903,29 +903,19 @@ class DepthwiseConvKernel : public framework::OpKernel<T> { ...@@ -903,29 +903,19 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
"and input channel number is %d", "and input channel number is %d",
output->dims()[1], input->dims()[1])); output->dims()[1], input->dims()[1]));
} }
// transform tensor
Tensor transformed_input(input->type());
Tensor transformed_output(output->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output,
&transformed_output);
} else {
transformed_input = *input;
transformed_output = *output;
}
// update padding and dilation // update padding and dilation
auto in_dims = transformed_input.dims(); auto in_dims = input->dims();
auto filter_dims = filter.dims(); auto filter_dims = filter.dims();
framework::DDim in_data_dims; framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); const framework::DataLayout data_layout =
framework::StringToDataLayout(data_format);
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims = framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
...@@ -944,16 +934,12 @@ class DepthwiseConvKernel : public framework::OpKernel<T> { ...@@ -944,16 +934,12 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
if (fuse_relu) { if (fuse_relu) {
math::DepthwiseConvFunctor<DeviceContext, T, true> depthwiseConv; math::DepthwiseConvFunctor<DeviceContext, T, true> depthwiseConv;
depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings, depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
dilations, &transformed_output); output, data_layout);
} else { } else {
math::DepthwiseConvFunctor<DeviceContext, T, false> depthwiseConv; math::DepthwiseConvFunctor<DeviceContext, T, false> depthwiseConv;
depthwiseConv(dev_ctx, transformed_input, filter, strides, paddings, depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
dilations, &transformed_output); output, data_layout);
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_output,
output);
} }
} }
}; };
...@@ -981,33 +967,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> { ...@@ -981,33 +967,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
context.Attr<std::string>("padding_algorithm"); context.Attr<std::string>("padding_algorithm");
const std::string data_format = context.Attr<std::string>("data_format"); const std::string data_format = context.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// transform Tensor
Tensor transformed_input(input->type());
Tensor transformed_output_grad(output_grad->type());
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input,
&transformed_input);
TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);
ResizeToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
TransToChannelFirst<DeviceContext, T>(context, output_grad,
&transformed_output_grad);
} else {
transformed_input = *input;
transformed_output_grad = *output_grad;
}
// update padding and dilation // update padding and dilation
auto in_dims = transformed_input.dims(); auto in_dims = input->dims();
auto filter_dims = filter.dims(); auto filter_dims = filter.dims();
framework::DDim in_data_dims; framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); const framework::DataLayout data_layout =
framework::StringToDataLayout(data_format);
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
}
framework::DDim filter_data_dims = framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims); std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
...@@ -1025,33 +996,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> { ...@@ -1025,33 +996,18 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
Tensor transformed_input_grad(input_grad->type()); set_zero(dev_ctx, input_grad, static_cast<T>(0));
if (channel_last) {
ResizeToChannelFirst<DeviceContext, T>(context, input_grad,
&transformed_input_grad);
} else {
transformed_input_grad = *input_grad;
}
set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
if (fuse_relu) { if (fuse_relu) {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, true> math::DepthwiseConvInputGradFunctor<DeviceContext, T, true>
depthwiseConvInputGrad; depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, transformed_input, filter, depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
transformed_output_grad, strides, paddings, paddings, dilations, input_grad, data_layout);
dilations, &transformed_input_grad);
} else { } else {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, false> math::DepthwiseConvInputGradFunctor<DeviceContext, T, false>
depthwiseConvInputGrad; depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, transformed_input, filter, depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
transformed_output_grad, strides, paddings, paddings, dilations, input_grad, data_layout);
dilations, &transformed_input_grad);
}
if (channel_last) {
TransToChannelLast<DeviceContext, T>(context, &transformed_input_grad,
input_grad);
} }
} }
...@@ -1061,15 +1017,13 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> { ...@@ -1061,15 +1017,13 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
if (fuse_relu) { if (fuse_relu) {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, true> math::DepthwiseConvFilterGradFunctor<DeviceContext, T, true>
depthwiseConvFilterGrad; depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, transformed_input, depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
transformed_output_grad, strides, paddings, paddings, dilations, filter_grad, data_layout);
dilations, filter_grad);
} else { } else {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, false> math::DepthwiseConvFilterGradFunctor<DeviceContext, T, false>
depthwiseConvFilterGrad; depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, transformed_input, depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
transformed_output_grad, strides, paddings, paddings, dilations, filter_grad, data_layout);
dilations, filter_grad);
} }
} }
} }
......
...@@ -112,10 +112,6 @@ def _conv_nd(x, ...@@ -112,10 +112,6 @@ def _conv_nd(x,
# Due to the poor performance of NHWC, we transpose the input to NCHW. # Due to the poor performance of NHWC, we transpose the input to NCHW.
origin_format = data_format origin_format = data_format
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
x = nn.transpose(x, perm=[0, 3, 1, 2])
data_format = "NCHW"
channel_dim = 1
if in_dygraph_mode(): if in_dygraph_mode():
attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn',
...@@ -159,10 +155,6 @@ def _conv_nd(x, ...@@ -159,10 +155,6 @@ def _conv_nd(x,
'use_mkldnn': use_mkldnn}) 'use_mkldnn': use_mkldnn})
else: else:
out = pre_bias out = pre_bias
if origin_format == "NHWC" and op_type == "depthwise_conv2d":
out = nn.transpose(out, perm=[0, 2, 3, 1])
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册