未验证 提交 5216d055 编写于 作者: P Pei Yang 提交者: GitHub

[cherry-pick]fix multihead matmul shared params (#27321)

* fix multihead matmul shared params (#27121)

* fix multihead matmul shared params
上级 67f87d6d
......@@ -615,53 +615,62 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
eltadd2,
eltadd1_b,
eltadd2_b,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul1_w,
mul2_w,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool is_fc_params_shared =
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
if (!is_fc_params_shared) {
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
eltadd2,
eltadd1_b,
eltadd2_b,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul1_w,
mul2_w,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}
};
gpd(graph, handler);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册