From 1fbd44403fe09e3d9c57d453cf5d99a94a116f51 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 8 Jun 2022 10:42:47 +0800 Subject: [PATCH] [Paddle-Inference]support matmulv2 in multihead (#43269) * support matmulv2 in multihead --- .../ir/trt_multihead_matmul_fuse_pass.cc | 64 ++++++++++++------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 2e3e957fd15..8fff2f953c3 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -235,16 +235,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { } PDNode* TrtMultiHeadMatmulPattern::operator()() { + std::unordered_set mul_ops{"mul", "matmul_v2"}; + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; auto* input0 = pattern->NewNode(input0_repr()); - input0->assert_is_op_input("mul"); + input0->assert_is_ops_input(mul_ops); // First path with scale - auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops); auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) ->AsInput() - ->assert_is_op_input("mul", "Y"); + ->assert_is_ops_input(mul_ops, "Y"); auto* mul0_out_var = - pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul"); + pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops); decltype(mul0) eltadd0; decltype(mul0) eltadd0_b_var; @@ -277,11 +279,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); auto* scale_out_var = pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); - scale_out_var->AsIntermediate()->assert_is_op_input("matmul"); + scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops); matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); auto* eltadd_qk = @@ -297,12 +300,12 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); - softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul"); + softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); auto* matmul_qkv = - pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul"); + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); auto* matmul_qkv_out_var = - pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul"); + pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); auto* transpose2_qkv = @@ -315,15 +318,15 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) ->assert_is_op_output("reshape2"); - reshape2_qkv_out_var->assert_is_op_input("mul"); + reshape2_qkv_out_var->assert_is_ops_input(mul_ops); // Second path to matmul - auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul"); + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops); auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) ->AsInput() - ->assert_is_op_input("mul", "Y"); + ->assert_is_ops_input(mul_ops, "Y"); auto* mul1_out_var = - pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul"); + pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops); decltype(mul1) eltadd1; decltype(mul1) eltadd1_b_var; @@ -350,16 +353,16 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_1_out_var->AsIntermediate()->assert_is_op_input( - "matmul"); // link to matmul qk + transpose2_1_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qk // Third path to matmul - auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul"); + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops); auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) ->AsInput() - ->assert_is_op_input("mul", "Y"); + ->assert_is_ops_input(mul_ops, "Y"); auto* mul2_out_var = - pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul"); + pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops); decltype(mul2) eltadd2; decltype(mul2) eltadd2_b_var; @@ -386,8 +389,8 @@ PDNode* TrtMultiHeadMatmulPattern::operator()() { pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) ->assert_is_op_output("transpose2"); - transpose2_2_out_var->AsIntermediate()->assert_is_op_input( - "matmul"); // link to matmul qkv + transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qkv // Q path mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); @@ -734,6 +737,23 @@ TrtMultiHeadMatmulV2FusePass::TrtMultiHeadMatmulV2FusePass() { .IsType() .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + AddOpCompat(OpCompat("softmax")) .AddInput("X") .IsTensor() @@ -866,7 +886,7 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, auto* mul0_op_desc = mul0->Op(); // all mul op has same input. - if (multihead_op_desc.HasAttr("Input_scale")) { + if (mul0_op_desc->HasAttr("Input_scale")) { multihead_op_desc.SetAttr("Input_scale", mul0_op_desc->GetAttr("Input_scale")); } -- GitLab