未验证 提交 5216d055 编写于 作者: P Pei Yang 提交者: GitHub

[cherry-pick]fix multihead matmul shared params (#27321)

* fix multihead matmul shared params (#27121)

* fix multihead matmul shared params
上级 67f87d6d
...@@ -615,9 +615,17 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -615,9 +615,17 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern); multihead_pattern);
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, // If weights or biases in qkv's fc are shared by multiple multihead_matmul
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, // patterns, we do not support this kind of fusion, this pass will not take
reshape2_0, reshape2_qkv_out, scale, scale_out); // effect.
bool is_fc_params_shared =
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
if (!is_fc_params_shared) {
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes({eltadd0, std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1, eltadd1,
...@@ -662,6 +670,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -662,6 +670,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
// Remove unneeded nodes. // Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count; ++fusion_count;
}
}; };
gpd(graph, handler); gpd(graph, handler);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册