未验证 提交 5216d055 编写于 作者: P Pei Yang 提交者: GitHub

[cherry-pick]fix multihead matmul shared params (#27321)

* fix multihead matmul shared params (#27121)

* fix multihead matmul shared params
上级 67f87d6d
...@@ -615,53 +615,62 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -615,53 +615,62 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern); multihead_pattern);
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, // If weights or biases in qkv's fc are shared by multiple multihead_matmul
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, // patterns, we do not support this kind of fusion, this pass will not take
reshape2_0, reshape2_qkv_out, scale, scale_out); // effect.
bool is_fc_params_shared =
std::unordered_set<const Node*> marked_nodes({eltadd0, mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
eltadd1, mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd2, eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
eltadd1_b, if (!is_fc_params_shared) {
eltadd2_b, fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
eltadd0_out, mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
eltadd1_out, eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
eltadd2_out,
reshape2_0, std::unordered_set<const Node*> marked_nodes({eltadd0,
reshape2_1, eltadd1,
reshape2_2, eltadd2,
reshape2_0_out, eltadd1_b,
reshape2_1_out, eltadd2_b,
reshape2_2_out, eltadd0_out,
transpose2_0, eltadd1_out,
transpose2_1, eltadd2_out,
transpose2_2, reshape2_0,
transpose2_0_out, reshape2_1,
transpose2_1_out, reshape2_2,
transpose2_2_out, reshape2_0_out,
matmul_qk, reshape2_1_out,
matmul_qk_out, reshape2_2_out,
eltadd_qk, transpose2_0,
eltadd_qk_out, transpose2_1,
softmax_qk, transpose2_2,
softmax_qk_out, transpose2_0_out,
transpose2_qkv, transpose2_1_out,
transpose2_qkv_out, transpose2_2_out,
matmul_qkv, matmul_qk,
matmul_qkv_out, matmul_qk_out,
mul0, eltadd_qk,
mul1, eltadd_qk_out,
mul2, softmax_qk,
mul0_out, softmax_qk_out,
mul1_out, transpose2_qkv,
mul2_out, transpose2_qkv_out,
mul1_w, matmul_qkv,
mul2_w, matmul_qkv_out,
reshape2_qkv, mul0,
scale}); mul1,
// Remove unneeded nodes. mul2,
GraphSafeRemoveNodes(graph, marked_nodes); mul0_out,
++fusion_count; mul1_out,
mul2_out,
mul1_w,
mul2_w,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}
}; };
gpd(graph, handler); gpd(graph, handler);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册