未验证 提交 30f4ef7f 编写于 作者: Z zhoutianzi666 提交者: GitHub

support nhwc in conv2d_fusion (#48642)

上级 4639d65d
...@@ -81,12 +81,16 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -81,12 +81,16 @@ class Conv2DFusionOp : public operators::ConvOp {
std::string data_format = ctx->Attrs().Get<std::string>("data_format"); std::string data_format = ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
data_format, data_format,
"NHWC", "NDHWC",
platform::errors::PermissionDenied( platform::errors::PermissionDenied(
"Operator(Conv2DFusion) only supports data format of " "Operator(Conv2DFusion) supports data format of "
"channel first (NCHW) now. But received: data_format = '%s'.", "channel first (NCHW,NCDHW) and data format of channel last(NHWC) "
"now. But received: data_format = '%s'.",
data_format)); data_format));
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
std::vector<int64_t> output_shape = ComputeOutputShape(ctx); std::vector<int64_t> output_shape = ComputeOutputShape(ctx);
ctx->SetOutputDim("Output", phi::make_ddim(output_shape)); ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output"); ctx->ShareLoD("Input", "Output");
...@@ -112,21 +116,31 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -112,21 +116,31 @@ class Conv2DFusionOp : public operators::ConvOp {
std::vector<framework::DDim> output_shapes(split_channels.size()); std::vector<framework::DDim> output_shapes(split_channels.size());
for (size_t i = 0; i < split_channels.size(); ++i) { for (size_t i = 0; i < split_channels.size(); ++i) {
split_channels_sum += split_channels[i]; split_channels_sum += split_channels[i];
output_shapes[i] = phi::make_ddim({output_shape[0], if (channel_last) {
split_channels[i], output_shapes[i] = phi::make_ddim({output_shape[0],
output_shape[2], output_shape[1],
output_shape[3]}); output_shape[2],
split_channels[i]});
} else {
output_shapes[i] = phi::make_ddim({output_shape[0],
split_channels[i],
output_shape[2],
output_shape[3]});
}
} }
int output_channels = output_shape[1];
// for NHWC
if (channel_last) output_channels = output_shape[3];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
split_channels_sum, split_channels_sum,
output_shape[1], output_channels,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The sum of Attr(split_channels) is expected to be equal to the " "The sum of Attr(split_channels) is expected to be equal to "
"the "
"total output channels. But received: the sum of " "total output channels. But received: the sum of "
"Attr(split_channels) = %d, the total output channels = %d.", "Attr(split_channels) = %d, the total output channels = %d.",
split_channels_sum, split_channels_sum,
output_shape[1])); output_channels));
ctx->SetOutputsDim("Outputs", output_shapes); ctx->SetOutputsDim("Outputs", output_shapes);
} }
} }
...@@ -159,6 +173,17 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -159,6 +173,17 @@ class Conv2DFusionOp : public operators::ConvOp {
const std::string data_format = const std::string data_format =
ctx->Attrs().Get<std::string>("data_format"); ctx->Attrs().Get<std::string>("data_format");
// if data_format is NHWC, we convert the weight dimension to the form of
// nchw to minimize program changes.
if (data_format == "NHWC") {
int kh = filter_dims[1];
int kw = filter_dims[2];
int ic = filter_dims[3];
filter_dims[1] = ic;
filter_dims[2] = kh;
filter_dims[3] = kw;
}
// MKL-DNN Kernels are using NCHW order of dims description // MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel // so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) && const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
......
...@@ -56,6 +56,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -56,6 +56,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const std::string activation = ctx.Attr<std::string>("activation"); const std::string activation = ctx.Attr<std::string>("activation");
std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format,
"NHWC",
platform::errors::PermissionDenied(
"Operator(Conv2DFusion) in cuDNN only supports data format of "
"channel first (NCHW) now. But received: data_format = '%s'.",
data_format));
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size = int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB")); static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册