未验证 提交 30dfa745 编写于 作者: P Pei Yang 提交者: GitHub

specify multihead_matmul_fuse_pass_v3 QK path (#32659) (#32668)

上级 ef7b6d55
...@@ -753,7 +753,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { ...@@ -753,7 +753,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2"); ->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul"); transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
...@@ -827,7 +827,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { ...@@ -827,7 +827,7 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2"); ->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input( transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk "matmul", "Y"); // link to matmul qk
// Third path to matmul // Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul"); auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册