未验证 提交 1fbd4440 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference]support matmulv2 in multihead (#43269)

* support matmulv2 in multihead
上级 e1a34bc4
......@@ -235,16 +235,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
}
PDNode* TrtMultiHeadMatmulPattern::operator()() {
std::unordered_set<std::string> mul_ops{"mul", "matmul_v2"};
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("mul");
input0->assert_is_ops_input(mul_ops);
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
......@@ -277,11 +279,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var =
pattern->NewNode(scale_out_repr())->assert_is_op_output("scale");
scale_out_var->AsIntermediate()->assert_is_op_input("matmul");
scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
......@@ -297,12 +300,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul");
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul");
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
......@@ -315,15 +318,15 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("mul");
reshape2_qkv_out_var->assert_is_ops_input(mul_ops);
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul");
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
......@@ -350,16 +353,16 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk
transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul");
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
->assert_is_ops_input(mul_ops, "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul");
pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
......@@ -386,8 +389,8 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
......@@ -734,6 +737,23 @@ TrtMultiHeadMatmulV2FusePass::TrtMultiHeadMatmulV2FusePass() {
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
......@@ -866,7 +886,7 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
auto* mul0_op_desc = mul0->Op();
// all mul op has same input.
if (multihead_op_desc.HasAttr("Input_scale")) {
if (mul0_op_desc->HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("Input_scale",
mul0_op_desc->GetAttr("Input_scale"));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册