From 1549fbd30dd41470c60b2faf081a666dd312c3a4 Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Thu, 25 May 2023 12:55:57 +0800 Subject: [PATCH] fix bias and scale shared by multi OP (#54099) --- .../framework/ir/split_layernorm_to_math_ops_pass.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc b/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc index 592ad4e2d72..f28b768513f 100644 --- a/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc +++ b/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc @@ -425,7 +425,14 @@ void SplitLayerNormPass::ApplyImpl(Graph* graph) const { IR_NODE_LINK_TO(new_bias_node, elementwise_add1_node); IR_NODE_LINK_TO(elementwise_add1_node, layer_norm_out); - GraphSafeRemoveNodes(g, {layer_norm_op, layer_norm_bias, layer_norm_scale}); + std::unordered_set nodes2rm = {}; + nodes2rm.insert(layer_norm_op); + if (layer_norm_bias->outputs.size() <= 1UL) + nodes2rm.insert(layer_norm_bias); + if (layer_norm_scale->outputs.size() <= 1UL) + nodes2rm.insert(layer_norm_scale); + + GraphSafeRemoveNodes(g, nodes2rm); found_layer_norm_count++; }; -- GitLab