From e47d8a572be7c20a49ee5df1752d01bea0e8300a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 5 Aug 2021 13:12:56 +0800 Subject: [PATCH] [pass_enhance]fix the mkldnn model performance drop problem. test=develop (#34625) --- .../ir/adaptive_pool2d_convert_global_pass.cc | 1 + .../framework/ir/conv_affine_channel_fuse_pass.cc | 6 ++++-- paddle/fluid/framework/ir/conv_bn_fuse_pass.cc | 2 ++ .../ir/conv_elementwise_add2_act_fuse_pass.cc | 1 + .../framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 5 ++++- .../mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc | 11 ++++++----- .../framework/ir/mkldnn/cpu_quantize_squash_pass.cc | 2 +- .../framework/ir/quant_conv2d_dequant_fuse_pass.cc | 1 + 8 files changed, 20 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc index 0e2bb3eaad5..c280b7c32ed 100644 --- a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc @@ -60,6 +60,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() { .IsStringIn({"NHWC", "NCHW"}) .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End(); } diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index e4ac89f04ff..3875d856d20 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -120,6 +120,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("groups") @@ -129,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("affine_channel")) @@ -267,6 +268,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("groups") @@ -276,7 +278,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("affine_channel")) .AddInput("X") diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index c362eec34b0..3a012b90848 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -620,6 +620,7 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") @@ -663,6 +664,7 @@ ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index 573436d393b..3d1c1eb55aa 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -68,6 +68,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { .AddAttr("paddings") .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("groups") diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index a7514038d40..41539a05b37 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -47,6 +47,7 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("groups") @@ -56,7 +57,7 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsType>() .End() .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -110,6 +111,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") @@ -135,6 +137,7 @@ Conv3DBiasFusePass::Conv3DBiasFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("groups") diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index bd65ad8e643..b07cc58959f 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -158,11 +158,6 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( Node* elementwise_add_op; Node* elementwise_add_identity; Node* elementwise_add_out; - if (!pass_->IsCompat(subgraph, graph)) { - LOG(WARNING) - << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; - return; - } std::tie(conv_op, conv_input, conv_filter, conv_output) = get_node_from_conv_op(subgraph); @@ -175,6 +170,12 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( if (HasFusedActivation(conv_op)) return; + if (!pass_->IsCompat(subgraph, graph)) { + LOG(WARNING) + << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + return; + } + conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetAttr("fuse_residual_connection", true); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 2483a506a8f..2b9419a5502 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -77,7 +77,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() { .End() .AddAttr("data_format") .IsOptional() - .IsStringIn({"NCHW", "NHWC"}) + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); } diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 068a50a1dc0..b48c8c6e70a 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -243,6 +243,7 @@ QuantDequantFusePass::QuantDequantFusePass() { .IsType>() .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("data_format") -- GitLab