diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 40e01c75bb99157aedccd0692d7410b99393c009..198107ea082dc86d9e65a926bf9befe2fc4abfa4 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -615,6 +615,16 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); + // If weights or biases in qkv's fc are shared by multiple multihead_matmul + // patterns, we do not support this kind of fusion, this pass will not take + // effect. + bool is_fc_params_shared = + mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 || + mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 || + eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1; + if (is_fc_params_shared) { + return; + } fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);