未验证 提交 30f4ef7f 编写于 作者: Z zhoutianzi666 提交者: GitHub

support nhwc in conv2d_fusion (#48642)

上级 4639d65d
......@@ -81,12 +81,16 @@ class Conv2DFusionOp : public operators::ConvOp {
std::string data_format = ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format,
"NHWC",
"NDHWC",
platform::errors::PermissionDenied(
"Operator(Conv2DFusion) only supports data format of "
"channel first (NCHW) now. But received: data_format = '%s'.",
"Operator(Conv2DFusion) supports data format of "
"channel first (NCHW,NCDHW) and data format of channel last(NHWC) "
"now. But received: data_format = '%s'.",
data_format));
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
std::vector<int64_t> output_shape = ComputeOutputShape(ctx);
ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
......@@ -112,21 +116,31 @@ class Conv2DFusionOp : public operators::ConvOp {
std::vector<framework::DDim> output_shapes(split_channels.size());
for (size_t i = 0; i < split_channels.size(); ++i) {
split_channels_sum += split_channels[i];
if (channel_last) {
output_shapes[i] = phi::make_ddim({output_shape[0],
output_shape[1],
output_shape[2],
split_channels[i]});
} else {
output_shapes[i] = phi::make_ddim({output_shape[0],
split_channels[i],
output_shape[2],
output_shape[3]});
}
}
int output_channels = output_shape[1];
// for NHWC
if (channel_last) output_channels = output_shape[3];
PADDLE_ENFORCE_EQ(
split_channels_sum,
output_shape[1],
output_channels,
platform::errors::InvalidArgument(
"The sum of Attr(split_channels) is expected to be equal to the "
"The sum of Attr(split_channels) is expected to be equal to "
"the "
"total output channels. But received: the sum of "
"Attr(split_channels) = %d, the total output channels = %d.",
split_channels_sum,
output_shape[1]));
output_channels));
ctx->SetOutputsDim("Outputs", output_shapes);
}
}
......@@ -159,6 +173,17 @@ class Conv2DFusionOp : public operators::ConvOp {
const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
// if data_format is NHWC, we convert the weight dimension to the form of
// nchw to minimize program changes.
if (data_format == "NHWC") {
int kh = filter_dims[1];
int kw = filter_dims[2];
int ic = filter_dims[3];
filter_dims[1] = ic;
filter_dims[2] = kh;
filter_dims[3] = kw;
}
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
......
......@@ -56,6 +56,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const std::string activation = ctx.Attr<std::string>("activation");
std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format,
"NHWC",
platform::errors::PermissionDenied(
"Operator(Conv2DFusion) in cuDNN only supports data format of "
"channel first (NCHW) now. But received: data_format = '%s'.",
data_format));
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册