未验证 提交 ace61b8b 编写于 作者: W weishengying 提交者: GitHub

Fix a bug in constant folding pass (#53456)

上级 49d7bc5c
......@@ -64,7 +64,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
platform::errors::Fatal(
"scope must not be null when applying constant floding."));
std::vector<std::string> blacklist{"feed", "matrix_multiply"};
std::vector<std::string> blacklist{"feed", "matrix_multiply", "save"};
auto op_node_sorted = framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(0));
......
......@@ -420,6 +420,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("squeeze2_transpose2_onednn_fuse_pass");
passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass");
......@@ -474,7 +475,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass");
passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass");
passes_.push_back("constant_folding_pass");
}
use_mkldnn_int8_ = true;
#else
......
......@@ -192,7 +192,9 @@ class FCMKLDNNHandler
} else {
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
bool has_activation = !ctx.Attr<std::string>("activation_type").empty();
bool has_activation = !ctx.Attr<std::string>("activation_type").empty() ||
(ctx.HasAttr("fuse_activation") &&
!ctx.Attr<std::string>("fuse_activation").empty());
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool fuse_residual_conn = ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册