From 90e9a486e83c5bd1cd173f9ac89e3305960c1ee4 Mon Sep 17 00:00:00 2001 From: wenbin Date: Wed, 22 Dec 2021 10:42:28 +0800 Subject: [PATCH] CE fix (#38324) * CE fix * more format --- .../ir/conv_affine_channel_fuse_pass.cc | 21 +++++++++++++++---- .../ir/conv_elementwise_add2_act_fuse_pass.cc | 2 +- .../ir/conv_elementwise_add_fuse_pass.cc | 9 +++++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index af6773042b6..6cd16132c2a 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -130,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("affine_channel")) @@ -148,7 +148,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsTensor() .End() .AddAttr("data_layout") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -197,6 +197,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { GET_CONV_BN_NODES(conv_ac_pattern); + auto data_format = conv->Op()->GetAttrIfExists("data_format"); + if (data_format == "AnyLayout") { + LOG_FIRST_N(WARNING, 1) << "conv_affine_channel_fuse_pass is enabled, " + "it's wrong if data_format of conv is not " + "NCHW."; + } + // Get affine_channel bias for resizing eltwise_y! auto* ac_bias_tensor = scope->FindVar(ac_bias->Name())->GetMutable(); @@ -282,7 +289,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("affine_channel")) .AddInput("X") @@ -299,7 +306,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsTensor() .End() .AddAttr("data_layout") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) .AddInput("X") @@ -347,6 +354,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { VLOG(4) << "handle ConvBN fuse"; GET_CONV_BN_NODES(conv_ac_pattern); + auto data_format = conv->Op()->GetAttrIfExists("data_format"); + if (data_format == "AnyLayout") { + LOG_FIRST_N(WARNING, 1) << "conv_eltwiseadd_affine_channel_fuse_pass is " + "enabled, it's wrong if data_format of conv " + "is not NCHW."; + } // OPERATORS GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern); // BIAS inputs diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index 3d1c1eb55aa..27e52167f31 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -77,7 +77,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { .AddAttr("dilations") .End() .AddAttr("data_format") - .IsStringIn({"NHWC", "NCHW"}) + .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index 439b85ffb9f..545e4a7b9e6 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -57,7 +57,7 @@ ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() { .AddAttr("dilations") .End() .AddAttr("data_format") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -97,6 +97,13 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); + auto data_format = + conv_op->Op()->GetAttrIfExists("data_format"); + if (data_format == "AnyLayout") { + LOG_FIRST_N(WARNING, 1) << "conv_elementwise_add_fuse_pass is enabled, " + "it's wrong if data_format of conv is not " + "NCHW."; + } std::string bias_name = elementwise_add_in_y->Name(); std::string output_name = elementwise_add_out->Name(); -- GitLab