diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index 43f48f87c09a513cb52088e7527cb3c2df13eb1a..79d27948954278227f07ba044bf955426bf75862 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -143,10 +143,7 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { LOG(WARNING) << "The subgraph is empty."; return; } - if (!IsCompat(subgraph, graph)) { - LOG(WARNING) << "preln_residual_bias pass in op compat failed."; - return; - } + VLOG(4) << "handle PrelnResidualBias fuse"; GET_IR_NODE_FROM_SUBGRAPH( elementwise_bias, elementwise_bias, fused_pattern); @@ -164,6 +161,21 @@ void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH( layer_norm_variance, layer_norm_variance, fused_pattern); + + // We can not accept that two or more layer_norm is connected to + // elementwise1_out. This will lead to two or more PrelnResidualBias + // patterns is found near elementwise1_out, and these patterns will interact + // on each other, so we make below check to ensure only one + // PrelnResidualBias pattern is delalted with. + for (auto op : elementwise1_out->inputs) { + if (op->Name() == "preln_residual_bias") return; + } + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "preln_residual_bias pass in op compat failed."; + return; + } + std::unordered_set del_node_set; // Create an PrelnResidualBias op node OpDesc new_desc;