From ace61b8ba538c77882bcaef78cf25fbf13c4e8ba Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Thu, 4 May 2023 16:30:45 +0800 Subject: [PATCH] Fix a bug in constant folding pass (#53456) --- paddle/fluid/framework/ir/constant_folding_pass.cc | 2 +- paddle/fluid/inference/api/paddle_pass_builder.cc | 2 +- paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/constant_folding_pass.cc b/paddle/fluid/framework/ir/constant_folding_pass.cc index 74d8e4a29a7..9e3d1d5c08c 100644 --- a/paddle/fluid/framework/ir/constant_folding_pass.cc +++ b/paddle/fluid/framework/ir/constant_folding_pass.cc @@ -64,7 +64,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { platform::errors::Fatal( "scope must not be null when applying constant floding.")); - std::vector blacklist{"feed", "matrix_multiply"}; + std::vector blacklist{"feed", "matrix_multiply", "save"}; auto op_node_sorted = framework::ir::TopologyVarientSort( *graph, static_cast(0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e7c24272b81..134d56180b6 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -420,6 +420,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("mkldnn_placement_pass"); + passes_.push_back("constant_folding_pass"); passes_.push_back("squeeze2_transpose2_onednn_fuse_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); @@ -474,7 +475,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass"); passes_.push_back("int8_scale_calculation_mkldnn_pass"); passes_.push_back("params_quantization_mkldnn_pass"); - passes_.push_back("constant_folding_pass"); } use_mkldnn_int8_ = true; #else diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index a5274c5f7ae..d408514e839 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -192,7 +192,9 @@ class FCMKLDNNHandler } else { auto scale_in_data = ctx.Attr("Scale_in"); auto scale_weights_data = ctx.Attr>("Scale_weights"); - bool has_activation = !ctx.Attr("activation_type").empty(); + bool has_activation = !ctx.Attr("activation_type").empty() || + (ctx.HasAttr("fuse_activation") && + !ctx.Attr("fuse_activation").empty()); bool force_fp32_output = ctx.Attr("force_fp32_output"); bool fuse_residual_conn = ctx.HasAttr("fuse_residual_connection") && ctx.Attr("fuse_residual_connection"); -- GitLab