From 8ccf549be194acbee4e01d3530b1b7439629ba07 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Thu, 29 Apr 2021 10:39:58 +0800 Subject: [PATCH] specify multihead_matmul_fuse_pass_v3 QK path (#32659) --- paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 1e8349e8787..57bee20247c 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -753,7 +753,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul"); + transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X"); auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk_out_var = @@ -827,7 +827,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) ->assert_is_op_output("transpose2"); transpose2_1_out_var->AsIntermediate()->assert_is_op_input( - "matmul"); // link to matmul qk + "matmul", "Y"); // link to matmul qk // Third path to matmul auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul"); -- GitLab