From 0f59d4e6d25de756d9e69659149418323500a058 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Mon, 28 Jun 2021 10:19:17 +0800 Subject: [PATCH] add compat precondition for multihead_matmul_fuse_pass_v2,v3, test=develop (#33786) --- .../ir/multihead_matmul_fuse_pass.cc | 669 ++++++++++++------ .../framework/ir/multihead_matmul_fuse_pass.h | 22 +- .../ir/multihead_matmul_fuse_pass_tester.cc | 38 +- .../fluid/framework/ir/pass_tester_helper.h | 6 +- paddle/fluid/operators/compat/matmul.pbtxt | 4 + paddle/fluid/operators/compat/softmax.pbtxt | 4 +- 6 files changed, 498 insertions(+), 245 deletions(-) diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 57bee20247..5a97727da3 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -422,13 +422,335 @@ PDNode* MultiHeadMatmulPattern::operator()() { return transpose2_2_out_var; } -static int BuildFusionV2(Graph* graph, const std::string& name_scope, - Scope* scope) { +PDNode* MultiHeadMatmulV3Pattern::operator()() { + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("matmul"); + + // First path with scale + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul"); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "Y"); + auto* mul0_out_var = + pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul"); + + decltype(mul0) eltadd0; + decltype(mul0) eltadd0_b_var; + decltype(mul0) eltadd0_out_var; + + mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = + pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); + reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X"); + + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = + pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); + softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2"); + reshape2_qkv_out_var->assert_is_op_input("matmul"); + + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul"); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "Y"); + auto* mul1_out_var = + pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul"); + + decltype(mul1) eltadd1; + decltype(mul1) eltadd1_b_var; + decltype(mul1) eltadd1_out_var; + + mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = + pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); + reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_1_out_var->AsIntermediate()->assert_is_op_input( + "matmul", "Y"); // link to matmul qk + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul"); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "Y"); + auto* mul2_out_var = + pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul"); + + decltype(mul2) eltadd2; + decltype(mul2) eltadd2_b_var; + decltype(mul2) eltadd2_out_var; + + mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = + pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); + reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qkv + + // Q path + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); + eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); + + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + // K path + mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); + eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + // compute q*k + matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var}); + eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return transpose2_2_out_var; +} +} // namespace patterns + +void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + + int fusion_count = patterns::BuildFusion(graph, name_scope_); + AddStatis(fusion_count); +} + +MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(2) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + // in bias, shape is (B, S, N*H), + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + .AddInput("Y") + // in bias, shape is (N*H) + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + // in bias, shape is (B, S, N*H) + // in biasqk, shape is (B, H, S, S) + .AddOutput("Out") + .IsTensor() + .End() + // in bias, it equal to 2 + // in biasqk, it equal to -1 or 0 + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + // -->: (B, S, H, N) -> (B, H, S, N) + // <--: (B, H, S, N) -> (B, S, H, N) + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() + .End(); + + // QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S) + // QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N) + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumEQ(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); +} + +int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, + const std::string& name_scope, + Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); // Create pattern. - MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); + patterns::MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); multihead_pattern(); // Create New OpDesc @@ -580,6 +902,11 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "Op compat check in multihead_matmul_fuse_pass_v2 failed."; + return; + } // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); @@ -714,197 +1041,141 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, return fusion_count; } -PDNode* MultiHeadMatmulV3Pattern::operator()() { - std::unordered_set matmul_ops{"matmul", "matmul_v2"}; - auto* input0 = pattern->NewNode(input0_repr()); - input0->assert_is_op_input("matmul"); - - // First path with scale - auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul"); - auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) - ->AsInput() - ->assert_is_op_input("matmul", "Y"); - auto* mul0_out_var = - pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul"); - - decltype(mul0) eltadd0; - decltype(mul0) eltadd0_b_var; - decltype(mul0) eltadd0_out_var; - - mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); - - eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); - eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) - ->AsInput() - ->assert_is_op_input("elementwise_add", "Y"); - - eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) - ->assert_is_op_output("elementwise_add"); - eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); - - auto* reshape2_0 = - pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); - - auto* reshape2_0_out_var = - pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); - reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); - - auto* transpose2_0 = - pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); - auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) - ->assert_is_op_output("transpose2"); - transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul", "X"); - - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); - auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); - matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); - - auto* eltadd_qk = - pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); - auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) - ->AsInput() - ->assert_is_op_input("elementwise_add", "Y"); - auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) - ->assert_is_op_output("elementwise_add"); - eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax"); - - auto* softmax_qk = - pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); - auto* softmax_qk_out_var = - pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); - softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); - - auto* matmul_qkv = - pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); - auto* matmul_qkv_out_var = - pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); - matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); - - auto* transpose2_qkv = - pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); - auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) - ->assert_is_op_output("transpose2"); - transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); - - auto* reshape2_qkv = - pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); - auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) - ->assert_is_op_output("reshape2"); - reshape2_qkv_out_var->assert_is_op_input("matmul"); - - // Second path to matmul - auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul"); - auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) - ->AsInput() - ->assert_is_op_input("matmul", "Y"); - auto* mul1_out_var = - pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul"); - - decltype(mul1) eltadd1; - decltype(mul1) eltadd1_b_var; - decltype(mul1) eltadd1_out_var; - - mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); - eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); - eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) - ->AsInput() - ->assert_is_op_input("elementwise_add", "Y"); - - eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) - ->assert_is_op_output("elementwise_add"); - eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); - - auto* reshape2_1 = - pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); - - auto* reshape2_1_out_var = - pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); - reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); - - auto* transpose2_1 = - pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); - auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) - ->assert_is_op_output("transpose2"); - transpose2_1_out_var->AsIntermediate()->assert_is_op_input( - "matmul", "Y"); // link to matmul qk - - // Third path to matmul - auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul"); - auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) - ->AsInput() - ->assert_is_op_input("matmul", "Y"); - auto* mul2_out_var = - pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul"); - - decltype(mul2) eltadd2; - decltype(mul2) eltadd2_b_var; - decltype(mul2) eltadd2_out_var; - - mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); - eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); - eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) - ->AsInput() - ->assert_is_op_input("elementwise_add", "Y"); - - eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) - ->assert_is_op_output("elementwise_add"); - eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); - - auto* reshape2_2 = - pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); - - auto* reshape2_2_out_var = - pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); - reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); - - auto* transpose2_2 = - pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); - auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) - ->assert_is_op_output("transpose2"); - transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( - matmul_ops); // link to matmul qkv - - // Q path - mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); - eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); +void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the multiheadMatmul pass, The scope should not be null.")); - reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); - transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); - // K path - mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); - eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); - reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); - transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); - // compute q*k - matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) - .LinksTo({matmul_qk_out_var}); - eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) - .LinksTo({eltadd_qk_out_var}); - softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); - // V path - mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var}); - eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); - reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); - transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); - // compute q*k*v - matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) - .LinksTo({matmul_qkv_out_var}); - transpose2_qkv->LinksFrom({matmul_qkv_out_var}) - .LinksTo({transpose2_qkv_out_var}); - reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) - .LinksTo({reshape2_qkv_out_var}); + int fusion_count = BuildFusionV2(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kMultiheadMatmulPass, new bool(true)); + } + AddStatis(fusion_count); +} - return transpose2_2_out_var; +MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(2) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + // in bias, shape is (B, S, N*H), + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + .AddInput("Y") + // in bias, shape is (N*H) + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + // in bias, shape is (B, S, N*H) + // in biasqk, shape is (B, H, S, S) + .AddOutput("Out") + .IsTensor() + .End() + // in bias, it equal to 2 + // in biasqk, it equal to -1 or 0 + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + // -->: (B, S, H, N) -> (B, H, S, N) + // <--: (B, H, S, N) -> (B, S, H, N) + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + // QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S) + // QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N) + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() // QK(anyvalue, will copy to new op) QKV(1.0) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); } -static int BuildFusionV3(Graph* graph, const std::string& name_scope, - Scope* scope) { +int MultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, + const std::string& name_scope, + Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); // Create pattern. - MultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope); + patterns::MultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope); multihead_pattern(); // Create New OpDesc @@ -1155,30 +1426,6 @@ static int BuildFusionV3(Graph* graph, const std::string& name_scope, return fusion_count; } -} // namespace patterns - -void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { - FusePassBase::Init(name_scope_, graph); - - int fusion_count = patterns::BuildFusion(graph, name_scope_); - AddStatis(fusion_count); -} - -void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { - FusePassBase::Init(name_scope_, graph); - auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL( - scope, - platform::errors::Fatal( - "During the multiheadMatmul pass, The scope should not be null.")); - - int fusion_count = patterns::BuildFusionV2(graph, name_scope_, scope); - if (fusion_count > 0) { - graph->Set(kMultiheadMatmulPass, new bool(true)); - } - AddStatis(fusion_count); -} - void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); auto* scope = param_scope(); @@ -1187,7 +1434,7 @@ void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { platform::errors::Fatal( "During the multiheadMatmul pass, The scope should not be null.")); - int fusion_count = patterns::BuildFusionV3(graph, name_scope_, scope); + int fusion_count = BuildFusionV3(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kMultiheadMatmulPass, new bool(true)); } diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h index c7f1336211..c39823e732 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h @@ -18,16 +18,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" - -namespace paddle { -namespace framework { -namespace ir { -class Graph; -} // namespace ir -} // namespace framework -} // namespace paddle namespace paddle { namespace framework { @@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase { class MultiHeadMatmulV2FusePass : public FusePassBase { public: - virtual ~MultiHeadMatmulV2FusePass() {} + MultiHeadMatmulV2FusePass(); protected: void ApplyImpl(Graph* graph) const; const std::string name_scope_{"multihead_matmul_fuse_v2"}; + + private: + int BuildFusionV2(Graph* graph, const std::string& name_scope, + Scope* scope) const; }; class MultiHeadMatmulV3FusePass : public FusePassBase { public: - virtual ~MultiHeadMatmulV3FusePass() {} + MultiHeadMatmulV3FusePass(); protected: void ApplyImpl(Graph* graph) const; const std::string name_scope_{"multihead_matmul_fuse_v3"}; + + private: + int BuildFusionV3(Graph* graph, const std::string& name_scope, + Scope* scope) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc index 2eda643d4e..b121436ee8 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc @@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) { // (transpose_qkv) reshape -> reshape_qkv // (reshape_qkv) mul -> mul_qkv Layers layers; - auto* x = layers.data("x", {128, 768}); + auto* x = layers.data("x", {1, 128, 768}); auto out = layers.layer_norm(x); auto* layer_out = out[0]; @@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) { auto* weights_1 = layers.data("weights1", {768, 768}, true); auto* weights_2 = layers.data("weights2", {768, 768}, true); - auto* mul_out_0 = layers.mul(layer_out, weights_0); - auto* mul_out_1 = layers.mul(layer_out, weights_1); - auto* mul_out_2 = layers.mul(layer_out, weights_2); + auto* mul_out_0 = layers.mul(layer_out, weights_0, nullptr, 2); + auto* mul_out_1 = layers.mul(layer_out, weights_1, nullptr, 2); + auto* mul_out_2 = layers.mul(layer_out, weights_2, nullptr, 2); auto* b0 = layers.data("bias_0", {768}, true); auto* b1 = layers.data("bias_1", {768}, true); auto* b2 = layers.data("bias_2", {768}, true); - auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0); - auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1); - auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2); + auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0, nullptr, 2); + auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1, nullptr, 2); + auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2, nullptr, 2); - std::vector shape = {128, 12, 64}; - auto* reshape_0 = layers.reshape2(elementwise_out_0, shape); - auto* reshape_1 = layers.reshape2(elementwise_out_1, shape); - auto* reshape_2 = layers.reshape2(elementwise_out_2, shape); + std::vector shape = {1, 128, 12, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true); + auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true); std::vector axis = {0, 2, 1, 3}; - auto* transpose_0 = layers.transpose2(reshape_0, axis); - auto* transpose_1 = layers.transpose2(reshape_1, axis); - auto* transpose_2 = layers.transpose2(reshape_2, axis); + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + auto* transpose_1 = layers.transpose2(reshape_1, axis, true); + auto* transpose_2 = layers.transpose2(reshape_2, axis, true); auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false); - auto* matmul_qk = layers.matmul(scale_0, transpose_1); + auto* matmul_qk = layers.matmul(scale_0, transpose_1, nullptr, false, true); - auto* bqk = layers.data("biasqk", {768}, true); + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2); - auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}); - auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {128, 768}); + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 768}, true); auto* weights_l = layers.data("weightsl", {768, 768}, true); - layers.mul(reshape_qkv_out, weights_l); + layers.mul(reshape_qkv_out, weights_l, nullptr, 2); std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index f5639e7bc9..284e54b3cb 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -293,13 +293,17 @@ struct Layers { return outs; } - VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr) { + VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr, + bool transpose_x = false, bool transpose_y = false) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("matmul"); op->SetInput("X", {x->Name()}); op->SetInput("Y", {y->Name()}); op->SetOutput("Out", {out->Name()}); + op->SetAttr("transpose_X", transpose_x); + op->SetAttr("transpose_Y", transpose_y); + op->SetAttr("alpha", 1.0f); return out; } diff --git a/paddle/fluid/operators/compat/matmul.pbtxt b/paddle/fluid/operators/compat/matmul.pbtxt index e68a7f31b6..8f29d93660 100644 --- a/paddle/fluid/operators/compat/matmul.pbtxt +++ b/paddle/fluid/operators/compat/matmul.pbtxt @@ -23,6 +23,10 @@ def { } } extra { + attrs { + name: "head_number" + type: INT + } attrs { name: "Scale_out" type: FLOAT diff --git a/paddle/fluid/operators/compat/softmax.pbtxt b/paddle/fluid/operators/compat/softmax.pbtxt index 5cd155ed1c..04f15ace15 100644 --- a/paddle/fluid/operators/compat/softmax.pbtxt +++ b/paddle/fluid/operators/compat/softmax.pbtxt @@ -10,12 +10,12 @@ def { name: "axis" type: INT } +} +extra { attrs { name: "data_format" type: STRING } -} -extra { attrs { name: "op_role" type: INT -- GitLab