未验证 提交 1549fbd3 编写于 作者: Z Zhang Jun 提交者: GitHub

fix bias and scale shared by multi OP (#54099)

上级 a64a722a
......@@ -425,7 +425,14 @@ void SplitLayerNormPass::ApplyImpl(Graph* graph) const {
IR_NODE_LINK_TO(new_bias_node, elementwise_add1_node);
IR_NODE_LINK_TO(elementwise_add1_node, layer_norm_out);
GraphSafeRemoveNodes(g, {layer_norm_op, layer_norm_bias, layer_norm_scale});
std::unordered_set<const Node*> nodes2rm = {};
nodes2rm.insert(layer_norm_op);
if (layer_norm_bias->outputs.size() <= 1UL)
nodes2rm.insert(layer_norm_bias);
if (layer_norm_scale->outputs.size() <= 1UL)
nodes2rm.insert(layer_norm_scale);
GraphSafeRemoveNodes(g, nodes2rm);
found_layer_norm_count++;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册