From 8d4450db4c36be8b0178479d6f606c540b38592e Mon Sep 17 00:00:00 2001 From: RichardWooSJTU <37864677+RichardWooSJTU@users.noreply.github.com> Date: Mon, 12 Dec 2022 17:41:36 +0800 Subject: [PATCH] update fused_multi_transformer_encoder_pass support GPT new matmul API (#48953) * fit paddle.matmul in fleetx.gpt --- .../fused_multi_transformer_decoder_pass.cc | 86 +++++++++++++++--- .../ir/fused_multi_transformer_decoder_pass.h | 4 + ...d_multi_transformer_decoder_pass_tester.cc | 30 ++++--- .../fused_multi_transformer_encoder_pass.cc | 88 ++++++++++++++++--- .../ir/fused_multi_transformer_encoder_pass.h | 4 + ...d_multi_transformer_encoder_pass_tester.cc | 33 +++---- 6 files changed, 193 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc index bc1a2dd0ed..b5d0661ae7 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -478,7 +478,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) ->assert_is_op_output("split") ->AsIntermediate() - ->assert_is_op_input("matmul", "X"); + ->assert_is_op_input("matmul_v2", "X"); auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) ->assert_is_op_output("split") ->AsIntermediate() @@ -496,7 +496,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr()) ->assert_is_op_output("concat") ->AsIntermediate() - ->assert_is_op_input("matmul") + ->assert_is_op_input("matmul_v2") ->assert_is_op_input("assign"); auto* concat_v_in_var = pattern ->NewNode(concat_v_in_repr()) @@ -529,10 +529,16 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { assign_v->LinksFrom({concat_v_out_var}); // QK path Nodes - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); - matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale"); + auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr()) + ->assert_is_op_output("scale") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add", "X"); auto* eltadd_qk = pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); @@ -554,7 +560,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) .LinksTo({matmul_qk_out_var}); - eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var}); + eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); @@ -799,7 +806,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) ->assert_is_op_output("split") ->AsIntermediate() - ->assert_is_op_input("matmul", "X"); + ->assert_is_op_input("matmul_v2", "X"); auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) ->assert_is_op_output("split") ->AsIntermediate() @@ -817,7 +824,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr()) ->assert_is_op_output("concat") ->AsIntermediate() - ->assert_is_op_input("matmul") + ->assert_is_op_input("matmul_v2") ->assert_is_op_input("assign"); auto* concat_v_in_var = pattern ->NewNode(concat_v_in_repr()) @@ -850,10 +857,16 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { assign_v->LinksFrom({concat_v_out_var}); // QK path Nodes - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); - matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale"); + auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr()) + ->assert_is_op_output("scale") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add", "X"); auto* eltadd_qk = pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); @@ -875,7 +888,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) .LinksTo({matmul_qk_out_var}); - eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var}); + eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); @@ -2192,6 +2206,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH( matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -2296,6 +2315,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( assign_v, matmul_qk, matmul_qk_out, + scale_qk, + scale_qk_out, eltadd_qk, eltadd_qk_out, softmax_qk, @@ -2382,6 +2403,23 @@ FusedMultiTransformerDecoderFuseQKVPass:: .IsNumGT(0) .End(); + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() + .End(); + AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") // the shape shoule be (B, S, N*H) .IsTensor() @@ -2917,6 +2955,11 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH( matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -3031,6 +3074,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( assign_v, matmul_qk, matmul_qk_out, + scale_qk, + scale_qk_out, eltadd_qk, eltadd_qk_out, softmax_qk, @@ -3124,6 +3169,23 @@ MultiDevicesFusedMultiTransformerDecoderFuseQKVPass:: .IsNumGT(0) .End(); + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() + .End(); + AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") // the shape shoule be (B, S, N*H) .IsTensor() diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h index fd2cfc8c66..bfdc38c708 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h @@ -182,6 +182,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { // Q, K matmul PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(scale_qk); + PATTERN_DECL_NODE(scale_qk_out); PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_out); @@ -282,6 +284,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern // Q, K matmul PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(scale_qk); + PATTERN_DECL_NODE(scale_qk_out); PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_out); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc index 2e54196e59..73026660c8 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc @@ -239,12 +239,13 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { // (eltadd_0) reshape2 -> reshape_0 // (reshape_0) transpose2 -> transpose_0 // (transpose_0) split -> split_q, split_k, - // split_v (split_k) concat -> concat_k + // split_v (split_k) concat -> concat_k // (split_v) concat -> concat_v // (concat_k) assign -> assign_k // (concat_v) assign -> assign_v - // (split_q, split_k) matmul -> matmul_qk - // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (split_q, split_k) matmul_v2 -> matmul_qk + // (matmul_qk) scale -> scale_qk + // (scale_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv @@ -298,10 +299,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { layers.assign(concat_v); // MHA: QK matmul - auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true); + auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true); + auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); - auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); // MHA: QKV matmul @@ -361,11 +363,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 50, + num_nodes_after + 52, platform::errors::InvalidArgument( "After the fused_multi_transformer_decoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 50, + num_nodes_before - 52, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, @@ -396,8 +398,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { // (split_v) concat -> concat_v // (concat_k) assign -> assign_k // (concat_v) assign -> assign_v - // (split_q, split_k) matmul -> matmul_qk - // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (split_q, split_k) matmul_v2 -> matmul_qk + // (matmul_qk) scale -> scale_qk + // (scale_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv @@ -455,10 +458,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { layers.assign(concat_v); // MHA: QK matmul - auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true); + auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true); + auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); - auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); // MHA: QKV matmul @@ -523,11 +527,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 58, + num_nodes_after + 60, platform::errors::InvalidArgument( "After the fused_multi_transformer_decoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 58, + num_nodes_before - 60, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc index 6f0ef5b755..b7a723d813 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -472,11 +472,11 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) ->assert_is_op_output("split") ->AsIntermediate() - ->assert_is_op_input("matmul", "X"); + ->assert_is_op_input("matmul_v2", "X"); auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) ->assert_is_op_output("split") ->AsOutput() - ->assert_is_op_input("matmul", "Y") + ->assert_is_op_input("matmul_v2", "Y") ->assert_is_op_input("while"); auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) ->assert_is_op_output("split") @@ -499,10 +499,17 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { while0->LinksFrom({split0_k_out_var, split0_v_out_var}); // QK path Nodes - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); - matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale"); + auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr()) + ->assert_is_op_output("scale") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add", "X"); auto* eltadd_qk = pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); @@ -524,7 +531,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) .LinksTo({matmul_qk_out_var}); - eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var}); + eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); @@ -769,11 +777,11 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr()) ->assert_is_op_output("split") ->AsIntermediate() - ->assert_is_op_input("matmul", "X"); + ->assert_is_op_input("matmul_v2", "X"); auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr()) ->assert_is_op_output("split") ->AsOutput() - ->assert_is_op_input("matmul", "Y") + ->assert_is_op_input("matmul_v2", "Y") ->assert_is_op_input("while"); auto* split0_v_out_var = pattern->NewNode(split0_v_out_repr()) ->assert_is_op_output("split") @@ -796,10 +804,17 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { while0->LinksFrom({split0_k_out_var, split0_v_out_var}); // QK path Nodes - auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2"); auto* matmul_qk_out_var = - pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); - matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale"); + auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr()) + ->assert_is_op_output("scale") + ->AsIntermediate() + ->assert_is_op_input("elementwise_add", "X"); auto* eltadd_qk = pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); @@ -821,7 +836,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { // QK path Linsk matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) .LinksTo({matmul_qk_out_var}); - eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var}); + eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var}) .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); @@ -2637,6 +2653,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH( matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -2739,6 +2760,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( split0_v_out, matmul_qk, matmul_qk_out, + scale_qk, + scale_qk_out, eltadd_qk, eltadd_qk_out, softmax_qk, @@ -2826,6 +2849,23 @@ FusedMultiTransformerEncoderFuseQKVPass:: .IsNumGT(0) .End(); + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() + .End(); + AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") // the shape shoule be (B, S, N*H) .IsTensor() @@ -3468,6 +3508,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH( matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern); + GET_IR_NODE_FROM_SUBGRAPH( eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -3580,6 +3625,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( split0_v_out, matmul_qk, matmul_qk_out, + scale_qk, + scale_qk_out, eltadd_qk, eltadd_qk_out, softmax_qk, @@ -3675,6 +3722,23 @@ MultiDevicesFusedMultiTransformerEncoderFuseQKVPass:: .IsNumGT(0) .End(); + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() + .End(); + AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") // the shape shoule be (B, S, N*H) .IsTensor() diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h index 55792456b8..ae7f0e9761 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.h @@ -168,6 +168,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { // Q, K matmul PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(scale_qk); + PATTERN_DECL_NODE(scale_qk_out); PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_out); @@ -263,6 +265,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern // Q, K matmul PATTERN_DECL_NODE(matmul_qk); PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(scale_qk); + PATTERN_DECL_NODE(scale_qk_out); PATTERN_DECL_NODE(eltadd_qk); PATTERN_DECL_NODE(eltadd_qk_b); PATTERN_DECL_NODE(eltadd_qk_out); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc index 08f4dc06f5..5542a802b2 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc @@ -234,11 +234,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { // (eltadd_0) reshape2 -> reshape_0 // (reshape_0) transpose2 -> transpose_0 // (transpose_0) split -> split_q, split_k, - // split_v (split_k) assign -> assign_k + // split_v (split_k) assign -> assign_k // (split_v) assign -> assign_v - // (split_q, split_k) matmul -> matmul_qk - // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk - // (eltadd_qk) softmax -> softmax_qk + // (split_q, split_k) matmul_v2 -> matmul_qk + // (matmul_qk) scale -> scale_qk + // (scale_qk, eltadd_qk) softmax -> softmax_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv // (transpose_qkv) reshape -> reshape_qkv @@ -289,10 +289,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { layers.while_loop({split_k, split_v}); // MHA: QK matmul - auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true); + auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true); + auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); - auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); // MHA: QKV matmul @@ -352,11 +353,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 44, + num_nodes_after + 46, platform::errors::InvalidArgument( "After the fused_multi_transformer_encoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 44, + num_nodes_before - 46, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, @@ -383,10 +384,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { // (eltadd_0) reshape2 -> reshape_0 // (reshape_0) transpose2 -> transpose_0 // (transpose_0) split -> split_q, split_k, - // split_v (split_k) assign -> assign_k + // split_v (split_k) assign -> assign_k // (split_v) assign -> assign_v - // (split_q, split_k) matmul -> matmul_qk - // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (split_q, split_k) matmul_v2 -> matmul_qk + // (matmul_qk) scale -> scale_qk + // (scale_qk, bias_qk) elementwise_add -> eltadd_qk // (eltadd_qk) softmax -> softmax_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv // (matmul_qkv) transpose -> transpose_qkv @@ -442,10 +444,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { layers.while_loop({split_k, split_v}); // MHA: QK matmul - auto* matmul_qk = layers.matmul(split_q, split_k, nullptr, false, true); + auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true); + auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); - auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk); auto* softmax_qk = layers.softmax(elementwise_qk, -1); // MHA: QKV matmul @@ -510,11 +513,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { PADDLE_ENFORCE_EQ( num_nodes_before, - num_nodes_after + 52, + num_nodes_after + 54, platform::errors::InvalidArgument( "After the fused_multi_transformer_encoder_fuse_qkv_pass, " "The node num in graph should be %d, but the result is %d", - num_nodes_before - 52, + num_nodes_before - 54, num_nodes_after)); PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, -- GitLab