From 30f4ef7f685260c80c449ac3d8467b729ed2a8f8 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Mon, 5 Dec 2022 20:53:20 +0800 Subject: [PATCH] support nhwc in conv2d_fusion (#48642) --- .../fluid/operators/fused/conv_fusion_op.cc | 49 ++++++++++++++----- .../fluid/operators/fused/conv_fusion_op.cu | 9 ++++ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cc b/paddle/fluid/operators/fused/conv_fusion_op.cc index 18dd6e8e75..27440c9408 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cc +++ b/paddle/fluid/operators/fused/conv_fusion_op.cc @@ -81,12 +81,16 @@ class Conv2DFusionOp : public operators::ConvOp { std::string data_format = ctx->Attrs().Get("data_format"); PADDLE_ENFORCE_NE( data_format, - "NHWC", + "NDHWC", platform::errors::PermissionDenied( - "Operator(Conv2DFusion) only supports data format of " - "channel first (NCHW) now. But received: data_format = '%s'.", + "Operator(Conv2DFusion) supports data format of " + "channel first (NCHW,NCDHW) and data format of channel last(NHWC) " + "now. But received: data_format = '%s'.", data_format)); - + // MKL-DNN Kernels are using NCHW order of dims description + // so we ignore data_format consideration for MKL-DNN kernel + const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) && + (data_format == "NHWC" || data_format == "NDHWC"); std::vector output_shape = ComputeOutputShape(ctx); ctx->SetOutputDim("Output", phi::make_ddim(output_shape)); ctx->ShareLoD("Input", "Output"); @@ -112,21 +116,31 @@ class Conv2DFusionOp : public operators::ConvOp { std::vector output_shapes(split_channels.size()); for (size_t i = 0; i < split_channels.size(); ++i) { split_channels_sum += split_channels[i]; - output_shapes[i] = phi::make_ddim({output_shape[0], - split_channels[i], - output_shape[2], - output_shape[3]}); + if (channel_last) { + output_shapes[i] = phi::make_ddim({output_shape[0], + output_shape[1], + output_shape[2], + split_channels[i]}); + } else { + output_shapes[i] = phi::make_ddim({output_shape[0], + split_channels[i], + output_shape[2], + output_shape[3]}); + } } + int output_channels = output_shape[1]; + // for NHWC + if (channel_last) output_channels = output_shape[3]; PADDLE_ENFORCE_EQ( split_channels_sum, - output_shape[1], + output_channels, platform::errors::InvalidArgument( - "The sum of Attr(split_channels) is expected to be equal to the " + "The sum of Attr(split_channels) is expected to be equal to " + "the " "total output channels. But received: the sum of " "Attr(split_channels) = %d, the total output channels = %d.", split_channels_sum, - output_shape[1])); - + output_channels)); ctx->SetOutputsDim("Outputs", output_shapes); } } @@ -159,6 +173,17 @@ class Conv2DFusionOp : public operators::ConvOp { const std::string data_format = ctx->Attrs().Get("data_format"); + // if data_format is NHWC, we convert the weight dimension to the form of + // nchw to minimize program changes. + if (data_format == "NHWC") { + int kh = filter_dims[1]; + int kw = filter_dims[2]; + int ic = filter_dims[3]; + filter_dims[1] = ic; + filter_dims[2] = kh; + filter_dims[3] = kw; + } + // MKL-DNN Kernels are using NCHW order of dims description // so we ignore data_format consideration for MKL-DNN kernel const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) && diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index a9b577e7f4..87ed8fb68f 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -56,6 +56,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); const std::string activation = ctx.Attr("activation"); + std::string data_format = ctx.Attr("data_format"); + PADDLE_ENFORCE_NE( + data_format, + "NHWC", + platform::errors::PermissionDenied( + "Operator(Conv2DFusion) in cuDNN only supports data format of " + "channel first (NCHW) now. But received: data_format = '%s'.", + data_format)); + int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); -- GitLab