From bfee398bb61d82bb1eb880ea94780797a9333948 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 20 Sep 2022 09:50:25 +0800 Subject: [PATCH] [Inference] fix preln_residual_bias_fuse_pass bug in TNT_small model (#46178) * fix preln_residual_bias_fuse_pass bug in TNT_small model --- .../ir/preln_residual_bias_fuse_pass.cc | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index 43f48f87c0..79d2794895 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -143,10 +143,7 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { LOG(WARNING) << "The subgraph is empty."; return; } - if (!IsCompat(subgraph, graph)) { - LOG(WARNING) << "preln_residual_bias pass in op compat failed."; - return; - } + VLOG(4) << "handle PrelnResidualBias fuse"; GET_IR_NODE_FROM_SUBGRAPH( elementwise_bias, elementwise_bias, fused_pattern); @@ -164,6 +161,21 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH( layer_norm_variance, layer_norm_variance, fused_pattern); + + // We can not accept that two or more layer_norm is connected to + // elementwise1_out. This will lead to two or more PrelnResidualBias + // patterns is found near elementwise1_out, and these patterns will interact + // on each other, so we make below check to ensure only one + // PrelnResidualBias pattern is delalted with. + for (auto op : elementwise1_out->inputs) { + if (op->Name() == "preln_residual_bias") return; + } + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "preln_residual_bias pass in op compat failed."; + return; + } + std::unordered_set del_node_set; // Create an PrelnResidualBias op node OpDesc new_desc; -- GitLab