From 178b2440648856547228246d0a91d293e66cb952 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 23 Jun 2022 13:18:52 +0800 Subject: [PATCH] general_prelayernorm_transformer (#43748) --- ...n_embedding_eltwise_layernorm_fuse_pass.cc | 71 ++-- ...ln_embedding_eltwise_layernorm_fuse_pass.h | 4 +- .../ir/preln_skip_layernorm_fuse_pass.cc | 30 +- .../ir/remove_padding_recover_padding_pass.cc | 177 ++++++++- .../ir/remove_padding_recover_padding_pass.h | 23 ++ .../ir/trt_multihead_matmul_fuse_pass.cc | 342 ++++++++++++------ .../ir/trt_skip_layernorm_fuse_pass.cc | 14 +- .../convert/preln_emb_eltwise_layernorm.cc | 49 ++- 8 files changed, 511 insertions(+), 199 deletions(-) diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc index 929ffa2cadb..ddcde5014a4 100644 --- a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc @@ -31,7 +31,8 @@ namespace framework { namespace ir { namespace patterns { -static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, +static PDNode* create_emb_vars(PDPattern* pattern, + const std::string& name, const std::string& arg, bool is_persist = false) { std::unordered_set embedding_ops{"lookup_table", @@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, if (is_persist) return node->assert_is_persistable_var(); return node; } -static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name, +static PDNode* create_emb_out_vars(PDPattern* pattern, + const std::string& name, const std::string& arg) { std::unordered_set embedding_ops{"lookup_table", "lookup_table_v2"}; @@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() { create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); std::unordered_set embedding_ops{"lookup_table", "lookup_table_v2"}; + auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed"); + auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed"); + auto* lookup_table1 = pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); auto* lookup_table2 = @@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() { pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) ->assert_is_op_output("elementwise_add"); + feed1->LinksTo({lookup_table1_x}); lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) .LinksTo({lookup_table1_out}); + feed2->LinksTo({lookup_table2_x}); lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) .LinksTo({lookup_table2_out}); eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) @@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() { create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); std::unordered_set embedding_ops{"lookup_table", "lookup_table_v2"}; + auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed"); + auto* lookup_table1 = pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); auto* lookup_table1_out = @@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() { ->assert_is_op_output("elementwise_add"); lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) .LinksTo({lookup_table1_out}); + feed1->LinksTo({lookup_table1_x}); eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) .LinksTo({eltwise_add_out}); } @@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, - start_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out, - start_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + lookup_table1_out, lookup_table1_out, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + lookup_table2_out, lookup_table2_out, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); if (!IsCompat(subgraph, graph)) { @@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( start_pattern_out_node.push_back(eltwise_add_out); std::unordered_set rm_nodes; - rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out, - lookup_table2_out, eltwise_add, eltwise_add_out}); + rm_nodes.insert({lookup_table1, + lookup_table2, + lookup_table1_out, + lookup_table2_out, + eltwise_add, + eltwise_add_out}); start_pattern_remove_nodes.push_back(rm_nodes); }; gpd(graph, handler); @@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, - second_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + lookup_table1_out, lookup_table1_out, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); @@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, - skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltwise_add_out, eltwise_add_out, skip_layernorm_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, - skip_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, - skip_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, - skip_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, - skip_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, - skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_out, layer_norm_out, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_mean, layer_norm_mean, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_variance, layer_norm_variance, skip_layernorm_pattern); if (!IsCompat(subgraph, graph)) { LOG(WARNING) << "Pass(PrelnSkipLayerNorm) in op compat failed."; return; @@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( embs.push_back(inner_pattern_ins[js[iter]].second->Name()); } - OpDesc new_op_desc; + OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block()); new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm"); new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Embs", embs); @@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { bool use_varseqlen = Get("use_varseqlen"); bool with_interleaved = Get("with_interleaved"); bool with_dynamic_shape = Get("with_dynamic_shape"); - if (!(enable_int8 && use_varseqlen && with_interleaved && - with_dynamic_shape)) { - VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, " - "enable_int8, " + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + if (!(enable_int8 && use_varseqlen && with_interleaved && pos_id != "" && + mask_id != "" && with_dynamic_shape)) { + VLOG(3) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, " + "enable_int8, set pos_id, set mask_id, " "use_varseqlen, with_interleaved, with_dynamic_shape. Stop this " "pass, " "please reconfig."; return; } - int fusion_count = PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_); if (fusion_count > 0) { diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h index 1ccc6c85d48..7ca60901ebd 100644 --- a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h @@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase { : PatternBase(pattern, name_scope, "Prelnembedding2_eltwise1") {} void operator()(); - + PATTERN_DECL_NODE(feed1); + PATTERN_DECL_NODE(feed2); PATTERN_DECL_NODE(lookup_table1_x); PATTERN_DECL_NODE(lookup_table2_x); PATTERN_DECL_NODE(lookup_table1_w); @@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase { const std::string& name_scope) : PatternBase(pattern, name_scope, "Prelnembedding1_eltwise1") {} void operator()(); + PATTERN_DECL_NODE(feed1); PATTERN_DECL_NODE(lookup_table1_x); PATTERN_DECL_NODE(lookup_table1_w); PATTERN_DECL_NODE(lookup_table1); diff --git a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc index 80e6c2b7967..84546c1db6a 100644 --- a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc @@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { bool use_varseqlen = Get("use_varseqlen"); bool with_interleaved = Get("with_interleaved"); bool with_dynamic_shape = Get("with_dynamic_shape"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); if (!(enable_int8 && use_varseqlen && with_interleaved && - with_dynamic_shape)) { - VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, " - "use_varseqlen, " - "with_interleaved, with_dynamic_shape. Stop this pass, please " - "reconfig. "; + graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass) && pos_id != "" && + mask_id != "" && with_dynamic_shape)) { + VLOG(3) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, " + "with_interleaved" + "use_varseqlen, preln_embedding_eltwise_layernorm_fuse_pass, " + "trt_multihead_matmul_fuse_pass" + "set pos_id, set mask_id, with_dynamic_shape. Stop this pass, " + "please " + "reconfig."; return; } - int found_subgraph_count = 0; GraphPatternDetector gpd; @@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, - fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, - fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_variance, layer_norm_variance, fused_pattern); std::unordered_set del_node_set; // Create an PrelnSkipLayerNorm op node - OpDesc new_desc; + OpDesc new_desc(elementwise->Op()->Block()); new_desc.SetType("preln_skip_layernorm"); // inputs @@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { found_subgraph_count++; }; - gpd(graph, handler); + AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc index 51280420124..d0023798bb4 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc @@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() { emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out}); } +void PrelnEmbEltwiseLayernorm::operator()() { + // Create nodes for fused_preln_embedding_eltwise_layernorm. + auto* preln_emb_elt_layernorm_op = + pattern->NewNode(preln_emb_elt_layernorm_op_repr()) + ->assert_is_op("fused_preln_embedding_eltwise_layernorm"); + auto* preln_emb_elt_layernorm_out_0 = + pattern->NewNode(preln_emb_elt_layernorm_out_0_repr()) + ->assert_is_op_output("fused_preln_embedding_eltwise_layernorm", + "Out_0"); + auto* preln_emb_elt_layernorm_out_1 = + pattern->NewNode(preln_emb_elt_layernorm_out_1_repr()) + ->assert_is_op_output("fused_preln_embedding_eltwise_layernorm", + "Out_1"); + + // Add links for fused_preln_embedding_eltwise_layernorm op. + preln_emb_elt_layernorm_op->LinksTo( + {preln_emb_elt_layernorm_out_0, preln_emb_elt_layernorm_out_1}); +} + void SkipLayernorm::operator()() { // Create nodes for skip_layernorm. auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr()) @@ -51,6 +70,30 @@ void SkipLayernorm::operator()() { .LinksTo({skip_layernorm_out}); } +void PrelnSkipLayernorm::operator()() { + // Create nodes for preln_skip_layernorm. + auto* preln_skip_layernorm_x = + pattern->NewNode(preln_skip_layernorm_x_repr()) + ->assert_is_op_input("preln_skip_layernorm", "X"); + auto* preln_skip_layernorm_y = + pattern->NewNode(preln_skip_layernorm_y_repr()) + ->assert_is_op_input("preln_skip_layernorm", "Y"); + auto* preln_skip_layernorm_op = + pattern->NewNode(preln_skip_layernorm_op_repr()) + ->assert_is_op("preln_skip_layernorm"); + auto* preln_skip_layernorm_out_0 = + pattern->NewNode(preln_skip_layernorm_out_0_repr()) + ->assert_is_op_output("preln_skip_layernorm", "Out_0"); + auto* preln_skip_layernorm_out_1 = + pattern->NewNode(preln_skip_layernorm_out_1_repr()) + ->assert_is_op_output("preln_skip_layernorm", "Out_1"); + + // Add links for preln_skip_layernorm op. + preln_skip_layernorm_op + ->LinksFrom({preln_skip_layernorm_x, preln_skip_layernorm_y}) + .LinksTo({preln_skip_layernorm_out_0, preln_skip_layernorm_out_1}); +} + void MultiheadMatmul::operator()() { // Create nodes for multihead_matmul. auto* multihead_matmul_input = @@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { std::string mask_id = Get("tensorrt_transformer_maskid"); if (use_varseqlen && pos_id != "" && mask_id != "" && - graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + (graph->Has(framework::ir::kEmbEltwiseLayernormPass) || + graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) && graph->Has(framework::ir::kMultiheadMatmulPass)) { VLOG(3) << "start varseqlen remove_padding_recover_padding_pass"; } else { + VLOG(3) << "remove_padding_recover_padding_pass check failed"; return; } @@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { remove_padding.SetOutput("Out", {remove_padding_out_name}); // set out_threshold for int8 - if (op_node->Op()->HasAttr("out_threshold")) { + if (op_node->Op()->HasAttr("Input_scale")) { remove_padding.SetAttr("out_threshold", - op_node->Op()->GetAttr("out_threshold")); + op_node->Op()->GetAttr("Input_scale")); + } else { + VLOG(3) << "remove_padding_op has not out_threshold, because next op has " + "not Input_scale."; } auto remove_padding_op_node = graph->CreateOpNode(&remove_padding); @@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { if (op_node->Op()->HasAttr("out_threshold")) { recover_padding.SetAttr("out_threshold", op_node->Op()->GetAttr("out_threshold")); + } else if (op_node->Op()->HasAttr("out_0_threshold")) { + recover_padding.SetAttr("out_threshold", + op_node->Op()->GetAttr("out_0_threshold")); + } else if (op_node->Op()->HasAttr("out_1_threshold")) { + recover_padding.SetAttr("out_threshold", + op_node->Op()->GetAttr("out_1_threshold")); + } else { + VLOG(3) << "recover_padding_op has not out_threshold, because previous " + "op has not out_*_threshold."; } auto recover_padding_op_node = graph->CreateOpNode(&recover_padding); @@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "remove_padding_recover_padding_pass for transformer: " "fused_embedding_eltwise_layernorm"; - GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op, emb_elt_layernorm_op, + GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op, + emb_elt_layernorm_op, fused_embedding_eltwise_layernorm); - GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out, emb_elt_layernorm_out, + GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out, + emb_elt_layernorm_out, fused_embedding_eltwise_layernorm); insert_recover_padding_op(emb_elt_layernorm_op, emb_elt_layernorm_out); @@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "remove_padding_recover_padding_pass for transformer: " "multihead_matmul"; - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, - multihead_matmul); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_op, multihead_matmul_op, - multihead_matmul); - GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, - multihead_matmul); + GET_IR_NODE_FROM_SUBGRAPH( + multihead_matmul_input, multihead_matmul_input, multihead_matmul); + GET_IR_NODE_FROM_SUBGRAPH( + multihead_matmul_op, multihead_matmul_op, multihead_matmul); + GET_IR_NODE_FROM_SUBGRAPH( + multihead_matmul_out, multihead_matmul_out, multihead_matmul); multihead_matmul_input_shape = multihead_matmul_input->Var()->GetShape(); @@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "remove_padding_recover_padding_pass for transformer: " "skip_layernorm"; - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x, - skip_layernorm); - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y, - skip_layernorm); - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op, - skip_layernorm); - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out, - skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH( + skip_layernorm_x, skip_layernorm_x, skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH( + skip_layernorm_y, skip_layernorm_y, skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH( + skip_layernorm_op, skip_layernorm_op, skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH( + skip_layernorm_out, skip_layernorm_out, skip_layernorm); std::vector skip_layernorm_x_shape = skip_layernorm_x->Var()->GetShape(); @@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { }; gpd4(graph, handler4); + GraphPatternDetector gpd5; + patterns::PrelnEmbEltwiseLayernorm fused_preln_embedding_eltwise_layernorm( + gpd5.mutable_pattern(), "remove_padding_recover_padding_pass"); + fused_preln_embedding_eltwise_layernorm(); + + auto handler5 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) << "remove_padding_recover_padding_pass for transformer: " + "fused_preln_embedding_eltwise_layernorm"; + + GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_op, + preln_emb_elt_layernorm_op, + fused_preln_embedding_eltwise_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_out_0, + preln_emb_elt_layernorm_out_0, + fused_preln_embedding_eltwise_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_out_1, + preln_emb_elt_layernorm_out_1, + fused_preln_embedding_eltwise_layernorm); + + insert_recover_padding_op(preln_emb_elt_layernorm_op, + preln_emb_elt_layernorm_out_0); + insert_recover_padding_op(preln_emb_elt_layernorm_op, + preln_emb_elt_layernorm_out_1); + + found_subgraph_count++; + }; + gpd5(graph, handler5); + + GraphPatternDetector gpd6; + patterns::PrelnSkipLayernorm preln_skip_layernorm( + gpd6.mutable_pattern(), "remove_padding_recover_padding_pass"); + preln_skip_layernorm(); + + auto handler6 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) << "remove_padding_recover_padding_pass for transformer: " + "preln_skip_layernorm"; + + GET_IR_NODE_FROM_SUBGRAPH( + preln_skip_layernorm_x, preln_skip_layernorm_x, preln_skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH( + preln_skip_layernorm_y, preln_skip_layernorm_y, preln_skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH( + preln_skip_layernorm_op, preln_skip_layernorm_op, preln_skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(preln_skip_layernorm_out_0, + preln_skip_layernorm_out_0, + preln_skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(preln_skip_layernorm_out_1, + preln_skip_layernorm_out_1, + preln_skip_layernorm); + + std::vector skip_layernorm_x_shape = + preln_skip_layernorm_x->Var()->GetShape(); + if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) { + check_flag = false; + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + for (size_t i = 0; i < skip_layernorm_x_shape.size(); ++i) { + if (skip_layernorm_x_shape[i] != multihead_matmul_input_shape[i]) { + check_flag = false; + } + } + if (!check_flag) { + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + insert_remove_padding_op(preln_skip_layernorm_x, preln_skip_layernorm_op); + insert_remove_padding_op(preln_skip_layernorm_y, preln_skip_layernorm_op); + insert_recover_padding_op(preln_skip_layernorm_op, + preln_skip_layernorm_out_0); + insert_recover_padding_op(preln_skip_layernorm_op, + preln_skip_layernorm_out_1); + found_subgraph_count++; + }; + gpd6(graph, handler6); + AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h index 7b8075644cb..f93ee4bc7c4 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h @@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase { PATTERN_DECL_NODE(emb_elt_layernorm_out); }; +struct PrelnEmbEltwiseLayernorm : public PatternBase { + PrelnEmbEltwiseLayernorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "preln_emb_elt_layernorm") {} + + void operator()(); + PATTERN_DECL_NODE(preln_emb_elt_layernorm_op); + PATTERN_DECL_NODE(preln_emb_elt_layernorm_out_0); + PATTERN_DECL_NODE(preln_emb_elt_layernorm_out_1); +}; + struct SkipLayernorm : public PatternBase { SkipLayernorm(PDPattern *pattern, const std::string &name_scope) : PatternBase(pattern, name_scope, "skip_layernorm") {} @@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase { PATTERN_DECL_NODE(skip_layernorm_out); }; +struct PrelnSkipLayernorm : public PatternBase { + PrelnSkipLayernorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "preln_skip_layernorm") {} + + void operator()(); + + PATTERN_DECL_NODE(preln_skip_layernorm_x); + PATTERN_DECL_NODE(preln_skip_layernorm_y); + PATTERN_DECL_NODE(preln_skip_layernorm_op); + PATTERN_DECL_NODE(preln_skip_layernorm_out_0); + PATTERN_DECL_NODE(preln_skip_layernorm_out_1); +}; + struct MultiheadMatmul : public PatternBase { MultiheadMatmul(PDPattern *pattern, const std::string &name_scope) : PatternBase(pattern, name_scope, "multihead_matmul") {} diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 8fff2f953c3..eb5b734291d 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -51,11 +51,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { multihead_pattern(); // Create New OpDesc - auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2, - Node* mul0_out, Node* mul1_out, Node* mul2_out, - Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, - Node* eltadd_qk_b, Node* reshape2, - Node* reshape2_qkv_out, Node* scale, + auto fuse_creater = [&](Node* input0, + Node* mul0, + Node* mul1, + Node* mul2, + Node* mul0_out, + Node* mul1_out, + Node* mul2_out, + Node* eltadd0_b, + Node* eltadd1_b, + Node* eltadd2_b, + Node* eltadd_qk_b, + Node* reshape2, + Node* reshape2_qkv_out, + Node* scale, Node* scale_out) { auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); // auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias")); @@ -123,11 +132,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); @@ -135,21 +144,21 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, multihead_pattern); // nodes need be removed GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); @@ -172,24 +181,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, - multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, - multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, - multihead_pattern); - - fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, - eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, - reshape2_qkv_out, scale, scale_out); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); + + fuse_creater(input0, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + eltadd0_b, + eltadd1_b, + eltadd2_b, + eltadd_qk_b, + reshape2_0, + reshape2_qkv_out, + scale, + scale_out); std::unordered_set marked_nodes( {eltadd0, @@ -777,14 +798,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, multihead_pattern(); // Create New OpDesc - auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2, - Node* mul0_out, Node* mul1_out, Node* mul2_out, - Node* mul0_w, Node* mul1_w, Node* mul2_w, - Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, - Node* eltadd_qk_b, Node* reshape2, - Node* reshape2_qkv_out, Node* scale, Node* scale_out, - Node* softmax_qk, Node* eltadd0, Node* eltadd1, - Node* eltadd2, Node* matmul_qk, Node* reshape2_qkv) { + auto fuse_creater = [&](Node* input0, + Node* mul0, + Node* mul1, + Node* mul2, + Node* mul0_out, + Node* mul1_out, + Node* mul2_out, + Node* mul0_w, + Node* mul1_w, + Node* mul2_w, + Node* eltadd0_b, + Node* eltadd1_b, + Node* eltadd2_b, + Node* eltadd_qk_b, + Node* reshape2, + Node* reshape2_qkv_out, + Node* scale, + Node* scale_out, + Node* softmax_qk, + Node* eltadd0, + Node* eltadd1, + Node* eltadd2, + Node* matmul_qk, + Node* reshape2_qkv) { auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) @@ -842,7 +879,8 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, wq_tensor->Resize(combined_w_dims); auto* new_combined_w_data = wq_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_combined_w_data, tmp_combined_w_data, + memcpy(new_combined_w_data, + tmp_combined_w_data, sizeof(float) * wq_tensor->numel()); scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); @@ -854,15 +892,17 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, size_t bias_size = bq_tensor->numel(); memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); - memcpy(tmp_combined_bias_data + bias_size, bk_data, - sizeof(float) * bias_size); - memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data, + memcpy( + tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + 2 * bias_size, + bv_data, sizeof(float) * bias_size); bq_tensor->Resize(combined_bias_dims); auto* new_combined_bias_data = bq_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_combined_bias_data, tmp_combined_bias_data, + memcpy(new_combined_bias_data, + tmp_combined_bias_data, sizeof(float) * bq_tensor->numel()); scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); @@ -944,11 +984,11 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); @@ -956,21 +996,21 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, multihead_pattern); // nodes need be removed GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); @@ -993,20 +1033,20 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, - multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, - multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); // If weights or biases in qkv's fc are shared by multiple multihead_matmul // patterns, we do not support this kind of fusion, this pass will not take @@ -1018,10 +1058,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, if (is_fc_params_shared) { return; } - fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, - mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, - reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk, - eltadd0, eltadd1, eltadd2, matmul_qk, reshape2_qkv); + fuse_creater(input0, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul0_w, + mul1_w, + mul2_w, + eltadd0_b, + eltadd1_b, + eltadd2_b, + eltadd_qk_b, + reshape2_0, + reshape2_qkv_out, + scale, + scale_out, + softmax_qk, + eltadd0, + eltadd1, + eltadd2, + matmul_qk, + reshape2_qkv); std::unordered_set marked_nodes({eltadd0, eltadd1, @@ -1083,19 +1143,28 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusionV2(graph, name_scope_, scope); if (fusion_count > 0) { bool use_varseqlen = Get("use_varseqlen"); + bool with_interleaved = Get("with_interleaved"); std::string pos_id = Get("tensorrt_transformer_posid"); std::string mask_id = Get("tensorrt_transformer_maskid"); if (use_varseqlen && pos_id != "" && mask_id != "") { - if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) { - VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v2"; + if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) || + graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) { + if (with_interleaved) { + VLOG(3) << "start interleaved_format " + "varseqlen_trt_multihead_matmul_fuse_pass_v2"; + } else { + VLOG(3) << "start varseqlen_trt_multihead_matmul_fuse_pass_v2"; + } } else { - PADDLE_THROW(platform::errors::Fatal( - "Use transformer'varseqlen need " - "embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); + PADDLE_THROW( + platform::errors::Fatal("Use transformer'varseqlen need " + "embedding_eltwise_layernorm_fuse_pass or " + "preln_embedding_eltwise_layernorm_fuse_" + "pass. please use no_varseqlen")); } } else if (!use_varseqlen && pos_id == "" && mask_id == "") { - VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v2"; + VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " @@ -1251,12 +1320,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, multihead_pattern(); // Create New OpDesc - auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2, - Node* mul0_out, Node* mul1_out, Node* mul2_out, - Node* mul0_w, Node* mul1_w, Node* mul2_w, - Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, - Node* eltadd_qk_b, Node* reshape2, - Node* reshape2_qkv_out, Node* matmul_qk) { + auto fuse_creater = [&](Node* input0, + Node* mul0, + Node* mul1, + Node* mul2, + Node* mul0_out, + Node* mul1_out, + Node* mul2_out, + Node* mul0_w, + Node* mul1_w, + Node* mul2_w, + Node* eltadd0_b, + Node* eltadd1_b, + Node* eltadd2_b, + Node* eltadd_qk_b, + Node* reshape2, + Node* reshape2_qkv_out, + Node* matmul_qk) { auto scale_attr = BOOST_GET_CONST(float, matmul_qk->Op()->GetAttr("alpha")); // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) @@ -1314,7 +1394,8 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, wq_tensor->Resize(combined_w_dims); auto* new_combined_w_data = wq_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_combined_w_data, tmp_combined_w_data, + memcpy(new_combined_w_data, + tmp_combined_w_data, sizeof(float) * wq_tensor->numel()); scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); @@ -1326,15 +1407,17 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, size_t bias_size = bq_tensor->numel(); memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); - memcpy(tmp_combined_bias_data + bias_size, bk_data, - sizeof(float) * bias_size); - memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data, + memcpy( + tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + 2 * bias_size, + bv_data, sizeof(float) * bias_size); bq_tensor->Resize(combined_bias_dims); auto* new_combined_bias_data = bq_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_combined_bias_data, tmp_combined_bias_data, + memcpy(new_combined_bias_data, + tmp_combined_bias_data, sizeof(float) * bq_tensor->numel()); scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); @@ -1375,31 +1458,31 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, multihead_pattern); // nodes need be removed GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); @@ -1422,20 +1505,20 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, - multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, - multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); // If weights or biases in qkv's fc are shared by multiple multihead_matmul // patterns, we do not support this kind of fusion, this pass will not take @@ -1447,9 +1530,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, if (is_fc_params_shared) { return; } - fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, - mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, - reshape2_0, reshape2_qkv_out, matmul_qk); + fuse_creater(input0, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul0_w, + mul1_w, + mul2_w, + eltadd0_b, + eltadd1_b, + eltadd2_b, + eltadd_qk_b, + reshape2_0, + reshape2_qkv_out, + matmul_qk); std::unordered_set marked_nodes({eltadd0, eltadd1, @@ -1510,19 +1607,28 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusionV3(graph, name_scope_, scope); if (fusion_count > 0) { bool use_varseqlen = Get("use_varseqlen"); + bool with_interleaved = Get("with_interleaved"); std::string pos_id = Get("tensorrt_transformer_posid"); std::string mask_id = Get("tensorrt_transformer_maskid"); if (use_varseqlen && pos_id != "" && mask_id != "") { - if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) { - VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v3"; + if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) || + graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) { + if (with_interleaved) { + VLOG(3) << "start interleaved_format " + "varseqlen_trt_multihead_matmul_fuse_pass_v3"; + } else { + VLOG(3) << "start varseqlen_trt_multihead_matmul_fuse_pass_v3"; + } } else { - PADDLE_THROW(platform::errors::Fatal( - "Use transformer'varseqlen need " - "embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); + PADDLE_THROW( + platform::errors::Fatal("Use transformer'varseqlen need " + "embedding_eltwise_layernorm_fuse_pass or " + "preln_embedding_eltwise_layernorm_fuse_" + "pass. please use no_varseqlen")); } } else if (!use_varseqlen && pos_id == "" && mask_id == "") { - VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v3"; + VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 13883909435..d33adab8b3e 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, - fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, - fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_variance, layer_norm_variance, fused_pattern); std::unordered_set del_node_set; @@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { std::string mask_id = Get("tensorrt_transformer_maskid"); if (use_varseqlen && pos_id != "" && mask_id != "") { - if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + if ((graph->Has(framework::ir::kEmbEltwiseLayernormPass) || + graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) && graph->Has(framework::ir::kMultiheadMatmulPass)) { VLOG(3) << "start varseqlen trt_skip_layernorm_fuse_pass"; } else { PADDLE_THROW(platform::errors::Fatal( "Use transformer'varseqlen need " - "embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); + "trt_embedding_eltwise_layernorm_fuse_pass, " + "trt_multihead_matmul_fuse_pass. please use no_varseqlen")); } } else if (!use_varseqlen && pos_id == "" && mask_id == "") { VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass"; diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index 4ee8db7c69d..78dd812e035 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -28,13 +28,21 @@ namespace tensorrt { class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { + const framework::Scope& scope, + bool test_mode) override { #if IS_TRT_VERSION_GE(7000) VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"; - if (!(engine_->use_varseqlen() && engine_->with_interleaved())) { + auto pos_id_name = engine_->tensorrt_transformer_posid(); + auto mask_id_name = engine_->tensorrt_transformer_maskid(); + bool flag_prelayernorm = engine_->with_interleaved() && + engine_->use_varseqlen() && pos_id_name != "" && + mask_id_name != ""; + + if (!flag_prelayernorm) { PADDLE_THROW(platform::errors::Fatal( - "PrelnErnie: If you want to use oss, must be with interleaved")); + "PrelnErnie: If you want to use varseqlen, must be with interleaved, " + "set pos_id_name, set mask_id_name.")); } framework::OpDesc op_desc(op, nullptr); bool enable_int8 = op_desc.HasAttr("enable_int8"); @@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { platform::errors::Fatal("use with_interleaved must be int8.")); } auto word_id_name = op_desc.Input("WordId").front(); - auto pos_id_name = op_desc.Input("PosId").front(); engine_->Set("ernie_pos_name", new std::string(pos_id_name)); auto sent_id_name = op_desc.Input("SentId").front(); @@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { auto pos_emb_name = op_desc.Input("PosEmbedding").front(); auto sent_emb_name = op_desc.Input("SentEmbedding").front(); + engine_->SetITensor("word_id", engine_->GetITensor(word_id_name)); + engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); + engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); + std::vector emb_names; emb_names = std::vector{word_emb_name, pos_emb_name, sent_emb_name}; @@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { input_embs.push_back(emb_data); emb_sizes.push_back(emb_size); PADDLE_ENFORCE_EQ( - emb_dims.size(), 2, + emb_dims.size(), + 2, platform::errors::InvalidArgument( "The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims.")); } @@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { int output_int8 = 1; PADDLE_ENFORCE_EQ( - input_num, 3, + input_num, + 3, platform::errors::InvalidArgument( "When using oss and var-len, embedding_eltwise_layernorm op" "should have 3 inputs only, but got %d.", input_num)); const std::vector fields{ - {"bert_embeddings_layernorm_beta", bias, - nvinfer1::PluginFieldType::kFLOAT32, static_cast(bias_size)}, - {"bert_embeddings_layernorm_gamma", scale, - nvinfer1::PluginFieldType::kFLOAT32, static_cast(scale_size)}, - {"bert_embeddings_word_embeddings", input_embs[0], + {"bert_embeddings_layernorm_beta", + bias, + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(bias_size)}, + {"bert_embeddings_layernorm_gamma", + scale, + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(scale_size)}, + {"bert_embeddings_word_embeddings", + input_embs[0], nvinfer1::PluginFieldType::kFLOAT32, static_cast(emb_sizes[0])}, - {"bert_embeddings_token_type_embeddings", input_embs[2], + {"bert_embeddings_token_type_embeddings", + input_embs[2], nvinfer1::PluginFieldType::kFLOAT32, static_cast(emb_sizes[2])}, - {"bert_embeddings_position_embeddings", input_embs[1], + {"bert_embeddings_position_embeddings", + input_embs[1], nvinfer1::PluginFieldType::kFLOAT32, static_cast(emb_sizes[1])}, {"output_fp16", &output_int8, nvinfer1::PluginFieldType::kINT32, 1}, @@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { plugin_inputs.emplace_back( engine_->GetITensor(pos_id_name)); // cu_seqlens, // eval_placeholder_2 - auto max_seqlen_tensor = - engine_->GetITensor(engine_->network()->getInput(3)->getName()); + auto max_seqlen_tensor = engine_->GetITensor(mask_id_name); auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor); nvinfer1::Dims shape_dim; -- GitLab