From 5216d0551073fccb7ed63dde9d7c9eefd912f1fa Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Wed, 16 Sep 2020 12:36:44 +0800 Subject: [PATCH] [cherry-pick]fix multihead matmul shared params (#27321) * fix multihead matmul shared params (#27121) * fix multihead matmul shared params --- .../ir/multihead_matmul_fuse_pass.cc | 103 ++++++++++-------- 1 file changed, 56 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index bb375d760ad..05558b265a2 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -615,53 +615,62 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); - fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, - mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, - reshape2_0, reshape2_qkv_out, scale, scale_out); - - std::unordered_set marked_nodes({eltadd0, - eltadd1, - eltadd2, - eltadd1_b, - eltadd2_b, - eltadd0_out, - eltadd1_out, - eltadd2_out, - reshape2_0, - reshape2_1, - reshape2_2, - reshape2_0_out, - reshape2_1_out, - reshape2_2_out, - transpose2_0, - transpose2_1, - transpose2_2, - transpose2_0_out, - transpose2_1_out, - transpose2_2_out, - matmul_qk, - matmul_qk_out, - eltadd_qk, - eltadd_qk_out, - softmax_qk, - softmax_qk_out, - transpose2_qkv, - transpose2_qkv_out, - matmul_qkv, - matmul_qkv_out, - mul0, - mul1, - mul2, - mul0_out, - mul1_out, - mul2_out, - mul1_w, - mul2_w, - reshape2_qkv, - scale}); - // Remove unneeded nodes. - GraphSafeRemoveNodes(graph, marked_nodes); - ++fusion_count; + // If weights or biases in qkv's fc are shared by multiple multihead_matmul + // patterns, we do not support this kind of fusion, this pass will not take + // effect. + bool is_fc_params_shared = + mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 || + mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 || + eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1; + if (!is_fc_params_shared) { + fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, + mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, + eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out); + + std::unordered_set marked_nodes({eltadd0, + eltadd1, + eltadd2, + eltadd1_b, + eltadd2_b, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul1_w, + mul2_w, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + } }; gpd(graph, handler); -- GitLab