diff --git a/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc b/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc index 592ad4e2d7291d1231b05d8b469546d4e8bebf36..f28b768513f9f5ed3b03b00660708a79286a0300 100644 --- a/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc +++ b/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc @@ -425,7 +425,14 @@ void SplitLayerNormPass::ApplyImpl(Graph* graph) const { IR_NODE_LINK_TO(new_bias_node, elementwise_add1_node); IR_NODE_LINK_TO(elementwise_add1_node, layer_norm_out); - GraphSafeRemoveNodes(g, {layer_norm_op, layer_norm_bias, layer_norm_scale}); + std::unordered_set nodes2rm = {}; + nodes2rm.insert(layer_norm_op); + if (layer_norm_bias->outputs.size() <= 1UL) + nodes2rm.insert(layer_norm_bias); + if (layer_norm_scale->outputs.size() <= 1UL) + nodes2rm.insert(layer_norm_scale); + + GraphSafeRemoveNodes(g, nodes2rm); found_layer_norm_count++; };