diff --git a/paddle/fluid/framework/ir/constant_folding_pass.cc b/paddle/fluid/framework/ir/constant_folding_pass.cc index 74d8e4a29a73cbf2bfe66424dfcbfee0f3003b72..9e3d1d5c08c72d8d0d912d315f678357f83c057c 100644 --- a/paddle/fluid/framework/ir/constant_folding_pass.cc +++ b/paddle/fluid/framework/ir/constant_folding_pass.cc @@ -64,7 +64,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { platform::errors::Fatal( "scope must not be null when applying constant floding.")); - std::vector blacklist{"feed", "matrix_multiply"}; + std::vector blacklist{"feed", "matrix_multiply", "save"}; auto op_node_sorted = framework::ir::TopologyVarientSort( *graph, static_cast(0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e7c24272b81c5568fa39e2d25df2d734e5d70a57..134d56180b6177350b191e947bcfab4ade59f75d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -420,6 +420,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("mkldnn_placement_pass"); + passes_.push_back("constant_folding_pass"); passes_.push_back("squeeze2_transpose2_onednn_fuse_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); @@ -474,7 +475,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass"); passes_.push_back("int8_scale_calculation_mkldnn_pass"); passes_.push_back("params_quantization_mkldnn_pass"); - passes_.push_back("constant_folding_pass"); } use_mkldnn_int8_ = true; #else diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index a5274c5f7ae7c8fdbb092563ec290415a93b8791..d408514e83921896fdf7eacd931885a5d583d579 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -192,7 +192,9 @@ class FCMKLDNNHandler } else { auto scale_in_data = ctx.Attr("Scale_in"); auto scale_weights_data = ctx.Attr>("Scale_weights"); - bool has_activation = !ctx.Attr("activation_type").empty(); + bool has_activation = !ctx.Attr("activation_type").empty() || + (ctx.HasAttr("fuse_activation") && + !ctx.Attr("fuse_activation").empty()); bool force_fp32_output = ctx.Attr("force_fp32_output"); bool fuse_residual_conn = ctx.HasAttr("fuse_residual_connection") && ctx.Attr("fuse_residual_connection");