From 5fb8c920541bb3fe4accca4407bb73aea495ba63 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Tue, 8 Sep 2020 10:42:15 +0800 Subject: [PATCH] fix multihead matmul shared params (#27121) --- .../fluid/framework/ir/multihead_matmul_fuse_pass.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 40e01c75bb9..198107ea082 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -615,6 +615,16 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); + // If weights or biases in qkv's fc are shared by multiple multihead_matmul + // patterns, we do not support this kind of fusion, this pass will not take + // effect. + bool is_fc_params_shared = + mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 || + mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 || + eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1; + if (is_fc_params_shared) { + return; + } fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out); -- GitLab