未验证 提交 8ccf549b 编写于 作者: P Pei Yang 提交者: GitHub

specify multihead_matmul_fuse_pass_v3 QK path (#32659)

上级 f46f15a0
......@@ -753,7 +753,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
......@@ -827,7 +827,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk
"matmul", "Y"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册