diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index af6773042b67870d715ed37894e0321015485d0d..6cd16132c2a10f77b08b199d6b6c4adbf1a64b72 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -130,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("affine_channel")) @@ -148,7 +148,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsTensor() .End() .AddAttr("data_layout") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -197,6 +197,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { GET_CONV_BN_NODES(conv_ac_pattern); + auto data_format = conv->Op()->GetAttrIfExists("data_format"); + if (data_format == "AnyLayout") { + LOG_FIRST_N(WARNING, 1) << "conv_affine_channel_fuse_pass is enabled, " + "it's wrong if data_format of conv is not " + "NCHW."; + } + // Get affine_channel bias for resizing eltwise_y! auto* ac_bias_tensor = scope->FindVar(ac_bias->Name())->GetMutable(); @@ -282,7 +289,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("affine_channel")) .AddInput("X") @@ -299,7 +306,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsTensor() .End() .AddAttr("data_layout") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) .AddInput("X") @@ -347,6 +354,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { VLOG(4) << "handle ConvBN fuse"; GET_CONV_BN_NODES(conv_ac_pattern); + auto data_format = conv->Op()->GetAttrIfExists("data_format"); + if (data_format == "AnyLayout") { + LOG_FIRST_N(WARNING, 1) << "conv_eltwiseadd_affine_channel_fuse_pass is " + "enabled, it's wrong if data_format of conv " + "is not NCHW."; + } // OPERATORS GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern); // BIAS inputs diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index 3d1c1eb55aa079ab26240fffe77747781785e0fe..27e52167f313721b7b149b77951dfab4f8eb217c 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -77,7 +77,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { .AddAttr("dilations") .End() .AddAttr("data_format") - .IsStringIn({"NHWC", "NCHW"}) + .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index 439b85ffb9f10dd9e50aab6353cf0df33cc6f166..545e4a7b9e616691e6b8859bae31877af7ce1427 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -57,7 +57,7 @@ ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() { .AddAttr("dilations") .End() .AddAttr("data_format") - .IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) + .IsStringIn({"NCHW", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -97,6 +97,13 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { GET_NODES; auto base_op_desc = *conv_op->Op()->Proto(); + auto data_format = + conv_op->Op()->GetAttrIfExists("data_format"); + if (data_format == "AnyLayout") { + LOG_FIRST_N(WARNING, 1) << "conv_elementwise_add_fuse_pass is enabled, " + "it's wrong if data_format of conv is not " + "NCHW."; + } std::string bias_name = elementwise_add_in_y->Name(); std::string output_name = elementwise_add_out->Name();