未验证 提交 1fbd4440 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference]support matmulv2 in multihead (#43269)

* support matmulv2 in multihead
上级 e1a34bc4
...@@ -235,16 +235,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -235,16 +235,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
} }
PDNode* TrtMultiHeadMatmulPattern::operator()() { PDNode* TrtMultiHeadMatmulPattern::operator()() {
std::unordered_set<std::string> mul_ops{"mul", "matmul_v2"};
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr()); auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("mul"); input0->assert_is_ops_input(mul_ops);
// First path with scale // First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput() ->AsInput()
->assert_is_op_input("mul", "Y"); ->assert_is_ops_input(mul_ops, "Y");
auto* mul0_out_var = auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul"); pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul0) eltadd0; decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var; decltype(mul0) eltadd0_b_var;
...@@ -277,11 +279,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { ...@@ -277,11 +279,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var = auto* scale_out_var =
pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); pattern->NewNode(scale_out_repr())->assert_is_op_output("scale");
scale_out_var->AsIntermediate()->assert_is_op_input("matmul"); scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var = auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk = auto* eltadd_qk =
...@@ -297,12 +300,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { ...@@ -297,12 +300,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul"); softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv = auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul"); pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var = auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul"); pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv = auto* transpose2_qkv =
...@@ -315,15 +318,15 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { ...@@ -315,15 +318,15 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2"); ->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("mul"); reshape2_qkv_out_var->assert_is_ops_input(mul_ops);
// Second path to matmul // Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul"); auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput() ->AsInput()
->assert_is_op_input("mul", "Y"); ->assert_is_ops_input(mul_ops, "Y");
auto* mul1_out_var = auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul"); pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul1) eltadd1; decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var; decltype(mul1) eltadd1_b_var;
...@@ -350,16 +353,16 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { ...@@ -350,16 +353,16 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2"); ->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input( transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
"matmul"); // link to matmul qk matmul_ops); // link to matmul qk
// Third path to matmul // Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul"); auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput() ->AsInput()
->assert_is_op_input("mul", "Y"); ->assert_is_ops_input(mul_ops, "Y");
auto* mul2_out_var = auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul"); pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops);
decltype(mul2) eltadd2; decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var; decltype(mul2) eltadd2_b_var;
...@@ -386,8 +389,8 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { ...@@ -386,8 +389,8 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() {
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2"); ->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_op_input( transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
"matmul"); // link to matmul qkv matmul_ops); // link to matmul qkv
// Q path // Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
...@@ -734,6 +737,23 @@ TrtMultiHeadMatmulV2FusePass::TrtMultiHeadMatmulV2FusePass() { ...@@ -734,6 +737,23 @@ TrtMultiHeadMatmulV2FusePass::TrtMultiHeadMatmulV2FusePass() {
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax")) AddOpCompat(OpCompat("softmax"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -866,7 +886,7 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -866,7 +886,7 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
auto* mul0_op_desc = mul0->Op(); auto* mul0_op_desc = mul0->Op();
// all mul op has same input. // all mul op has same input.
if (multihead_op_desc.HasAttr("Input_scale")) { if (mul0_op_desc->HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("Input_scale", multihead_op_desc.SetAttr("Input_scale",
mul0_op_desc->GetAttr("Input_scale")); mul0_op_desc->GetAttr("Input_scale"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册