From 30dfa745c7613604fc3073de90fe4abefcb2cef7 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Thu, 29 Apr 2021 14:49:37 +0800 Subject: [PATCH] specify multihead_matmul_fuse_pass_v3 QK path (#32659) (#32668) --- paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 1e8349e878..57bee20247 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -753,7 +753,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul"); + transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X"); auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk_out_var = @@ -827,7 +827,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) ->assert_is_op_output("transpose2"); transpose2_1_out_var->AsIntermediate()->assert_is_op_input( - "matmul"); // link to matmul qk + "matmul", "Y"); // link to matmul qk // Third path to matmul auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul"); -- GitLab