diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 1e8349e878781dccc622580f5e80b803e2194dee..57bee20247c9644941f87db48406ef2b097a23fb 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -753,7 +753,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul"); + transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X"); auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk_out_var = @@ -827,7 +827,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) ->assert_is_op_output("transpose2"); transpose2_1_out_var->AsIntermediate()->assert_is_op_input( - "matmul"); // link to matmul qk + "matmul", "Y"); // link to matmul qk // Third path to matmul auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");