From 2810dfea475e811d7a919d3ee9c5317e0f865da3 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 2 Jun 2022 12:21:17 +0800 Subject: [PATCH] [Paddle-Inference] new general transformer inference support (#43077) * new general transformer inference support --- paddle/fluid/framework/ir/CMakeLists.txt | 3 + ...n_embedding_eltwise_layernorm_fuse_pass.cc | 8 +- .../ir/preln_skip_layernorm_fuse_pass.cc | 7 +- .../ir/remove_padding_recover_padding_pass.cc | 204 ++- .../ir/remove_padding_recover_padding_pass.h | 8 + .../ir/set_transformer_input_convert_pass.cc | 185 +- .../ir/set_transformer_input_convert_pass.h | 43 +- ...t_embedding_eltwise_layernorm_fuse_pass.cc | 477 +++++ ...rt_embedding_eltwise_layernorm_fuse_pass.h | 167 ++ .../ir/trt_multihead_matmul_fuse_pass.cc | 1546 +++++++++++++++++ .../ir/trt_multihead_matmul_fuse_pass.h | 179 ++ .../ir/trt_skip_layernorm_fuse_pass.cc | 232 +++ .../ir/trt_skip_layernorm_fuse_pass.h | 87 + paddle/fluid/inference/analysis/argument.h | 6 +- .../inference/analysis/ir_pass_manager.cc | 6 +- .../ir_passes/tensorrt_subgraph_pass.cc | 10 +- paddle/fluid/inference/api/analysis_config.cc | 12 +- .../fluid/inference/api/analysis_predictor.cc | 21 +- .../inference/api/paddle_analysis_config.h | 8 +- paddle/fluid/inference/api/paddle_api.h | 6 + .../inference/api/paddle_pass_builder.cc | 32 +- paddle/fluid/inference/capi_exp/pd_config.cc | 6 +- paddle/fluid/inference/capi_exp/pd_config.h | 2 +- paddle/fluid/inference/goapi/config.go | 4 +- paddle/fluid/inference/goapi/config_test.go | 4 +- .../inference/tensorrt/convert/CMakeLists.txt | 4 + .../tensorrt/convert/emb_eltwise_layernorm.cc | 28 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 3 +- .../tensorrt/convert/multihead_matmul_op.cc | 30 +- .../convert/preln_emb_eltwise_layernorm.cc | 2 +- .../tensorrt/convert/preln_skip_layernorm.cc | 5 +- .../tensorrt/convert/recover_padding_op.cc | 76 + .../tensorrt/convert/remove_padding_op.cc | 69 + .../tensorrt/convert/skip_layernorm.cc | 9 +- .../inference/tensorrt/convert/slice_op.cc | 48 +- .../convert/transformer_input_convert_op.cc | 72 + paddle/fluid/inference/tensorrt/engine.h | 21 +- paddle/fluid/inference/tensorrt/op_teller.cc | 10 +- .../inference/tensorrt/plugin/CMakeLists.txt | 5 +- .../tensorrt/plugin/recover_padding_plugin.cu | 120 ++ .../tensorrt/plugin/recover_padding_plugin.h | 133 ++ .../tensorrt/plugin/remove_padding_plugin.cu | 118 ++ .../tensorrt/plugin/remove_padding_plugin.h | 133 ++ .../tensorrt/plugin/special_slice_plugin.cu | 197 --- .../tensorrt/plugin/special_slice_plugin.h | 98 -- .../transformer_input_convert_plugin.cu | 110 ++ .../plugin/transformer_input_convert_plugin.h | 134 ++ .../tests/api/analyzer_capi_exp_gpu_tester.cc | 2 +- .../tests/api/trt_dynamic_shape_ernie_test.cc | 6 +- .../tests/infer_ut/test_ernie_xnli_int8.cc | 2 +- .../inference/utils/table_printer_tester.cc | 2 +- paddle/fluid/pybind/inference_api.cc | 5 +- .../ir/inference/inference_pass_test.py | 6 +- .../ir/inference/quant_dequant_test.py | 6 +- .../inference/test_trt_multiclass_nms3_op.py | 6 +- .../inference/test_trt_multiclass_nms_op.py | 6 +- 56 files changed, 4124 insertions(+), 605 deletions(-) create mode 100644 paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h create mode 100644 paddle/fluid/inference/tensorrt/convert/recover_padding_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/remove_padding_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h create mode 100644 paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h delete mode 100644 paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu delete mode 100644 paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h create mode 100644 paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8166c43e65d..3fc938f7641 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -107,6 +107,9 @@ target_link_libraries(generate_pass pass_desc_proto) if(WITH_TENSORRT) pass_library(trt_map_matmul_to_mul_pass inference) + pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) + pass_library(trt_multihead_matmul_fuse_pass inference) + pass_library(trt_skip_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(set_transformer_input_convert_pass inference) diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc index d6761d2e82e..929ffa2cadb 100644 --- a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc @@ -430,13 +430,15 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); bool enable_int8 = Get("enable_int8"); - bool use_oss = Get("use_oss"); + bool use_varseqlen = Get("use_varseqlen"); bool with_interleaved = Get("with_interleaved"); bool with_dynamic_shape = Get("with_dynamic_shape"); - if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) { + if (!(enable_int8 && use_varseqlen && with_interleaved && + with_dynamic_shape)) { VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, " "enable_int8, " - "use_oss, with_interleaved, with_dynamic_shape. Stop this pass, " + "use_varseqlen, with_interleaved, with_dynamic_shape. Stop this " + "pass, " "please reconfig."; return; } diff --git a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc index 978360d8f0a..6c06b741adb 100644 --- a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc @@ -109,12 +109,13 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("preln_skip_layernorm_fuse", graph); bool enable_int8 = Get("enable_int8"); - bool use_oss = Get("use_oss"); + bool use_varseqlen = Get("use_varseqlen"); bool with_interleaved = Get("with_interleaved"); bool with_dynamic_shape = Get("with_dynamic_shape"); - if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) { + if (!(enable_int8 && use_varseqlen && with_interleaved && + with_dynamic_shape)) { VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, " - "use_oss, " + "use_varseqlen, " "with_interleaved, with_dynamic_shape. Stop this pass, please " "reconfig. "; return; diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc index 67dfe074dc0..ee9474f6fad 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc @@ -22,6 +22,19 @@ namespace paddle { namespace framework { namespace ir { namespace patterns { +void EmbEltwiseLayernorm::operator()() { + // Create nodes for fused_embedding_eltwise_layernorm. + auto* emb_elt_layernorm_op = + pattern->NewNode(emb_elt_layernorm_op_repr()) + ->assert_is_op("fused_embedding_eltwise_layernorm"); + auto* emb_elt_layernorm_out = + pattern->NewNode(emb_elt_layernorm_out_repr()) + ->assert_is_op_output("fused_embedding_eltwise_layernorm", "Out"); + + // Add links for fused_embedding_eltwise_layernorm op. + emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out}); +} + void SkipLayernorm::operator()() { // Create nodes for skip_layernorm. auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr()) @@ -59,16 +72,12 @@ void Fc::operator()() { auto* fc_input = pattern->NewNode(fc_input_repr())->assert_is_op_input("fc", "Input"); auto* fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc"); - auto* fc_out = - pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out"); - - // Add links for fc op. - fc_op->LinksFrom({fc_input}).LinksTo({fc_out}); + fc_op->LinksFrom({fc_input}); } void Activation::operator()() { // Create nodes for activation. - std::unordered_set activation_ops{"relu", "sigmoid", "tanh"}; + std::unordered_set activation_ops{"relu", "sigmoid", "gelu"}; auto* activation_input = pattern->NewNode(activation_input_repr()) ->assert_is_ops_input(activation_ops); auto* activation_op = @@ -82,6 +91,18 @@ void Activation::operator()() { } // namespace patterns void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { + bool use_varseqlen = Get("use_varseqlen"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + + if (use_varseqlen && pos_id != "" && mask_id != "" && + graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass)) { + VLOG(3) << "start varseqlen remove_padding_recover_padding_pass"; + } else { + return; + } + PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init(name_scope_, graph); @@ -91,14 +112,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { // Create an remove_padding op node auto insert_remove_padding_op = [&](Node* input_node, Node* op_node) { // create op, var in graph - OpDesc remove_padding; + OpDesc remove_padding(op_node->Op()->Block()); std::string remove_padding_out_name = input_node->Name() + ".remove_padding"; - - VarDesc remove_padding_out(remove_padding_out_name); - remove_padding_out.SetDataType(input_node->Var()->GetDataType()); - remove_padding_out.SetShape(input_node->Var()->GetShape()); - remove_padding_out.SetPersistable(false); + auto* remove_padding_out = + op_node->Op()->Block()->Var(remove_padding_out_name); + remove_padding_out->SetDataType(input_node->Var()->GetDataType()); + remove_padding_out->SetShape(input_node->Var()->GetShape()); + remove_padding_out->SetPersistable(false); // remove_padding_op remove_padding.SetType("remove_padding"); @@ -110,7 +131,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { remove_padding.SetOutput("Out", {remove_padding_out_name}); auto remove_padding_op_node = graph->CreateOpNode(&remove_padding); - auto remove_padding_out_node = graph->CreateVarNode(&remove_padding_out); + auto remove_padding_out_node = graph->CreateVarNode(remove_padding_out); // replace link for (size_t i = 0; i < input_node->outputs.size(); ++i) { @@ -145,13 +166,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { // create an remove_padding op node auto insert_recover_padding_op = [&](Node* op_node, Node* out_node) { // create op, var in graph - OpDesc recover_padding; + OpDesc recover_padding(op_node->Op()->Block()); std::string recover_padding_input_name = out_node->Name() + ".recover_padding"; - VarDesc recover_padding_input(recover_padding_input_name); - recover_padding_input.SetDataType(out_node->Var()->GetDataType()); - recover_padding_input.SetShape(out_node->Var()->GetShape()); - recover_padding_input.SetPersistable(false); + auto* recover_padding_input = + op_node->Op()->Block()->Var(recover_padding_input_name); + recover_padding_input->SetDataType(out_node->Var()->GetDataType()); + recover_padding_input->SetShape(out_node->Var()->GetShape()); + recover_padding_input->SetPersistable(false); // recover_padding_op recover_padding.SetType("recover_padding"); @@ -164,7 +186,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { auto recover_padding_op_node = graph->CreateOpNode(&recover_padding); auto recover_padding_input_node = - graph->CreateVarNode(&recover_padding_input); + graph->CreateVarNode(recover_padding_input); // replace link for (size_t i = 0; i < op_node->outputs.size(); ++i) { @@ -195,39 +217,36 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { op_node->Op()->RenameOutput(out_node->Name(), recover_padding_input_name); }; - GraphPatternDetector gpd1; - patterns::SkipLayernorm skip_layernorm(gpd1.mutable_pattern(), - "remove_padding_recover_padding_pass"); - skip_layernorm(); + bool check_flag = true; - auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph, + GraphPatternDetector gpd0; + patterns::EmbEltwiseLayernorm fused_embedding_eltwise_layernorm( + gpd0.mutable_pattern(), "remove_padding_recover_padding_pass"); + fused_embedding_eltwise_layernorm(); + + auto handler0 = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(3) << "remove_padding_recover_padding_pass for transformer: " - "skip_layernorm"; + "fused_embedding_eltwise_layernorm"; - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x, - skip_layernorm); - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y, - skip_layernorm); - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op, - skip_layernorm); - GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out, - skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op, emb_elt_layernorm_op, + fused_embedding_eltwise_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out, emb_elt_layernorm_out, + fused_embedding_eltwise_layernorm); - insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op); - insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op); - insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out); + insert_recover_padding_op(emb_elt_layernorm_op, emb_elt_layernorm_out); found_subgraph_count++; }; - gpd1(graph, handler1); + gpd0(graph, handler0); - GraphPatternDetector gpd2; + GraphPatternDetector gpd1; patterns::MultiheadMatmul multihead_matmul( - gpd2.mutable_pattern(), "remove_padding_recover_padding_pass"); + gpd1.mutable_pattern(), "remove_padding_recover_padding_pass"); multihead_matmul(); - auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, + std::vector multihead_matmul_input_shape; + auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(3) << "remove_padding_recover_padding_pass for transformer: " "multihead_matmul"; @@ -239,11 +258,57 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, multihead_matmul); + multihead_matmul_input_shape = multihead_matmul_input->Var()->GetShape(); + insert_remove_padding_op(multihead_matmul_input, multihead_matmul_op); insert_recover_padding_op(multihead_matmul_op, multihead_matmul_out); found_subgraph_count++; }; + gpd1(graph, handler1); + + GraphPatternDetector gpd2; + patterns::SkipLayernorm skip_layernorm(gpd2.mutable_pattern(), + "remove_padding_recover_padding_pass"); + skip_layernorm(); + + auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) << "remove_padding_recover_padding_pass for transformer: " + "skip_layernorm"; + + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x, + skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y, + skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op, + skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out, + skip_layernorm); + + std::vector skip_layernorm_x_shape = + skip_layernorm_x->Var()->GetShape(); + if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) { + check_flag = false; + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + for (size_t i = 0; i < skip_layernorm_x_shape.size(); ++i) { + if (skip_layernorm_x_shape[i] != multihead_matmul_input_shape[i]) { + check_flag = false; + } + } + if (!check_flag) { + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op); + insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op); + insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out); + found_subgraph_count++; + }; gpd2(graph, handler2); GraphPatternDetector gpd3; @@ -257,11 +322,39 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, fc); GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc); - GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc); - insert_remove_padding_op(fc_input, fc_op); - insert_recover_padding_op(fc_op, fc_out); + std::vector fc_input_shape = fc_input->Var()->GetShape(); + if ((fc_input_shape.size() != multihead_matmul_input_shape.size()) || + (fc_input_shape.size() != 3)) { + check_flag = false; + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + if (fc_input_shape[0] != multihead_matmul_input_shape[0]) { + check_flag = false; + } + if (fc_input_shape[1] != multihead_matmul_input_shape[1]) { + check_flag = false; + } + if ((fc_input_shape[2] != multihead_matmul_input_shape[2]) && + (fc_input_shape[2] != 4 * multihead_matmul_input_shape[2])) { + check_flag = false; + } + if (BOOST_GET_CONST(int, fc_op->Op()->GetAttr("in_num_col_dims")) != 2) { + check_flag = false; + } + if (!check_flag) { + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + fc_op->Op()->RemoveAttr("in_num_col_dims"); + fc_op->Op()->SetAttr("in_num_col_dims", 1); + + insert_remove_padding_op(fc_input, fc_op); + insert_recover_padding_op(fc_op, fc_op->outputs[0]); found_subgraph_count++; }; gpd3(graph, handler3); @@ -280,6 +373,31 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation_op, activation); GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, activation); + std::vector activation_input_shape = + activation_input->Var()->GetShape(); + if ((activation_input_shape.size() != + multihead_matmul_input_shape.size()) || + (activation_input_shape.size() != 3)) { + check_flag = false; + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + if (activation_input_shape[0] != multihead_matmul_input_shape[0]) { + check_flag = false; + } + if (activation_input_shape[1] != multihead_matmul_input_shape[1]) { + check_flag = false; + } + if ((activation_input_shape[2] != multihead_matmul_input_shape[2]) && + (activation_input_shape[2] != 4 * multihead_matmul_input_shape[2])) { + check_flag = false; + } + if (!check_flag) { + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } insert_remove_padding_op(activation_input, activation_op); insert_recover_padding_op(activation_op, activation_out); diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h index d7ccfc75c20..7b8075644cb 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h @@ -32,6 +32,14 @@ namespace paddle { namespace framework { namespace ir { namespace patterns { +struct EmbEltwiseLayernorm : public PatternBase { + EmbEltwiseLayernorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "emb_elt_layernorm") {} + + void operator()(); + PATTERN_DECL_NODE(emb_elt_layernorm_op); + PATTERN_DECL_NODE(emb_elt_layernorm_out); +}; struct SkipLayernorm : public PatternBase { SkipLayernorm(PDPattern *pattern, const std::string &name_scope) diff --git a/paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc index 37e77bc134d..f177f607087 100644 --- a/paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc +++ b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc @@ -21,129 +21,134 @@ namespace paddle { namespace framework { namespace ir { -SetTransformerInputConvertPass::SetTransformerInputConvertPass() { - AddOpCompat(OpCompat("elementwise_add")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("axis") - .End(); -} namespace patterns { -void SetTransformerInputConvert::operator()() { +void SetTransformerInputConvert::operator()(const std::string &pos_id) { std::unordered_set lookup_table_ops{"lookup_table", "lookup_table_v2"}; - // Create nodes for lookup_table1 op. - auto *lookup_table1_x = pattern->NewNode(lookup_table1_x_repr()) - ->assert_is_ops_input(lookup_table_ops, "Ids"); - auto *lookup_table1_w = pattern->NewNode(lookup_table1_w_repr()) - ->assert_is_ops_input(lookup_table_ops, "W"); - auto *lookup_table1_op = - pattern->NewNode(lookup_table1_repr())->assert_is_ops(lookup_table_ops); - auto *lookup_table1_out = pattern->NewNode(lookup_table1_out_repr()) - ->assert_is_ops_output(lookup_table_ops) - ->AsIntermediate() - ->assert_is_op_input("elementwise_add", "X"); - - // Create nodes for lookup_table2 op. - auto *lookup_table2_x = pattern->NewNode(lookup_table2_x_repr()) - ->assert_is_ops_input(lookup_table_ops, "Ids"); - auto *lookup_table2_w = pattern->NewNode(lookup_table2_w_repr()) - ->assert_is_ops_input(lookup_table_ops, "W"); - auto *lookup_table2_op = - pattern->NewNode(lookup_table2_repr())->assert_is_ops(lookup_table_ops); - auto *lookup_table2_out = pattern->NewNode(lookup_table2_out_repr()) - ->assert_is_ops_output(lookup_table_ops) - ->AsIntermediate() - ->assert_is_op_input("elementwise_add", "Y"); - - // Create nodes for elementwise_add op. - auto *elementwise_op = - pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); - auto *elementwise_out = pattern->NewNode(elementwise_out_repr()) - ->AsOutput() - ->assert_is_only_output_of_op("elementwise_add"); + // Create nodes for lookup_table. + auto *lookup_table_id = + pattern->NewNode(lookup_table_id_repr()) + ->assert_is_ops_input(lookup_table_ops, "Ids") + ->assert_more([&](Node *node) { return node->Name() == pos_id; }); + auto *lookup_table_op = + pattern->NewNode(lookup_table_repr())->assert_is_ops(lookup_table_ops); // links nodes. - lookup_table1_op->LinksFrom({lookup_table1_x, lookup_table1_w}) - .LinksTo({lookup_table1_out}); - lookup_table2_op->LinksFrom({lookup_table2_x, lookup_table2_w}) - .LinksTo({lookup_table2_out}); - elementwise_op->LinksFrom({lookup_table1_out, lookup_table2_out}) - .LinksTo({elementwise_out}); + lookup_table_op->LinksFrom({lookup_table_id}); } +void MultiheadMatmulOP::operator()() { + // Create nodes for multihead_matmul op. + auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr()) + ->assert_is_op("multihead_matmul"); + auto *multihead_matmul_out = + pattern->NewNode(multihead_matmul_out_repr()) + ->assert_is_op_output("multihead_matmul", "Out"); + + // links nodes. + multihead_matmul_out->LinksFrom({multihead_matmul}); +} } // namespace patterns void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const { + bool with_dynamic_shape = Get("with_dynamic_shape"); + std::string pos_id = Get("tensorrt_transformer_posid"); + + if (!(graph->Has(framework::ir::kMultiheadMatmulPass) && with_dynamic_shape && + (pos_id != ""))) { + VLOG(3) << "Transformer model need MultiheadMatmul, and " + "with_dynamic_shape. Stop this pass, " + "please reconfig."; + return; + } PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init(name_scope_, graph); int found_subgraph_count = 0; - - GraphPatternDetector gpd; + Node *transformer_input_convert_out0_node; + Node *transformer_input_convert_out1_node; + GraphPatternDetector gpd0; patterns::SetTransformerInputConvert fused_pattern( - gpd.mutable_pattern(), "transformer_input_convert_pass"); - fused_pattern(); - - auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, - Graph *graph) { - if (!IsCompat(subgraph, graph)) { - LOG(WARNING) << "transformer_input_convert_pass in op compat failed."; - return; - } - - VLOG(3) << "transformer_input_convert_pass for pos_id, max_seqlen"; - - GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, fused_pattern); + gpd0.mutable_pattern(), "transformer_input_convert_pass"); + fused_pattern(pos_id); + auto handler0 = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + VLOG(3) + << "transformer_input_convert_pass for pos_id, max_seqlen, mask_tensor"; + GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table_id, lookup_table_id, fused_pattern); // create op, var in graph - OpDesc new_desc; + OpDesc new_desc(lookup_table->Op()->Block()); + new_desc.SetType("transformer_input_convert"); // inputs - new_desc.SetInput("X", {lookup_table2_x->Name()}); + new_desc.SetInput("Input", {lookup_table_id->Name()}); // outputs - std::vector output_0 = {"pos_id_tensor"}; - std::vector output_1 = {"max_seqlen_tensor"}; - new_desc.SetOutput("PosId", output_0); - new_desc.SetOutput("MaxSeqlen", output_1); - std::string transformer_input_convert_out0_name = "pos_id_tensor"; std::string transformer_input_convert_out1_name = "max_seqlen_tensor"; - VarDesc transformer_input_convert_out0(transformer_input_convert_out0_name); - VarDesc transformer_input_convert_out1(transformer_input_convert_out1_name); - transformer_input_convert_out0.SetDataType(proto::VarType::INT32); - transformer_input_convert_out1.SetDataType(proto::VarType::INT32); - transformer_input_convert_out0.SetShape({-1}); - transformer_input_convert_out1.SetShape({-1}); - transformer_input_convert_out0.SetPersistable(false); - transformer_input_convert_out1.SetPersistable(false); + std::string transformer_input_convert_out2_name = "mask_tensor"; + std::vector output_0 = {transformer_input_convert_out0_name}; + std::vector output_1 = {transformer_input_convert_out1_name}; + std::vector output_2 = {transformer_input_convert_out2_name}; + new_desc.SetOutput("PosId", output_0); + new_desc.SetOutput("MaxSeqlen", output_1); + new_desc.SetOutput("MaskTensor", output_2); + + auto *transformer_input_convert_out0 = + lookup_table->Op()->Block()->Var(transformer_input_convert_out0_name); + auto *transformer_input_convert_out1 = + lookup_table->Op()->Block()->Var(transformer_input_convert_out1_name); + auto *transformer_input_convert_out2 = + lookup_table->Op()->Block()->Var(transformer_input_convert_out2_name); + transformer_input_convert_out0->SetDataType(proto::VarType::INT32); + transformer_input_convert_out1->SetDataType(proto::VarType::INT32); + transformer_input_convert_out2->SetDataType(proto::VarType::INT32); + transformer_input_convert_out0->SetShape({-1}); + transformer_input_convert_out1->SetShape({-1}); + + transformer_input_convert_out2->SetShape({-1}); + + transformer_input_convert_out0->SetPersistable(false); + transformer_input_convert_out1->SetPersistable(false); + transformer_input_convert_out2->SetPersistable(false); auto new_op_node = graph->CreateOpNode(&new_desc); auto transformer_input_convert_out0_node = - graph->CreateVarNode(&transformer_input_convert_out0); + graph->CreateVarNode(transformer_input_convert_out0); auto transformer_input_convert_out1_node = - graph->CreateVarNode(&transformer_input_convert_out1); + graph->CreateVarNode(transformer_input_convert_out1); + auto transformer_input_convert_out2_node = + graph->CreateVarNode(transformer_input_convert_out2); // needn't create variable in scope - IR_NODE_LINK_TO(lookup_table2_x, new_op_node); + IR_NODE_LINK_TO(lookup_table_id, new_op_node); IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out0_node); IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out1_node); - - found_subgraph_count++; + IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out2_node); + }; + gpd0(graph, handler0); + + GraphPatternDetector gpd1; + patterns::MultiheadMatmulOP multihead_matmul_pattern( + gpd1.mutable_pattern(), "transformer_input_convert_pass"); + multihead_matmul_pattern(); + auto handler1 = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + VLOG(3) << "link pos_id, max_seqlen to multihead_matmul."; + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, + multihead_matmul_pattern); + + IR_NODE_LINK_TO(transformer_input_convert_out0_node, multihead_matmul); + IR_NODE_LINK_TO(transformer_input_convert_out1_node, multihead_matmul); }; + gpd1(graph, handler1); - gpd(graph, handler); + found_subgraph_count++; AddStatis(found_subgraph_count); } @@ -153,9 +158,3 @@ void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const { REGISTER_PASS(set_transformer_input_convert_pass, paddle::framework::ir::SetTransformerInputConvertPass); -REGISTER_PASS_CAPABILITY(set_transformer_input_convert_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("lookup_table", 1) - .LE("lookup_table_v2", 1) - .LE("elementweise_add", 1)); diff --git a/paddle/fluid/framework/ir/set_transformer_input_convert_pass.h b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.h index 5a5843e810f..01c9b1c854b 100644 --- a/paddle/fluid/framework/ir/set_transformer_input_convert_pass.h +++ b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.h @@ -33,41 +33,36 @@ namespace framework { namespace ir { namespace patterns { -// in_var emb in_var emb -// | | | | -// lookup_table lookup_table -// | | -// lkt_var lkt_var -// \ / -// elementwise_add -// | -// elt_out_var +// in_var emb +// | | +// lookup_table +// | +// lkt_var + // struct SetTransformerInputConvert : public PatternBase { SetTransformerInputConvert(PDPattern *pattern, const std::string &name_scope) - : PatternBase(pattern, name_scope, "transformer_input_convert") {} + : PatternBase(pattern, name_scope, "transformer_input_convert_pass") {} + void operator()(const std::string &pos_id); + // declare operator node's name + PATTERN_DECL_NODE(lookup_table); + // declare variable node's name + PATTERN_DECL_NODE(lookup_table_id); +}; +struct MultiheadMatmulOP : public PatternBase { + MultiheadMatmulOP(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "transformer_input_convert_pass") {} void operator()(); - // declare operator node's name - PATTERN_DECL_NODE(lookup_table1); - PATTERN_DECL_NODE(lookup_table2); - PATTERN_DECL_NODE(elementwise); - - // declare variable node's name - PATTERN_DECL_NODE(lookup_table1_x); - PATTERN_DECL_NODE(lookup_table1_w); - PATTERN_DECL_NODE(lookup_table1_out); - PATTERN_DECL_NODE(lookup_table2_x); - PATTERN_DECL_NODE(lookup_table2_w); - PATTERN_DECL_NODE(lookup_table2_out); - PATTERN_DECL_NODE(elementwise_out); + PATTERN_DECL_NODE(multihead_matmul); + PATTERN_DECL_NODE(multihead_matmul_out); }; } // namespace patterns class SetTransformerInputConvertPass : public FusePassBase { public: - SetTransformerInputConvertPass(); + SetTransformerInputConvertPass() {} virtual ~SetTransformerInputConvertPass() {} protected: diff --git a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc new file mode 100644 index 00000000000..8f1fdb0b521 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc @@ -0,0 +1,477 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, + const std::string& arg, + bool is_persist = false) { + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + PDNode* node = + pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg); + if (is_persist) return node->assert_is_persistable_var(); + return node; +} +static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name, + const std::string& arg) { + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + PDNode* node = pattern->NewNode(name) + ->assert_is_only_output_of_ops(embedding_ops) + ->assert_is_op_input("elementwise_add", arg) + ->AsIntermediate(); + return node; +} +void TrtEmbedding2Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table2_x = + create_emb_vars(pattern, lookup_table2_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + auto* lookup_table2_w = + create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed"); + auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed"); + + auto* lookup_table1 = + pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); + auto* lookup_table2 = + pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops); + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "X"); + auto* lookup_table2_out = + create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y"); + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + feed1->LinksTo({lookup_table1_x}); + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + feed2->LinksTo({lookup_table2_x}); + lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) + .LinksTo({lookup_table2_out}); + eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) + .LinksTo({eltwise_add_out}); +} +void TrtEmbedding1Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed"); + + auto* lookup_table1 = + pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y"); + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_output("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + feed1->LinksTo({lookup_table1_x}); + eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) + .LinksTo({eltwise_add_out}); +} +void TrtSkipLayerNorm::operator()() { + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add") + ->assert_is_op_input("layer_norm", "X") + ->AsIntermediate(); + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y") + ->AsOutput(); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + eltwise_add->LinksTo({eltwise_add_out}); + layer_norm + ->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var}); +} + +} // namespace patterns + +int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion( + Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + bool use_varseqlen = Get("use_varseqlen"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + std::vector>> start_pattern_in_nodes; + std::vector start_pattern_out_node; + std::vector> start_pattern_remove_nodes; + + // Create pattern. + patterns::TrtEmbedding2Eltwise1Pattern start_pattern(pattern, + name_scope + "/start"); + start_pattern(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, + start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out, + start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(TrtEmbedding2Eltwise1Pattern) in op compat failed."; + return; + } + std::vector> ins; + ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w)); + ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w)); + start_pattern_in_nodes.push_back(ins); + start_pattern_out_node.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out, + lookup_table2_out, eltwise_add, eltwise_add_out}); + start_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd(graph, handler); + + std::vector> inner_pattern_ins; + std::vector inner_pattern_tmp_in; + std::vector inner_pattern_out; + std::vector> inner_pattern_remove_nodes; + + GraphPatternDetector gpd2; + auto* pattern2 = gpd2.mutable_pattern(); + patterns::TrtEmbedding1Eltwise1Pattern second_pattern(pattern2, + name_scope + "/second"); + second_pattern(); + auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, + second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(TrtEmbedding1Eltwise1Pattern) in op compat failed."; + return; + } + auto in = std::make_pair(lookup_table1_x, lookup_table1_w); + inner_pattern_ins.push_back(in); + inner_pattern_tmp_in.push_back(eltwise_add_in); + inner_pattern_out.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert( + {lookup_table1, lookup_table1_out, eltwise_add, eltwise_add_out}); + inner_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd2(graph, handler2); + + std::vector end_pattern_elt_out; + std::vector end_pattern_scales; + std::vector end_pattern_biases; + std::vector end_pattern_out; + std::vector end_patter_layernorms; + std::vector> end_pattern_remove_nodes; + GraphPatternDetector gpd3; + auto* pattern3 = gpd3.mutable_pattern(); + patterns::TrtSkipLayerNorm skip_layernorm_pattern(pattern3, + name_scope + "/third"); + skip_layernorm_pattern(); + auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + skip_layernorm_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(TrtSkipLayerNorm) in op compat failed."; + return; + } + end_pattern_elt_out.push_back(eltwise_add_out); + std::unordered_set rm_nodes; + rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance}); + end_pattern_remove_nodes.push_back(rm_nodes); + end_pattern_biases.push_back(layer_norm_bias); + end_pattern_scales.push_back(layer_norm_scale); + end_pattern_out.push_back(layer_norm_out); + end_patter_layernorms.push_back(layer_norm); + }; + gpd3(graph, handler3); + + if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) { + return 0; + } + // only reserve the subgraphs that in connected domains. + int fusion_count = 0; + // fusion_id for (i, k, js) + std::vector>>> + fusion_ids; + for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) { + Node* tmp = start_pattern_out_node[i]; + Node* old_tmp = nullptr; + // get correct inner pattern node order. + std::vector js; + while (tmp != old_tmp) { + old_tmp = tmp; + for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) { + if (inner_pattern_tmp_in[j] == tmp) { + tmp = inner_pattern_out[j]; + js.push_back(j); + break; + } + } + } + + for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) { + if (tmp == end_pattern_elt_out[k]) { + fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js))); + break; + } + } + } + + for (size_t num = 0; num < fusion_ids.size(); ++num) { + int i = fusion_ids[num].first; + int k = fusion_ids[num].second.first; + std::vector js = fusion_ids[num].second.second; + + std::vector ids; + std::vector embs; + for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { + ids.push_back(start_pattern_in_nodes[i][iter].first->Name()); + embs.push_back(start_pattern_in_nodes[i][iter].second->Name()); + } + for (size_t iter = 0; iter < js.size(); ++iter) { + ids.push_back(inner_pattern_ins[js[iter]].first->Name()); + embs.push_back(inner_pattern_ins[js[iter]].second->Name()); + } + + OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block()); + new_op_desc.SetType("fused_embedding_eltwise_layernorm"); + new_op_desc.SetInput("Ids", ids); + new_op_desc.SetInput("Embs", embs); + new_op_desc.SetInput("WordId", {ids[0]}); + if (use_varseqlen && pos_id != "" && mask_id != "") { + new_op_desc.SetInput("PosId", {pos_id}); + new_op_desc.SetInput("MaskId", {mask_id}); + } else { + new_op_desc.SetInput("PosId", {ids[1]}); + } + if (ids.size() > 2) { + new_op_desc.SetInput("SentId", {ids[2]}); + } + + new_op_desc.SetInput("WordEmbedding", {embs[0]}); + new_op_desc.SetInput("PosEmbedding", {embs[1]}); + if (embs.size() > 2) { + new_op_desc.SetInput("SentEmbedding", {embs[2]}); + } + + new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); + new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); + new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); + new_op_desc.SetAttr("epsilon", + end_patter_layernorms[k]->Op()->GetAttr("epsilon")); + + if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) { + new_op_desc.SetAttr("enable_int8", true); + new_op_desc.SetAttr( + "out_threshold", + end_patter_layernorms[k]->Op()->GetAttr("out_threshold")); + } + + auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); + + for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { + IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first, + embedding_eltwise_layernorm); + IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second, + embedding_eltwise_layernorm); + } + for (size_t iter = 0; iter < js.size(); ++iter) { + IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first, + embedding_eltwise_layernorm); + IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second, + embedding_eltwise_layernorm); + } + IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm); + IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm); + IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]); + + // Remove unneeded nodes. + std::unordered_set marked_nodes; + marked_nodes.insert(start_pattern_remove_nodes[i].begin(), + start_pattern_remove_nodes[i].end()); + marked_nodes.insert(end_pattern_remove_nodes[k].begin(), + end_pattern_remove_nodes[k].end()); + for (size_t iter = 0; iter < js.size(); ++iter) { + marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(), + inner_pattern_remove_nodes[js[iter]].end()); + } + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + } + + return fusion_count; +} + +TrtEmbeddingEltwiseLayerNormFusePass::TrtEmbeddingEltwiseLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); +} + +void TrtEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { + bool with_dynamic_shape = Get("with_dynamic_shape"); + if (!with_dynamic_shape) { + VLOG(3) << "trt_embedding_eltwise_layernorm_fuse_pass need: use_varseqlen, " + "with_dynamic_shape. Stop this pass, " + "please reconfig."; + return; + } + FusePassBase::Init(name_scope_, graph); + int fusion_count = + TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_); + if (fusion_count > 0) { + bool use_varseqlen = Get("use_varseqlen"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + + if ((use_varseqlen && pos_id != "" && mask_id != "") || + (!use_varseqlen && pos_id == "" && mask_id == "")) { + VLOG(3) << "start trt_embedding_eltwise_layernorm_fuse_pass"; + } else { + PADDLE_THROW( + platform::errors::Fatal("Use transformer'varseqlen need config: " + "use_varseqlen, set pos_id, set " + "mask_id. Or not use varseqlen, do not set " + "pos_id, set mask_id. Please " + "reconfig")); + } + graph->Set(kEmbEltwiseLayernormPass, new bool(true)); + } + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_embedding_eltwise_layernorm_fuse_pass, + paddle::framework::ir::TrtEmbeddingEltwiseLayerNormFusePass); +REGISTER_PASS_CAPABILITY(trt_embedding_eltwise_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("lookup_table", 1) + .LE("lookup_table_v2", 1) + .LE("elementweise_add", 1)); diff --git a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h new file mode 100644 index 00000000000..2d956a38aac --- /dev/null +++ b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h @@ -0,0 +1,167 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +class Graph; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// detect start pattern. +// +// in_var emb in_var emb +// | | | | +// lookup_table lookup_table +// | | +// lkt_var lkt_var +// \ / +// elementwise_add +// | +// elt_out_var +// +struct TrtEmbedding2Eltwise1Pattern : public PatternBase { + TrtEmbedding2Eltwise1Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "embedding2_eltwise1") {} + + void operator()(); + PATTERN_DECL_NODE(feed1); + PATTERN_DECL_NODE(feed2); + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table2_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table2_w); + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table2); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(lookup_table2_out); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); +}; + +// detect repeats inner pattern +// +// elt_out_var in_var emb +// \ | | +// \ lookup_table +// \ | +// \ lkt_var +// \ / +// elementwise_add +// | +// elt_out_var +// +struct TrtEmbedding1Eltwise1Pattern : public PatternBase { + TrtEmbedding1Eltwise1Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "embedding1_eltwise1") {} + void operator()(); + PATTERN_DECL_NODE(feed1); + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(eltwise_add_in); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); +}; + +// detect end pattern +// +// elementwise_add +// | +// elt_out_var +// scale | bias +// \ | / +// layer_norm +// +struct TrtSkipLayerNorm : public PatternBase { + TrtSkipLayerNorm(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "skip_layernorm") {} + void operator()(); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + // Delete the mean and var nodes in the graph. + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); +}; +} // namespace patterns + +// The TrtEmbeddingEltwiseLayerNormFusePass detect the following pattern: +// +// inputs operator output +// -------------------------------------------------------------------- +// (word, weights_0) lookup_table -> word_emb +// (pos, weights_1) lookup_table -> pos_emb +// (sent, weights_2) lookup_table -> sent_emb +// (word_emb, pos_emb) elementweise_add -> elementwise_out_0 +// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1 +// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out +// +// and then convert the corresponding subgraph to: +// +// (word, pos, sent, weights_0, weights_1, weights_2, +// scale, baias) embedding_eltwise_layernorm -> layer_norm_out +// +// +// in_var emb_var in_var emb_var in_var emb_var in_var emb_var +// | | | | | | | | +// lookup_table lookup_table lookup_table ... lookup_table +// | | | | +// lkt_var lkt_var lkt_var lkt_var +// \ / | ... | +// elementwise_add | | +// \ / | +// elementwise_add | +// | | +// elt_var / +// \ / +// elementwise_add +// | +// layer_norm + +class TrtEmbeddingEltwiseLayerNormFusePass : public FusePassBase { + public: + TrtEmbeddingEltwiseLayerNormFusePass(); + virtual ~TrtEmbeddingEltwiseLayerNormFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + int BuildFusion(Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const; + const std::string name_scope_{"trt_embedding_eltwise_layernorm_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc new file mode 100644 index 00000000000..798a038f767 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -0,0 +1,1546 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +static void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) { + if (op->IsOp() && op->Op()) { + new_var->inputs.push_back(op); + for (size_t i = 0; i < op->outputs.size(); ++i) { + if (op->outputs[i] == old_var) { + op->outputs[i] = new_var; + op->Op()->RenameOutput(old_var->Name(), new_var->Name()); + } + } + } +} + +static int BuildFusion(Graph* graph, const std::string& name_scope) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + TrtMultiHeadMatmulPattern multihead_pattern(pattern, name_scope); + + multihead_pattern(); + // Create New OpDesc + auto fuse_creater = [&]( + Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, + Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b, + Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2, + Node* reshape2_qkv_out, Node* scale, Node* scale_out) { + auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); + // auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias")); + // bool after_scale = + // BOOST_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); + + // create multihead + OpDesc multihead_op_desc(mul0->Op()->Block()); + + // create tmp tensor + VarDesc k_var_desc(*mul1_out->Var()); + k_var_desc.SetName("K" + mul1_out->Name()); + auto* k_var_node = graph->CreateVarNode(&k_var_desc); + + VarDesc q_var_desc(*mul0_out->Var()); + q_var_desc.SetName("Q" + mul0_out->Name()); + auto* q_var_node = graph->CreateVarNode(&q_var_desc); + + VarDesc v_var_desc(*mul2_out->Var()); + v_var_desc.SetName("V" + mul2_out->Name()); + auto* v_var_node = graph->CreateVarNode(&v_var_desc); + + auto reshape_desc = reshape2->Op(); + int head_number = + BOOST_GET_CONST(std::vector, reshape_desc->GetAttr("shape")).at(2); + + ReplaceOutputVar(mul0, mul0_out, q_var_node); + ReplaceOutputVar(mul1, mul1_out, k_var_node); + ReplaceOutputVar(mul2, mul2_out, v_var_node); + + multihead_op_desc.SetType("multihead_matmul"); + multihead_op_desc.SetInput("Q", {q_var_node->Name()}); + multihead_op_desc.SetInput("K", {k_var_node->Name()}); + multihead_op_desc.SetInput("V", {v_var_node->Name()}); + + multihead_op_desc.SetInput("BiasQ", {eltadd0_b->Name()}); + multihead_op_desc.SetInput("BiasK", {eltadd1_b->Name()}); + multihead_op_desc.SetInput("BiasV", {eltadd2_b->Name()}); + multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); + + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + IR_NODE_LINK_TO(q_var_node, multihead); + IR_NODE_LINK_TO(k_var_node, multihead); + IR_NODE_LINK_TO(v_var_node, multihead); + + IR_NODE_LINK_TO(eltadd0_b, multihead); + IR_NODE_LINK_TO(eltadd1_b, multihead); + IR_NODE_LINK_TO(eltadd2_b, multihead); + IR_NODE_LINK_TO(eltadd_qk_b, multihead); + + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, + multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, + multihead_pattern); + + fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, + eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, + reshape2_qkv_out, scale, scale_out); + + std::unordered_set marked_nodes( + {eltadd0, + eltadd1, + eltadd2, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, // dropout_qk, dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0_out, + mul1_out, + mul2_out, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +PDNode* TrtMultiHeadMatmulPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("mul"); + + // First path with scale + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_op_input("mul", "Y"); + auto* mul0_out_var = + pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul"); + + decltype(mul0) eltadd0; + decltype(mul0) eltadd0_b_var; + decltype(mul0) eltadd0_out_var; + + mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = + pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); + reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_0_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out_var = + pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); + scale_out_var->AsIntermediate()->assert_is_op_input("matmul"); + + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = + pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); + softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul"); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2"); + reshape2_qkv_out_var->assert_is_op_input("mul"); + + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul"); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_op_input("mul", "Y"); + auto* mul1_out_var = + pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul"); + + decltype(mul1) eltadd1; + decltype(mul1) eltadd1_b_var; + decltype(mul1) eltadd1_out_var; + + mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = + pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); + reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_1_out_var->AsIntermediate()->assert_is_op_input( + "matmul"); // link to matmul qk + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul"); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_op_input("mul", "Y"); + auto* mul2_out_var = + pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul"); + + decltype(mul2) eltadd2; + decltype(mul2) eltadd2_b_var; + decltype(mul2) eltadd2_out_var; + + mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = + pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); + reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_2_out_var->AsIntermediate()->assert_is_op_input( + "matmul"); // link to matmul qkv + + // Q path + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); + eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); + + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var}); + // K path + mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); + eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + // compute q*k + matmul_qk->LinksFrom({scale_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var}); + eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return transpose2_2_out_var; +} + +PDNode* TrtMultiHeadMatmulV3Pattern::operator()() { + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_ops_input(matmul_ops); + + // First path with scale + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(matmul_ops); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_ops_input(matmul_ops, "Y"); + auto* mul0_out_var = + pattern->NewNode(mul0_out_repr())->assert_is_ops_output(matmul_ops); + + decltype(mul0) eltadd0; + decltype(mul0) eltadd0_b_var; + decltype(mul0) eltadd0_out_var; + + mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = + pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); + reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops, "X"); + + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = + pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); + softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2"); + reshape2_qkv_out_var->assert_is_ops_input(matmul_ops); + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(matmul_ops); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_ops_input(matmul_ops, "Y"); + auto* mul1_out_var = + pattern->NewNode(mul1_out_repr())->assert_is_ops_output(matmul_ops); + + decltype(mul1) eltadd1; + decltype(mul1) eltadd1_b_var; + decltype(mul1) eltadd1_out_var; + + mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = + pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); + reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_1_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops, "Y"); // link to matmul qk + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(matmul_ops); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_ops_input(matmul_ops, "Y"); + auto* mul2_out_var = + pattern->NewNode(mul2_out_repr())->assert_is_ops_output(matmul_ops); + + decltype(mul2) eltadd2; + decltype(mul2) eltadd2_b_var; + decltype(mul2) eltadd2_out_var; + + mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = + pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); + reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qkv + + // Q path + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); + eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); + + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + // K path + mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); + eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + // compute q*k + matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var}); + eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return transpose2_2_out_var; +} +} // namespace patterns + +void TrtMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + + int fusion_count = patterns::BuildFusion(graph, name_scope_); + AddStatis(fusion_count); +} + +TrtMultiHeadMatmulV2FusePass::TrtMultiHeadMatmulV2FusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(2) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + // in bias, shape is (B, S, N*H), + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + .AddInput("Y") + // in bias, shape is (N*H) + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + // in bias, shape is (B, S, N*H) + // in biasqk, shape is (B, H, S, S) + .AddOutput("Out") + .IsTensor() + .End() + // in bias, it equal to 2 + // in biasqk, it equal to -1 or 0 + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + // -->: (B, S, H, N) -> (B, H, S, N) + // <--: (B, H, S, N) -> (B, S, H, N) + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsType() // copy to new op. so unconstrained. + .End() + .AddAttr("bias") + .IsNumEQ(0.f) + .End() + .AddAttr("bias_after_scale") // bias is 0, so unconstrained. + .IsType() + .End(); + + // QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S) + // QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N) + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumEQ(1.0f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); +} + +int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::TrtMultiHeadMatmulPattern multihead_pattern(pattern, name_scope); + + multihead_pattern(); + // Create New OpDesc + auto fuse_creater = [&]( + Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, + Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, + Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, + Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out, + Node* softmax_qk, Node* eltadd0, Node* eltadd1, Node* eltadd2, + Node* matmul_qk, Node* reshape2_qkv) { + auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); + + // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) + // bias (B * S * 3 * N * H) + bias (3 * N * H) + // Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H) + auto* wq_tensor = scope->FindVar(mul0_w->Name())->GetMutable(); + auto* wk_tensor = scope->FindVar(mul1_w->Name())->GetMutable(); + auto* wv_tensor = scope->FindVar(mul2_w->Name())->GetMutable(); + + auto* bq_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(eltadd1_b->Name())->GetMutable(); + auto* bv_tensor = + scope->FindVar(eltadd2_b->Name())->GetMutable(); + + auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); + auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); + auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); + auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); + auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); + auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); + + auto combined_w_dims = + phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]}); + + // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. + auto* combined_w_desc = mul0_w->Var(); + combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + combined_w_desc->SetPersistable(true); + + auto* combined_bias_desc = eltadd0_b->Var(); + combined_bias_desc->SetShape({3, bq_tensor->dims()[0]}); + combined_bias_desc->SetPersistable(true); + + framework::LoDTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + auto* tmp_combined_w_data = + tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); + + std::vector w_vec = {wq_data, wk_data, wv_data}; + int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; + // Combine the three fc weights together. + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (3 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + tmp_combined_w_data[out_index] = w_vec[j][in_index]; + } + } + } + + wq_tensor->Resize(combined_w_dims); + auto* new_combined_w_data = + wq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_w_data, tmp_combined_w_data, + sizeof(float) * wq_tensor->numel()); + + scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); + + framework::LoDTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + auto* tmp_combined_bias_data = + tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); + + size_t bias_size = bq_tensor->numel(); + memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + bias_size, bk_data, + sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data, + sizeof(float) * bias_size); + + bq_tensor->Resize(combined_bias_dims); + auto* new_combined_bias_data = + bq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_bias_data, tmp_combined_bias_data, + sizeof(float) * bq_tensor->numel()); + + scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); + + auto reshape_desc = reshape2->Op(); + int head_number = + BOOST_GET_CONST(std::vector, reshape_desc->GetAttr("shape")).at(2); + + OpDesc multihead_op_desc(mul0->Op()->Block()); + multihead_op_desc.SetType("multihead_matmul"); + + multihead_op_desc.SetInput("Input", {input0->Name()}); + multihead_op_desc.SetInput("W", {mul0_w->Name()}); + multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()}); + multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); + + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* mul0_op_desc = mul0->Op(); + + // all mul op has same input. + if (multihead_op_desc.HasAttr("Input_scale")) { + multihead_op_desc.SetAttr("Input_scale", + mul0_op_desc->GetAttr("Input_scale")); + } + auto* add0_op_desc = eltadd0->Op(); + auto* add1_op_desc = eltadd1->Op(); + auto* add2_op_desc = eltadd2->Op(); + if (add0_op_desc->HasAttr("out_threshold")) { + auto out_scale0 = + BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold")); + auto out_scale1 = + BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold")); + auto out_scale2 = + BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold")); + auto out_scale_max = std::max(out_scale0, out_scale1); + out_scale_max = std::max(out_scale_max, out_scale2); + multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max); + } + + auto* softmax_qk_op_desc = softmax_qk->Op(); + auto* matmul_qk_op_desc = matmul_qk->Op(); + if (matmul_qk_op_desc->HasAttr("Input_scale")) { + multihead_op_desc.SetAttr("qkv2context_plugin_int8", true); + if (softmax_qk_op_desc->HasAttr("out_threshold")) { + auto qkv_plugin_scale = BOOST_GET_CONST( + float, softmax_qk_op_desc->GetAttr("out_threshold")); + multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale); + } + } + if (reshape2_qkv->Op()->HasAttr("out_threshold")) { + multihead_op_desc.SetAttr("out_threshold", + reshape2_qkv->Op()->GetAttr("out_threshold")); + } + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + + IR_NODE_LINK_TO(input0, multihead); + IR_NODE_LINK_TO(mul0_w, multihead); + IR_NODE_LINK_TO(eltadd0_b, multihead); + IR_NODE_LINK_TO(eltadd_qk_b, multihead); + + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "Op compat check in trt_multihead_matmul_fuse_pass_v2 failed."; + return; + } + // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, + multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, + multihead_pattern); + + // If weights or biases in qkv's fc are shared by multiple multihead_matmul + // patterns, we do not support this kind of fusion, this pass will not take + // effect. + bool is_fc_params_shared = + mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 || + mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 || + eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1; + if (is_fc_params_shared) { + return; + } + fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, + mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, + reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk, + eltadd0, eltadd1, eltadd2, matmul_qk, reshape2_qkv); + + std::unordered_set marked_nodes({eltadd0, + eltadd1, + eltadd2, + eltadd1_b, + eltadd2_b, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul1_w, + mul2_w, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the multiheadMatmul pass, The scope should not be null.")); + + int fusion_count = BuildFusionV2(graph, name_scope_, scope); + if (fusion_count > 0) { + bool use_varseqlen = Get("use_varseqlen"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + + if (use_varseqlen && pos_id != "" && mask_id != "") { + if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) { + VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v2"; + } else { + PADDLE_THROW(platform::errors::Fatal( + "Use transformer'varseqlen need " + "embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); + } + } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v2"; + } else { + PADDLE_THROW( + platform::errors::Fatal("Use transformer'varseqlen need config: " + "use_varseqlen, set pos_id, set " + "mask_id. Or not use varseqlen, do not set " + "pos_id, set mask_id. Please " + "reconfig")); + } + graph->Set(kMultiheadMatmulPass, new bool(true)); + } + AddStatis(fusion_count); +} + +TrtMultiHeadMatmulV3FusePass::TrtMultiHeadMatmulV3FusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(2) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + // in bias, shape is (B, S, N*H), + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + .AddInput("Y") + // in bias, shape is (N*H) + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + // in bias, shape is (B, S, N*H) + // in biasqk, shape is (B, H, S, S) + .AddOutput("Out") + .IsTensor() + .End() + // in bias, it equal to 2 + // in biasqk, it equal to -1 or 0 + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + // -->: (B, S, H, N) -> (B, H, S, N) + // <--: (B, H, S, N) -> (B, S, H, N) + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + // QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S) + // QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N) + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() // QK(anyvalue, will copy to new op) QKV(1.0) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); +} + +int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::TrtMultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope); + + multihead_pattern(); + // Create New OpDesc + auto fuse_creater = [&]( + Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, + Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, + Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, + Node* reshape2, Node* reshape2_qkv_out, Node* matmul_qk) { + auto scale_attr = BOOST_GET_CONST(float, matmul_qk->Op()->GetAttr("alpha")); + + // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) + // bias (B * S * 3 * N * H) + bias (3 * N * H) + // Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H) + auto* wq_tensor = scope->FindVar(mul0_w->Name())->GetMutable(); + auto* wk_tensor = scope->FindVar(mul1_w->Name())->GetMutable(); + auto* wv_tensor = scope->FindVar(mul2_w->Name())->GetMutable(); + + auto* bq_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(eltadd1_b->Name())->GetMutable(); + auto* bv_tensor = + scope->FindVar(eltadd2_b->Name())->GetMutable(); + + auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); + auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); + auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); + auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); + auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); + auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); + + auto combined_w_dims = + phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]}); + + // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. + auto* combined_w_desc = mul0_w->Var(); + combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + combined_w_desc->SetPersistable(true); + + auto* combined_bias_desc = eltadd0_b->Var(); + combined_bias_desc->SetShape({3, bq_tensor->dims()[0]}); + combined_bias_desc->SetPersistable(true); + + framework::LoDTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + auto* tmp_combined_w_data = + tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); + + std::vector w_vec = {wq_data, wk_data, wv_data}; + int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; + // Combine the three fc weights together. + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (3 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + tmp_combined_w_data[out_index] = w_vec[j][in_index]; + } + } + } + + wq_tensor->Resize(combined_w_dims); + auto* new_combined_w_data = + wq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_w_data, tmp_combined_w_data, + sizeof(float) * wq_tensor->numel()); + + scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); + + framework::LoDTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + auto* tmp_combined_bias_data = + tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); + + size_t bias_size = bq_tensor->numel(); + memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + bias_size, bk_data, + sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data, + sizeof(float) * bias_size); + + bq_tensor->Resize(combined_bias_dims); + auto* new_combined_bias_data = + bq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_bias_data, tmp_combined_bias_data, + sizeof(float) * bq_tensor->numel()); + + scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); + + auto reshape_desc = reshape2->Op(); + int head_number = + BOOST_GET_CONST(std::vector, reshape_desc->GetAttr("shape")).at(2); + + OpDesc multihead_op_desc(mul0->Op()->Block()); + multihead_op_desc.SetType("multihead_matmul"); + + multihead_op_desc.SetInput("Input", {input0->Name()}); + multihead_op_desc.SetInput("W", {mul0_w->Name()}); + multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()}); + multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); + + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + + IR_NODE_LINK_TO(input0, multihead); + IR_NODE_LINK_TO(mul0_w, multihead); + IR_NODE_LINK_TO(eltadd0_b, multihead); + IR_NODE_LINK_TO(eltadd_qk_b, multihead); + + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, + multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, + multihead_pattern); + + // If weights or biases in qkv's fc are shared by multiple multihead_matmul + // patterns, we do not support this kind of fusion, this pass will not take + // effect. + bool is_fc_params_shared = + mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 || + mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 || + eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1; + if (is_fc_params_shared) { + return; + } + fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, + mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, + reshape2_0, reshape2_qkv_out, matmul_qk); + + std::unordered_set marked_nodes({eltadd0, + eltadd1, + eltadd2, + eltadd1_b, + eltadd2_b, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul1_w, + mul2_w, + reshape2_qkv}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the multiheadMatmul pass, The scope should not be null.")); + + int fusion_count = BuildFusionV3(graph, name_scope_, scope); + if (fusion_count > 0) { + bool use_varseqlen = Get("use_varseqlen"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + + if (use_varseqlen && pos_id != "" && mask_id != "") { + if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) { + VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v3"; + } else { + PADDLE_THROW(platform::errors::Fatal( + "Use transformer'varseqlen need " + "embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); + } + } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v3"; + } else { + PADDLE_THROW( + platform::errors::Fatal("Use transformer'varseqlen need config: " + "use_varseqlen, set pos_id, set " + "mask_id. Or not use varseqlen, do not set " + "pos_id, set mask_id. Please " + "reconfig")); + } + graph->Set(kMultiheadMatmulPass, new bool(true)); + } + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_multihead_matmul_fuse_pass, + paddle::framework::ir::TrtMultiHeadMatmulFusePass); + +REGISTER_PASS(trt_multihead_matmul_fuse_pass_v2, + paddle::framework::ir::TrtMultiHeadMatmulV2FusePass); +REGISTER_PASS(trt_multihead_matmul_fuse_pass_v3, + paddle::framework::ir::TrtMultiHeadMatmulV3FusePass); +REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("softmax", 0)); +REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v3) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h new file mode 100644 index 00000000000..467e803b497 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h @@ -0,0 +1,179 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct TrtMultiHeadMatmulPattern : public PatternBase { + TrtMultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multihead_matmul") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(mul0_w); + PATTERN_DECL_NODE(mul1_w); + PATTERN_DECL_NODE(mul2_w); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(eltadd2_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_qkv_out); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); +}; + +struct TrtMultiHeadMatmulV3Pattern : public PatternBase { + TrtMultiHeadMatmulV3Pattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multihead_matmul_v3") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(mul0_w); + PATTERN_DECL_NODE(mul1_w); + PATTERN_DECL_NODE(mul2_w); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(eltadd2_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_qkv_out); + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); +}; + +} // namespace patterns + +class TrtMultiHeadMatmulFusePass : public FusePassBase { + public: + virtual ~TrtMultiHeadMatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"trt_multihead_matmul_fuse"}; +}; + +class TrtMultiHeadMatmulV2FusePass : public FusePassBase { + public: + TrtMultiHeadMatmulV2FusePass(); + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"trt_multihead_matmul_fuse_v2"}; + + private: + int BuildFusionV2(Graph* graph, const std::string& name_scope, + Scope* scope) const; +}; + +class TrtMultiHeadMatmulV3FusePass : public FusePassBase { + public: + TrtMultiHeadMatmulV3FusePass(); + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"trt_multihead_matmul_fuse_v3"}; + + private: + int BuildFusionV3(Graph* graph, const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc new file mode 100644 index 00000000000..53452d4239a --- /dev/null +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -0,0 +1,232 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct TrtSkipLayerNorm : public PatternBase { + TrtSkipLayerNorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "skip_layernorm") {} + + PDNode *operator()(PDNode *x, PDNode *y); + + // declare operator node's name + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(layer_norm); + // declare variable node's name + PATTERN_DECL_NODE( + elementwise_out); // (elementwise_input_x,elementwise_input_y) -> + // elementwise_out + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); +}; + +PDNode *TrtSkipLayerNorm::operator()(PDNode *x, PDNode *y) { + // Create nodes for elementwise add op. + x->assert_is_op_input("elementwise_add", "X"); + y->assert_is_op_input("elementwise_add", "Y"); + auto *elementwise = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); + auto *elementwise_out_var = + pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_only_output_of_op("elementwise_add"); + + // Add links for elementwise_add op. + elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); + + // Create nodes for layer_norm op. + elementwise_out_var->AsIntermediate()->assert_is_op_input("layer_norm"); + auto *layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + + auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto *layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + + // Add links for layer_norm op. + layer_norm + ->LinksFrom( + {elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + return layer_norm_out_var; +} + +} // namespace patterns + +void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("skip_layernorm_fuse", graph); + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("skip_layernorm_fuse/x") + ->AsInput() + ->assert_is_op_input("elementwise_add", "X") + ->assert_var_not_persistable(); + auto *y = gpd.mutable_pattern() + ->NewNode("skip_layernorm_fuse/y") + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y") + ->assert_var_not_persistable(); + patterns::TrtSkipLayerNorm fused_pattern(gpd.mutable_pattern(), + "skip_layernorm_fuse"); + fused_pattern(x, y); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "skip_layernorm pass in op compat failed."; + return; + } + + VLOG(4) << "handle TrtSkipLayerNorm fuse"; + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + fused_pattern); + + std::unordered_set del_node_set; + + // Create an TrtSkipLayerNorm op node + OpDesc new_desc(elementwise->Op()->Block()); + new_desc.SetType("skip_layernorm"); + + // inputs + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetInput("Scale", {layer_norm_scale->Name()}); + new_desc.SetInput("Bias", {layer_norm_bias->Name()}); + + if (layer_norm->Op()->HasAttr("out_threshold")) { + new_desc.SetAttr("enable_int8", true); + new_desc.SetAttr("out_threshold", + layer_norm->Op()->GetAttr("out_threshold")); + } + + // outputs + new_desc.SetOutput("Out", {layer_norm_out->Name()}); + + // attrs + new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("begin_norm_axis", + layer_norm->Op()->GetAttr("begin_norm_axis")); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(elementwise); + del_node_set.insert(layer_norm); + del_node_set.insert(elementwise_out); + del_node_set.insert(layer_norm_mean); + del_node_set.insert(layer_norm_variance); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(subgraph.at(y), fused_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, layer_norm_out); + + found_subgraph_count++; + }; + + gpd(graph, handler); + if (found_subgraph_count > 0) { + bool use_varseqlen = Get("use_varseqlen"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + + if (use_varseqlen && pos_id != "" && mask_id != "") { + if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass)) { + VLOG(3) << "start varseqlen trt_skip_layernorm_fuse_pass"; + } else { + PADDLE_THROW(platform::errors::Fatal( + "Use transformer'varseqlen need " + "embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); + } + } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass"; + } else { + PADDLE_THROW( + platform::errors::Fatal("Use transformer'varseqlen need config: " + "use_varseqlen, set pos_id, set " + "mask_id. Or not use varseqlen, do not set " + "pos_id, set mask_id. Please " + "reconfig")); + } + } + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_skip_layernorm_fuse_pass, + paddle::framework::ir::TrtSkipLayerNormFusePass); +REGISTER_PASS_CAPABILITY(trt_skip_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("layer_norm", 0)); diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h new file mode 100644 index 00000000000..a299493efa0 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +// | | | | +// other_op1 other_op2 other_op1 other_op2 +// | | fuse \ / +// |------elementwise_add -> skip_layernorm +// | | +// layer_norm other_op3 +// | | +// other_op3 +// | +class Graph; + +class TrtSkipLayerNormFusePass : public FusePassBase { + public: + TrtSkipLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1}) + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + } + + virtual ~TrtSkipLayerNormFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 2336fd1980d..07b7b374859 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -216,8 +216,12 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, bool); DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool); - DECL_ARGUMENT_FIELD(tensorrt_use_oss, TensorRtUseOSS, bool); + DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool); DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool); + DECL_ARGUMENT_FIELD(tensorrt_transformer_posid, TensorRtTransformerPosid, + std::string); + DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid, TensorRtTransformerMaskid, + std::string); DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path, TensorRtShapeRangeInfoPath, std::string); DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape, TensorRtTunedDynamicShape, diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index aafbe57e05f..c5c60564b0f 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -55,9 +55,13 @@ void IRPassManager::CreatePasses(Argument *argument, int pass_num = 0; for (const std::string &pass_name : passes) { auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); - pass->Set("use_oss", new bool(argument->tensorrt_use_oss())); + pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen())); pass->Set("with_interleaved", new bool(argument->tensorrt_with_interleaved())); + pass->Set("tensorrt_transformer_posid", + new std::string(argument->tensorrt_transformer_posid())); + pass->Set("tensorrt_transformer_maskid", + new std::string(argument->tensorrt_transformer_maskid())); pass->Set("disable_logs", new bool(argument->disable_logs())); auto precision_mode = argument->tensorrt_precision_mode(); bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8; diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index b73eb624db8..394ce7799e8 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -377,12 +377,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp( Get("workspace_size"), precision_mode, calibrator.get(), Get("gpu_device_id"), min_input_shape, max_input_shape, opt_input_shape, disable_trt_plugin_fp16); - trt_engine->SetUseOSS(Get("use_oss")); + trt_engine->SetUseOSS(Get("use_varseqlen")); trt_engine->SetWithInterleaved(Get("with_interleaved")); + trt_engine->SetTransformerPosid( + Get("tensorrt_transformer_posid")); + trt_engine->SetTransformerMaskid( + Get("tensorrt_transformer_maskid")); trt_engine->SetUseDLA(Get("trt_use_dla")); trt_engine->SetDLACore(Get("trt_dla_core")); trt_engine->SetUseInspector(Get("use_inspector")); - trt_engine->SetWithErnie(graph->Has(framework::ir::kMultiheadMatmulPass)); + trt_engine->SetWithErnie( + graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass)); if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 735e1b7be4c..5bb26d8f080 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -256,8 +256,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(trt_dla_core_); CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_calib_mode_); - CP_MEMBER(trt_use_oss_); + CP_MEMBER(trt_use_varseqlen_); CP_MEMBER(trt_with_interleaved_); + CP_MEMBER(tensorrt_transformer_posid_); + CP_MEMBER(tensorrt_transformer_maskid_); CP_MEMBER(trt_tuned_dynamic_shape_); CP_MEMBER(trt_allow_build_at_runtime_); CP_MEMBER(collect_shape_range_info_); @@ -546,7 +548,7 @@ void AnalysisConfig::Exp_DisableTensorRtOPs( trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end()); } -void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; } +void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; } // TODO(Superjomn) refactor this, buggy. void AnalysisConfig::Update() { @@ -1034,9 +1036,13 @@ std::string AnalysisConfig::Summary() { ? shape_range_info_path_ : "false"}); - os.InsertRow({"tensorrt_use_oss", trt_use_oss_ ? "true" : "false"}); + os.InsertRow( + {"tensorrt_use_varseqlen", trt_use_varseqlen_ ? "true" : "false"}); os.InsertRow({"tensorrt_with_interleaved", trt_with_interleaved_ ? "true" : "false"}); + os.InsertRow({"tensorrt_transformer_posid", tensorrt_transformer_posid_}); + os.InsertRow( + {"tensorrt_transformer_maskid", tensorrt_transformer_maskid_}); os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"}); if (trt_use_dla_) { os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)}); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 09a5bbddba8..b40377855bd 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -853,8 +853,10 @@ void AnalysisPredictor::PrepareArgument() { } argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); - argument_.SetTensorRtUseOSS(config_.trt_use_oss_); + argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_); argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_); + argument_.SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_); + argument_.SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_); argument_.SetMinInputShape(config_.min_input_shape_); argument_.SetMaxInputShape(config_.max_input_shape_); argument_.SetOptimInputShape(config_.optim_input_shape_); @@ -1803,6 +1805,9 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(strided_slice) +USE_TRT_CONVERTER(transformer_input_convert) +USE_TRT_CONVERTER(recover_padding) +USE_TRT_CONVERTER(remove_padding) #endif namespace paddle_infer { @@ -1971,6 +1976,20 @@ void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c, #endif } +void InternalUtils::SetTransformerPosid( + paddle_infer::Config *c, const std::string &tensorrt_transformer_posid) { +#ifdef PADDLE_WITH_CUDA + c->tensorrt_transformer_posid_ = tensorrt_transformer_posid; +#endif +} + +void InternalUtils::SetTransformerMaskid( + paddle_infer::Config *c, const std::string &tensorrt_transformer_maskid) { +#ifdef PADDLE_WITH_CUDA + c->tensorrt_transformer_maskid_ = tensorrt_transformer_maskid; +#endif +} + void InternalUtils::SyncStream(paddle_infer::Predictor *p) { #ifdef PADDLE_WITH_CUDA auto *pred = dynamic_cast(p->predictor_.get()); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index af6cf88a322..ab2265bff24 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -618,14 +618,14 @@ struct PD_INFER_DECL AnalysisConfig { /// may be more high-performance. Libnvinfer_plugin.so greater than /// V7.2.1 is needed. /// - void EnableTensorRtOSS(); + void EnableVarseqlen(); /// /// \brief A boolean state telling whether to use the TensorRT OSS. /// /// \return bool Whether to use the TensorRT OSS. /// - bool tensorrt_oss_enabled() { return trt_use_oss_; } + bool tensorrt_varseqlen_enabled() { return trt_use_varseqlen_; } /// /// \brief Enable TensorRT DLA @@ -954,8 +954,10 @@ struct PD_INFER_DECL AnalysisConfig { Precision tensorrt_precision_mode_{Precision::kFloat32}; bool trt_use_static_engine_{false}; bool trt_use_calib_mode_{true}; - bool trt_use_oss_{false}; + bool trt_use_varseqlen_{false}; bool trt_with_interleaved_{false}; + std::string tensorrt_transformer_posid_{""}; + std::string tensorrt_transformer_maskid_{""}; bool trt_use_dla_{false}; int trt_dla_core_{0}; std::map> min_input_shape_{}; diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index dc9f7debe5f..711998e9956 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -435,6 +435,12 @@ class PD_INFER_DECL InternalUtils { static void UpdateConfigInterleaved(paddle_infer::Config* c, bool with_interleaved); + static void SetTransformerPosid( + paddle_infer::Config* c, const std::string& tensorrt_transformer_posid); + + static void SetTransformerMaskid( + paddle_infer::Config* c, const std::string& tensorrt_transformer_maskid); + static void SyncStream(paddle_infer::Predictor* pred); static void SyncStream(cudaStream_t stream); template diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f9ec41f6c83..3b1c84db4c5 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -94,25 +94,25 @@ const std::vector kTRTSubgraphPasses({ "add_support_int8_pass", // // "fc_fuse_pass", // "simplify_with_basic_ops_pass", // - "embedding_eltwise_layernorm_fuse_pass", // + "trt_embedding_eltwise_layernorm_fuse_pass", // "preln_embedding_eltwise_layernorm_fuse_pass", // - "multihead_matmul_fuse_pass_v2", // - "multihead_matmul_fuse_pass_v3", // - "skip_layernorm_fuse_pass", // + "trt_multihead_matmul_fuse_pass_v2", // + "trt_multihead_matmul_fuse_pass_v3", // + "trt_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", // // "set_transformer_input_convert_pass", // - "conv_bn_fuse_pass", // - "unsqueeze2_eltwise_fuse_pass", // - "trt_squeeze2_matmul_fuse_pass", // - "trt_reshape2_matmul_fuse_pass", // - "trt_flatten2_matmul_fuse_pass", // - "trt_map_matmul_v2_to_mul_pass", // - "trt_map_matmul_v2_to_matmul_pass", // - "trt_map_matmul_to_mul_pass", // - "fc_fuse_pass", // - "conv_elementwise_add_fuse_pass", // - // "remove_padding_recover_padding_pass", // - // "delete_remove_padding_recover_padding_pass", // + "conv_bn_fuse_pass", // + "unsqueeze2_eltwise_fuse_pass", // + "trt_squeeze2_matmul_fuse_pass", // + "trt_reshape2_matmul_fuse_pass", // + "trt_flatten2_matmul_fuse_pass", // + "trt_map_matmul_v2_to_mul_pass", // + "trt_map_matmul_v2_to_matmul_pass", // + "trt_map_matmul_to_mul_pass", // + "fc_fuse_pass", // + "conv_elementwise_add_fuse_pass", // + "remove_padding_recover_padding_pass", // + "delete_remove_padding_recover_padding_pass", // // "yolo_box_fuse_pass", // "tensorrt_subgraph_pass", // "conv_bn_fuse_pass", // diff --git a/paddle/fluid/inference/capi_exp/pd_config.cc b/paddle/fluid/inference/capi_exp/pd_config.cc index d7b07652bab..d290f44d2ee 100644 --- a/paddle/fluid/inference/capi_exp/pd_config.cc +++ b/paddle/fluid/inference/capi_exp/pd_config.cc @@ -303,13 +303,13 @@ void PD_ConfigDisableTensorRtOPs(__pd_keep PD_Config* pd_config, size_t ops_num, config->Exp_DisableTensorRtOPs(ops_list); } -void PD_ConfigEnableTensorRtOSS(__pd_keep PD_Config* pd_config) { +void PD_ConfigEnableVarseqlen(__pd_keep PD_Config* pd_config) { CHECK_AND_CONVERT_PD_CONFIG; - config->EnableTensorRtOSS(); + config->EnableVarseqlen(); } PD_Bool PD_ConfigTensorRtOssEnabled(__pd_keep PD_Config* pd_config) { CHECK_AND_CONVERT_PD_CONFIG; - return config->tensorrt_oss_enabled(); + return config->tensorrt_varseqlen_enabled(); } void PD_ConfigEnableTensorRtDla(__pd_keep PD_Config* pd_config, diff --git a/paddle/fluid/inference/capi_exp/pd_config.h b/paddle/fluid/inference/capi_exp/pd_config.h index f6b754cad21..667843520d6 100644 --- a/paddle/fluid/inference/capi_exp/pd_config.h +++ b/paddle/fluid/inference/capi_exp/pd_config.h @@ -432,7 +432,7 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigDisableTensorRtOPs( /// /// \param[in] pd_onfig config /// -PADDLE_CAPI_EXPORT extern void PD_ConfigEnableTensorRtOSS( +PADDLE_CAPI_EXPORT extern void PD_ConfigEnableVarseqlen( __pd_keep PD_Config* pd_config); /// /// \brief A boolean state telling whether to use the TensorRT OSS. diff --git a/paddle/fluid/inference/goapi/config.go b/paddle/fluid/inference/goapi/config.go index 8f9f34c06b4..0aca2a1075f 100644 --- a/paddle/fluid/inference/goapi/config.go +++ b/paddle/fluid/inference/goapi/config.go @@ -500,8 +500,8 @@ func (config *Config) DisableTensorRtOPs(ops []string) { /// may be more high-performance. Libnvinfer_plugin.so greater than /// V7.2.1 is needed. /// -func (config *Config) EnableTensorRtOSS() { - C.PD_ConfigEnableTensorRtOSS(config.c) +func (config *Config) EnableVarseqlen() { + C.PD_ConfigEnableVarseqlen(config.c) } /// diff --git a/paddle/fluid/inference/goapi/config_test.go b/paddle/fluid/inference/goapi/config_test.go index 297841dcbcf..080f2fd0135 100644 --- a/paddle/fluid/inference/goapi/config_test.go +++ b/paddle/fluid/inference/goapi/config_test.go @@ -54,7 +54,7 @@ func TestNewConfig(t *testing.T) { } config.SetTRTDynamicShapeInfo(minInputShape, maxInputShape, optInputShape, false) - config.EnableTensorRtOSS() + config.EnableVarseqlen() t.Logf("TensorrtOssEnabled:%+v", config.TensorrtOssEnabled()) config.EnableTensorRtDLA(0) @@ -138,4 +138,4 @@ func TestONNXRuntime(t *testing.T) { config.SetCpuMathLibraryNumThreads(4) t.Logf("CpuMathLibraryNumThreads:%+v", config.CpuMathLibraryNumThreads()) -} \ No newline at end of file +} diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 1910e2f6eb9..05ab3fb53e5 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -56,7 +56,11 @@ nv_library(tensorrt_converter strided_slice_op.cc preln_skip_layernorm.cc roll_op.cc + transformer_input_convert_op.cc + remove_padding_op.cc + recover_padding_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter) + diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index 7a494860e6f..ffb32bab522 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -30,23 +30,28 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { -#if IS_TRT_VERSION_GE(6000) VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer"; framework::OpDesc op_desc(op, nullptr); auto word_id_name = op_desc.Input("WordId").front(); - auto pos_id_name = op_desc.Input("PosId").front(); + auto pos_id_name = engine_->tensorrt_transformer_posid(); engine_->Set("ernie_pos_name", new std::string(pos_id_name)); auto sent_id_name = op_desc.Input("SentId").front(); + auto mask_id_name = engine_->tensorrt_transformer_maskid(); auto word_emb_name = op_desc.Input("WordEmbedding").front(); auto pos_emb_name = op_desc.Input("PosEmbedding").front(); auto sent_emb_name = op_desc.Input("SentEmbedding").front(); std::vector id_names; std::vector emb_names; + bool flag_varseqlen = + engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != ""; - if (engine_->use_oss()) { + if (flag_varseqlen) { + engine_->SetITensor("word_id", engine_->GetITensor(word_id_name)); + engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); + engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); id_names = std::vector{word_id_name, pos_id_name, sent_id_name}; emb_names = @@ -106,7 +111,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; bool enable_int8 = op_desc.HasAttr("enable_int8"); - if (engine_->use_oss()) { + if (flag_varseqlen) { int output_fp16 = static_cast((engine_->WithFp16() == 1) ? 1 : 0); if (enable_int8) { output_fp16 = 1; @@ -121,7 +126,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { output_fp16, 1, platform::errors::InvalidArgument( "Only Precision::KHalf(fp16) is supported when infering " - "ernie(bert) model with config.EnableTensorRtOSS(). " + "ernie(bert) model with config.EnableVarseqlen(). " "But Precision::KFloat32 is setted.")); const std::vector fields{ {"bert_embeddings_layernorm_beta", bias, @@ -159,8 +164,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { plugin_inputs.emplace_back( engine_->GetITensor(pos_id_name)); // cu_seqlens, // eval_placeholder_2 - auto max_seqlen_tensor = - engine_->GetITensor(engine_->network()->getInput(3)->getName()); + auto max_seqlen_tensor = engine_->GetITensor(mask_id_name); auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor); nvinfer1::Dims shape_dim; @@ -193,8 +197,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_scale); } if (engine_->with_interleaved()) { - VLOG(4) - << "fused emb_eltwise_layernorm op: use_oss and with_interleaved"; + VLOG(4) << "fused emb_eltwise_layernorm op: use_varseqlen and " + "with_interleaved"; if (!enable_int8) { PADDLE_THROW( platform::errors::Fatal("use with_interleaved must be int8.")); @@ -229,12 +233,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name}, test_mode); } - -#else - PADDLE_THROW(platform::errors::Fatal( - "You are running the TRT Dynamic Shape mode, need to confirm that " - "your TRT version is no less than 6.0")); -#endif } }; diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index a631332dae3..bf3170dacc7 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -250,8 +250,7 @@ class FcOpConverter : public OpConverter { } // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can // not add Shuffle layer in ernie's multihead. - if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && - x_dim.d[3] == 1 && x_num_col_dims == 2) { + if (x_dim.nbDims == 4 && x_num_col_dims == 1) { if (enable_int8 || support_int8) { // add conv1x1 layer nvinfer1::DimsHW nv_ksize(1, 1); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 4b4ad01f567..f06554e7ebb 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -76,12 +76,14 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; auto output_name = op_desc.Output("Out")[0]; - + bool flag_varseqlen = engine_->use_varseqlen() && + engine_->tensorrt_transformer_posid() != "" && + engine_->tensorrt_transformer_maskid() != ""; if (engine_->with_dynamic_shape()) { - if (engine_->use_oss()) { + if (flag_varseqlen) { if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { PADDLE_THROW(platform::errors::Fatal( - "use use_oss must be int8 or half, not float32.")); + "use use_varseqlen must be int8 or half, not float32.")); } nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), @@ -90,7 +92,8 @@ class MultiheadMatMulOpConverter : public OpConverter { static_cast(bias_data), static_cast(bias_t->numel())}; if (engine_->with_interleaved()) { - VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved"; + VLOG(4) << "fused multihead_matmul op: use_varseqlen and " + "with_interleaved"; if (!op_desc.HasAttr("Input_scale")) { PADDLE_THROW( platform::errors::Fatal("use with_interleaved must be int8.")); @@ -233,9 +236,6 @@ class MultiheadMatMulOpConverter : public OpConverter { BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; } } - - auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); - auto creator = GetPluginRegistry()->getPluginCreator( "CustomQKVToContextPluginDynamic", "2"); assert(creator != nullptr); @@ -272,18 +272,10 @@ class MultiheadMatMulOpConverter : public OpConverter { std::vector plugin_inputs; plugin_inputs.emplace_back(fc_layer->getOutput(0)); - plugin_inputs.emplace_back(mask_tensor); - if (engine_->Has("ernie_pos_name")) { - plugin_inputs.emplace_back(engine_->GetITensor( - engine_->Get("ernie_pos_name"))); - } else { - plugin_inputs.emplace_back(engine_->GetITensor( - engine_->network() - ->getInput(2) - ->getName())); // cu_seqlens, eval_placeholder_2 - } - auto max_seqlen_tensor = - engine_->GetITensor(engine_->network()->getInput(3)->getName()); + plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask")); + plugin_inputs.emplace_back(engine_->GetITensor("pos_id")); + + auto max_seqlen_tensor = engine_->GetITensor("mask_id"); auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( engine_, Shuffle, *const_cast(max_seqlen_tensor)); diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index 87fdbb71a3f..4ee8db7c69d 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -32,7 +32,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { #if IS_TRT_VERSION_GE(7000) VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"; - if (!(engine_->use_oss() && engine_->with_interleaved())) { + if (!(engine_->use_varseqlen() && engine_->with_interleaved())) { PADDLE_THROW(platform::errors::Fatal( "PrelnErnie: If you want to use oss, must be with interleaved")); } diff --git a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc index 8053135cc45..1e9aec29e34 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc @@ -24,7 +24,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { const framework::Scope& scope, bool test_mode) override { #if IS_TRT_VERSION_GE(7000) VLOG(4) << "convert fused preln_skip_layernorm op to tensorrt layer"; - if (!(engine_->use_oss() && engine_->with_interleaved())) { + if (!(engine_->use_varseqlen() && engine_->with_interleaved())) { PADDLE_THROW(platform::errors::Fatal( "PrelnErnie: If you want to use oss, must be with interleaved")); } @@ -60,7 +60,8 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; - VLOG(4) << "fused preln_skip_layernorm op: use_oss and with_interleaved"; + VLOG(4) + << "fused preln_skip_layernorm op: use_varseqlen and with_interleaved"; auto creator = GetPluginRegistry()->getPluginCreator( "CustomSkipLayerNormPluginDynamic", "4"); diff --git a/paddle/fluid/inference/tensorrt/convert/recover_padding_op.cc b/paddle/fluid/inference/tensorrt/convert/recover_padding_op.cc new file mode 100644 index 00000000000..8f996e1d0f8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/recover_padding_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Recover padding of transformer'input. + */ +class RecoverPadding : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "Recover padding of transformer'output: VarSeqlen -> Padding."; + if (!engine_->with_dynamic_shape()) { + PADDLE_THROW(platform::errors::Fatal( + "recover_padding_op: If you want to use transformer, must " + "be with dynamic shape")); + } + + framework::OpDesc op_desc(op, nullptr); + /* + auto x_var_name = op_desc.Input(InputNames()).front(); + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + */ + auto input_name = op_desc.Input("Input").front(); + + std::cout << "input_name: " << input_name << std::endl; + + std::vector plugin_inputs; + plugin_inputs.push_back(engine_->GetITensor(input_name)); + plugin_inputs.push_back(engine_->GetITensor("pos_id")); + plugin_inputs.push_back(engine_->GetITensor("mask_id")); + int input_num = 3; + auto output_name = op_desc.Output("Out").front(); + + plugin::RecoverPaddingPlugin* plugin = new plugin::RecoverPaddingPlugin(); + nvinfer1::ILayer* layer = + engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin); + + RreplenishLayerAndOutput(layer, "recover_padding", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(recover_padding, RecoverPadding); diff --git a/paddle/fluid/inference/tensorrt/convert/remove_padding_op.cc b/paddle/fluid/inference/tensorrt/convert/remove_padding_op.cc new file mode 100644 index 00000000000..49d5edbbd4e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/remove_padding_op.cc @@ -0,0 +1,69 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Remove padding of transformer'input. + */ +class RemovePadding : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "Remove padding of transformer'input: Padding -> VarSeqlen"; + if (!engine_->with_dynamic_shape()) { + PADDLE_THROW(platform::errors::Fatal( + "remove_padding_op: If you want to use transformer, must " + "be with dynamic shape")); + } + + framework::OpDesc op_desc(op, nullptr); + auto input_name = op_desc.Input("Input").front(); + + std::vector plugin_inputs; + plugin_inputs.push_back(engine_->GetITensor(input_name)); + plugin_inputs.push_back(engine_->GetITensor("pos_id")); + plugin_inputs.push_back(engine_->GetITensor("word_id")); + size_t input_num = plugin_inputs.size(); + auto output_name = op_desc.Output("Out").front(); + + plugin::RemovePaddingPlugin* plugin = new plugin::RemovePaddingPlugin(); + nvinfer1::ILayer* layer = + engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin); + + RreplenishLayerAndOutput(layer, "remove_padding_op", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(remove_padding, RemovePadding); diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 831e1173117..6f65e271923 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -52,10 +52,13 @@ class SkipLayerNormOpConverter : public OpConverter { bool enable_int8 = op_desc.HasAttr("enable_int8"); nvinfer1::ILayer* layer = nullptr; - - if (engine_->use_oss()) { + bool flag_varseqlen = engine_->use_varseqlen() && + engine_->tensorrt_transformer_posid() != "" && + engine_->tensorrt_transformer_maskid() != ""; + if (flag_varseqlen) { if (engine_->with_interleaved()) { - VLOG(4) << "fused skip_layernorm op: use_oss and with_interleaved"; + VLOG(4) + << "fused skip_layernorm op: use_varseqlen and with_interleaved"; if (!enable_int8) { PADDLE_THROW( platform::errors::Fatal("use with_interleaved must be int8.")); diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index dea9a1ec3d7..fa6f4889403 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" -#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h" namespace paddle { namespace inference { @@ -74,47 +73,12 @@ class SliceOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { - if (engine_->use_oss() && engine_->with_ernie() && - input_dims.nbDims == 4) { - std::vector plugin_inputs; - if (engine_->with_interleaved()) { - auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - nvinfer1::Permutation transpose_embed{2, 1, 0, 3}; - shuffler_slice->setSecondTranspose(transpose_embed); - engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0), - out_scale); - shuffler_slice->setName( - ("SpecialSlice_interleaved: transpose: (Output: " + output_name + - ")") - .c_str()); - plugin_inputs.emplace_back(shuffler_slice->getOutput(0)); - } else { - plugin_inputs.emplace_back(input); - } - std::string pos_name; - if (engine_->Has("ernie_pos_name")) { - pos_name = engine_->Get("ernie_pos_name"); - } else { - // hard code for compatibility - pos_name = engine_->network()->getInput(2)->getName(); - } - plugin_inputs.emplace_back( - engine_->GetITensor(pos_name)); // cu_seqlens, eval_placeholder_2 - - // bool ban_fp16 = engine_->disable_trt_plugin_fp16(); - plugin::SpecialSlicePluginDynamic* plugin = - new plugin::SpecialSlicePluginDynamic(); - layer = engine_->AddDynamicPlugin(plugin_inputs.data(), - plugin_inputs.size(), plugin); - } else { - bool with_fp16 = - engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - int decrease_axis = - decrease_axises.size() == 0 ? -1 : decrease_axises[0]; - plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( - starts, ends, axes, decrease_axis, with_fp16); - layer = engine_->AddDynamicPlugin(&input, 1, plugin); - } + bool with_fp16 = + engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0]; + plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( + starts, ends, axes, decrease_axis, with_fp16); + layer = engine_->AddDynamicPlugin(&input, 1, plugin); } else { bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); diff --git a/paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc b/paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc new file mode 100644 index 00000000000..045a5d163ca --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Convert Transformer Input(pos_id, max_seqlen). + */ +class TransformerInputConvert : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "Convert Transformer Input(pos_id, max_seqlen), use " + "transformer_input_convert_plugin"; + if (!engine_->with_dynamic_shape()) { + PADDLE_THROW(platform::errors::Fatal( + "transformer_input_convert_op: If you want to use transformer, must " + "be with dynamic shape")); + } + + framework::OpDesc op_desc(op, nullptr); + auto input_name = op_desc.Input("Input").front(); + auto* input = engine_->GetITensor(input_name); + int input_num = op_desc.Input("Input").size(); + + // tensorrt_subgraph_pass will rename tensor + // auto pos_id_name = op_desc.Output("PosId").front(); + // auto max_seqlen_name = op_desc.Output("MaxSeqlen").front(); + auto pos_id_name = "pos_id_tensor"; + auto max_seqlen_name = "max_seqlen_tensor"; + + plugin::TransformerInputConvertPlugin* plugin = + new plugin::TransformerInputConvertPlugin(); + nvinfer1::ILayer* layer = + engine_->AddDynamicPlugin(&input, input_num, plugin); + + RreplenishLayerAndOutput(layer, "transformer_input_convert", + {pos_id_name, max_seqlen_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(transformer_input_convert, TransformerInputConvert); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index f781cd0cb3a..598d751ad5f 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -410,14 +410,19 @@ class TensorRTEngine { suffix_counter += 1; } - void SetUseOSS(bool use_oss) { use_oss_ = use_oss; } + void SetUseOSS(bool use_varseqlen) { use_varseqlen_ = use_varseqlen; } void SetUseDLA(bool use_dla) { use_dla_ = use_dla; } void SetDLACore(int dla_core) { dla_core_ = dla_core; } void SetWithErnie(bool with_ernie) { with_ernie_ = with_ernie; } void SetWithInterleaved(bool with_interleaved) { with_interleaved_ = with_interleaved; } - + void SetTransformerPosid(std::string tensorrt_transformer_posid) { + tensorrt_transformer_posid_ = tensorrt_transformer_posid; + } + void SetTransformerMaskid(std::string tensorrt_transformer_maskid) { + tensorrt_transformer_maskid_ = tensorrt_transformer_maskid; + } void ClearWeights() { for (auto& weight_pair : weight_map) { weight_pair.second.reset(nullptr); @@ -488,9 +493,15 @@ class TensorRTEngine { return ret; } - bool use_oss() { return use_oss_; } + bool use_varseqlen() { return use_varseqlen_; } bool with_ernie() { return with_ernie_; } bool with_interleaved() { return with_interleaved_; } + std::string tensorrt_transformer_posid() { + return tensorrt_transformer_posid_; + } + std::string tensorrt_transformer_maskid() { + return tensorrt_transformer_maskid_; + } bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; } bool with_dynamic_shape() { return with_dynamic_shape_; } AnalysisConfig::Precision precision() { return precision_; } @@ -612,11 +623,13 @@ class TensorRTEngine { ShapeMapType max_input_shape_; ShapeMapType optim_input_shape_; bool disable_trt_plugin_fp16_{false}; - bool use_oss_{false}; + bool use_varseqlen_{false}; bool use_dla_{false}; int dla_core_{0}; bool with_ernie_{false}; bool with_interleaved_{false}; + std::string tensorrt_transformer_posid_; + std::string tensorrt_transformer_maskid_; nvinfer1::ILogger& logger_; // max data size for the buffers. diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 690bc173c77..79a5e7d7a6a 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -125,7 +125,10 @@ struct SimpleOpTypeSetTeller : public Teller { "strided_slice", "fused_preln_embedding_eltwise_layernorm", "roll", - "preln_skip_layernorm"}; + "preln_skip_layernorm", + "transformer_input_convert", + "recover_padding", + "remove_padding"}; std::unordered_set teller_set{ "mul", "matmul", @@ -194,7 +197,10 @@ struct SimpleOpTypeSetTeller : public Teller { "fused_preln_embedding_eltwise_layernorm", "preln_skip_layernorm", "roll", - "multiclass_nms3"}; + "multiclass_nms3", + "transformer_input_convert", + "recover_padding", + "remove_padding"}; }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index ff6a1cd60f7..ee1d6c1dc7d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -4,7 +4,7 @@ nv_library(tensorrt_plugin pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu - hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu + hard_swish_op_plugin.cu stack_op_plugin.cu anchor_generator_op_plugin.cu yolo_box_op_plugin.cu yolo_box_head_op_plugin.cu @@ -14,6 +14,9 @@ nv_library(tensorrt_plugin pool3d_op_plugin.cu deformable_conv_op_plugin.cu matmul_op_int8_plugin.cu + transformer_input_convert_plugin.cu + remove_padding_plugin.cu + recover_padding_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu new file mode 100644 index 00000000000..515e01f4053 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu @@ -0,0 +1,120 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +__global__ void RecoverPaddingKernel(const float* input0, const int32_t* input1, + float* output) { + int word_id = blockIdx.x * gridDim.y + blockIdx.y; + int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; + if (blockIdx.y < seqence_length) { + output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + + threadIdx.x] = + input0[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x + + blockIdx.z * blockDim.x + threadIdx.x]; + } else { + output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + + threadIdx.x] = 0; + } +} + +nvinfer1::DataType RecoverPaddingPlugin::getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + return input_types[0]; +} + +nvinfer1::DimsExprs RecoverPaddingPlugin::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs output_dims{}; + output_dims.nbDims = 3; + const auto* one = exprBuilder.constant(1); + output_dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, + *inputs[1].d[0], *one); + output_dims.d[1] = inputs[2].d[1]; + output_dims.d[2] = inputs[0].d[1]; + return output_dims; +} + +bool RecoverPaddingPlugin::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nbInputs, 3, + platform::errors::InvalidArgument("Must have 3 inputs, " + "but got %d input(s). ", + nbInputs)); + PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(), + platform::errors::InvalidArgument("Must have 1 output, " + "but got %d output(s). ", + nbOutputs)); + if (pos == 1) { // PosId, MaxSeqlen + return inOut[pos].type == nvinfer1::DataType::kINT32 && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + return inOut[pos].type == nvinfer1::DataType::kFLOAT && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format + // == nvinfer1::TensorFormat::kLINEAR)|| + // (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == + // nvinfer1::TensorFormat::kLINEAR)|| + // (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format == + // nvinfer1::TensorFormat::kCHW32); +} + +void RecoverPaddingPlugin::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + +void RecoverPaddingPlugin::attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + +void RecoverPaddingPlugin::detachFromContext() TRT_NOEXCEPT {} + +void RecoverPaddingPlugin::terminate() TRT_NOEXCEPT {} + +int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto input0_desc = inputDesc[0]; + const auto input1_desc = inputDesc[1]; + const auto input2_desc = inputDesc[2]; + const float* input0 = static_cast(inputs[0]); + const int32_t* input1 = + static_cast(inputs[1]); // pos_id_tensor + float* output = static_cast(outputs[0]); + const int32_t num_threads = 256; + const dim3 num_blocks( + input1_desc.dims.d[0] - 1, input2_desc.dims.d[1], + input0_desc.dims.d[1] / num_threads); // batchs, max sequnce length + // (mask_id.dims.d[1]), + // input.dims.d[1]/256 + RecoverPaddingKernel<<>>(input0, input1, + output); + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h new file mode 100644 index 00000000000..896cd05eef1 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h @@ -0,0 +1,133 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class RecoverPaddingPlugin : public DynamicPluginTensorRT { + public: + RecoverPaddingPlugin() {} + + RecoverPaddingPlugin(void const* serial_data, size_t serial_length) {} + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + RecoverPaddingPlugin* ptr = new RecoverPaddingPlugin(); + return ptr; + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "recover_padding_plugin"; + } + + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + + int initialize() TRT_NOEXCEPT { return 0; } + void terminate() TRT_NOEXCEPT; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) + TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + protected: + size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; } + + void serialize(void* buffer) const TRT_NOEXCEPT override {} +}; + +class RecoverPaddingPluginCreator : public nvinfer1::IPluginCreator { + public: + RecoverPaddingPluginCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "recover_padding_plugin"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* plugin_field) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, void const* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + RecoverPaddingPlugin* obj = + new RecoverPaddingPlugin(serial_data, serial_length); + obj->setPluginNamespace(name); + return obj; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; +}; +REGISTER_TRT_PLUGIN_V2(RecoverPaddingPluginCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu new file mode 100644 index 00000000000..84e36a4d5f6 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu @@ -0,0 +1,118 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +__global__ void RemovePaddingKernel(const float* input0, const int32_t* input1, + float* output) { + int word_id = blockIdx.x * gridDim.y + blockIdx.y; + int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; + if (blockIdx.y < seqence_length) { + output[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x + + blockIdx.z * blockDim.x + threadIdx.x] = + input0[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + + threadIdx.x]; + } +} + +nvinfer1::DataType RemovePaddingPlugin::getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + return input_types[0]; +} + +nvinfer1::DimsExprs RemovePaddingPlugin::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs output_dims{}; + output_dims.nbDims = 4; + output_dims.d[0] = inputs[2].d[0]; + output_dims.d[1] = inputs[0].d[2]; + output_dims.d[2] = exprBuilder.constant(1); + output_dims.d[3] = exprBuilder.constant(1); + + return output_dims; +} + +bool RemovePaddingPlugin::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nbInputs, 3, + platform::errors::InvalidArgument("Must have 3 inputs, " + "but got %d input(s). ", + nbInputs)); + PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(), + platform::errors::InvalidArgument("Must have 1 output, " + "but got %d output(s). ", + nbOutputs)); + if (pos == 1 || pos == 2) { // pos_id, work_id + return inOut[pos].type == nvinfer1::DataType::kINT32 && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + return inOut[pos].type == nvinfer1::DataType::kFLOAT && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format + // == nvinfer1::TensorFormat::kLINEAR)|| + // (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == + // nvinfer1::TensorFormat::kLINEAR)|| + // (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format == + // nvinfer1::TensorFormat::kCHW32); +} + +void RemovePaddingPlugin::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + +void RemovePaddingPlugin::attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + +void RemovePaddingPlugin::detachFromContext() TRT_NOEXCEPT {} + +void RemovePaddingPlugin::terminate() TRT_NOEXCEPT {} + +int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto input_desc = inputDesc[0]; + const float* input0 = static_cast(inputs[0]); + const int32_t* input1 = + static_cast(inputs[1]); // pos_id_tensor + float* output = static_cast(outputs[0]); + + const auto input0_desc = inputDesc[0]; + + const int32_t num_threads = 256; + const dim3 num_blocks( + input0_desc.dims.d[0], input0_desc.dims.d[1], + input0_desc.dims.d[2] / + num_threads); // batchs, max sequnce length, input.dims.d[2]/256 + + RemovePaddingKernel<<>>(input0, input1, + output); + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h new file mode 100644 index 00000000000..6679f2f0819 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h @@ -0,0 +1,133 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class RemovePaddingPlugin : public DynamicPluginTensorRT { + public: + RemovePaddingPlugin() {} + + RemovePaddingPlugin(void const* serial_data, size_t serial_length) {} + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + RemovePaddingPlugin* ptr = new RemovePaddingPlugin(); + return ptr; + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "remove_padding_plugin"; + } + + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + + int initialize() TRT_NOEXCEPT { return 0; } + void terminate() TRT_NOEXCEPT; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) + TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + protected: + size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; } + + void serialize(void* buffer) const TRT_NOEXCEPT override {} +}; + +class RemovePaddingPluginCreator : public nvinfer1::IPluginCreator { + public: + RemovePaddingPluginCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "remove_padding_plugin"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* plugin_field) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, void const* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + RemovePaddingPlugin* obj = + new RemovePaddingPlugin(serial_data, serial_length); + obj->setPluginNamespace(name); + return obj; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; +}; +REGISTER_TRT_PLUGIN_V2(RemovePaddingPluginCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu deleted file mode 100644 index 324e9c0392c..00000000000 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -#if IS_TRT_VERSION_GE(6000) -SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {} - -SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data, - size_t serial_length) {} - -SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {} - -nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const - TRT_NOEXCEPT { - return new SpecialSlicePluginDynamic(); -} - -const char* SpecialSlicePluginDynamic::getPluginType() const TRT_NOEXCEPT { - return "special_slice_plugin"; -} - -int SpecialSlicePluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -int SpecialSlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; } - -size_t SpecialSlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - size_t serialize_size = 0; - return serialize_size; -} - -void SpecialSlicePluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {} - -nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions( - int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, - nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { - nvinfer1::DimsExprs output(inputs[0]); - output.nbDims++; - for (int i = output.nbDims - 1; i > 1; i--) { - output.d[i] = inputs[0].d[i - 1]; - } - auto one = expr_builder.constant(1); - output.d[1] = one; - output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB, - *inputs[1].d[0], *one); - // remove padding 1 - output.nbDims -= 2; - - return output; -} - -void SpecialSlicePluginDynamic::configurePlugin( - const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {} - -size_t SpecialSlicePluginDynamic::getWorkspaceSize( - const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -void SpecialSlicePluginDynamic::destroy() TRT_NOEXCEPT { delete this; } - -void SpecialSlicePluginDynamic::terminate() TRT_NOEXCEPT {} - -bool SpecialSlicePluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs, - int nb_outputs) TRT_NOEXCEPT { - if (pos == 0) // slice tensor - return (desc[pos].type == nvinfer1::DataType::kHALF && - desc[pos].format == - nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type == - // nvinfer1::DataType::kFLOAT); - - if (pos == 1) // cu_seqlen - return (desc[pos].type == nvinfer1::DataType::kINT32 && - desc[pos].format == nvinfer1::TensorFormat::kLINEAR); - - return (desc[pos].type == nvinfer1::DataType::kHALF && - desc[pos].format == - nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type == - // nvinfer1::DataType::kFLOAT); -} - -nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType* input_types, - int nb_inputs) const TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( - "The index should be equal to 0")); - return input_types[0]; -} - -template -__global__ void SpecialSliceKernel(const T* slice_input, - const int32_t* cu_seqlens, T* output) { - const int hidden = blockDim.x * gridDim.x; - const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x; - const int batch_id = blockIdx.y; - - output[batch_id * hidden + hidden_id] = - slice_input[cu_seqlens[batch_id] * hidden + hidden_id]; -} - -int SpecialSlicePluginDynamic::enqueue( - const nvinfer1::PluginTensorDesc* input_desc, - const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs, - void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { - auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1) - auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1) - - PADDLE_ENFORCE_EQ( - input_desc[0].type, nvinfer1::DataType::kHALF, - platform::errors::InvalidArgument("Type of input should be half.")); - - const int32_t hidden = input_dims.d[1]; - PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument( - "hidden should be multiple of 128.")); - - constexpr int num_threads = 128; - const half* slice_input = static_cast(inputs[0]); - const int32_t* cu_seqlens = static_cast(inputs[1]); - half* output = static_cast(outputs[0]); - - const int32_t num_blocks_x = hidden / num_threads; - const int32_t num_blocks_y = out_dims.d[0]; // batchs - const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks - - SpecialSliceKernel<<>>( - slice_input, cu_seqlens, output); - return cudaGetLastError() != cudaSuccess; -} - -SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {} - -const char* SpecialSlicePluginDynamicCreator::getPluginName() const - TRT_NOEXCEPT { - return "special_slice_plugin"; -} - -const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const - TRT_NOEXCEPT { - return "1"; -} - -const nvinfer1::PluginFieldCollection* -SpecialSlicePluginDynamicCreator::getFieldNames() TRT_NOEXCEPT { - return &field_collection_; -} - -nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT { - return new SpecialSlicePluginDynamic(); -} - -nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin( - const char* name, const void* serial_data, - size_t serial_length) TRT_NOEXCEPT { - auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length); - return plugin; -} - -void SpecialSlicePluginDynamicCreator::setPluginNamespace( - const char* lib_namespace) TRT_NOEXCEPT { - plugin_namespace_ = lib_namespace; -} - -const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const - TRT_NOEXCEPT { - return plugin_namespace_.c_str(); -} - -#endif - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h deleted file mode 100644 index c3521e4ed63..00000000000 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -#if IS_TRT_VERSION_GE(6000) -class SpecialSlicePluginDynamic : public DynamicPluginTensorRT { - public: - SpecialSlicePluginDynamic(); - SpecialSlicePluginDynamic(void const* serial_data, size_t serial_length); - ~SpecialSlicePluginDynamic(); - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, - cudaStream_t stream) TRT_NOEXCEPT override; - - nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - const char* getPluginType() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - int initialize() TRT_NOEXCEPT override; - void terminate() TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; - void destroy() TRT_NOEXCEPT override; - - private: - int axis_; - int num_stack_; -}; - -class SpecialSlicePluginDynamicCreator : public nvinfer1::IPluginCreator { - public: - SpecialSlicePluginDynamicCreator(); - const char* getPluginName() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; - const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; - nvinfer1::IPluginV2* createPlugin(const char* name, - const nvinfer1::PluginFieldCollection* fc) - TRT_NOEXCEPT override; - nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serial_data, - size_t serial_length) TRT_NOEXCEPT override; - void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override; - const char* getPluginNamespace() const TRT_NOEXCEPT override; - - private: - std::string plugin_namespace_; - nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; - std::vector plugin_attributes_; -}; -REGISTER_TRT_PLUGIN_V2(SpecialSlicePluginDynamicCreator); -#endif - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu new file mode 100644 index 00000000000..a7fff027816 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu @@ -0,0 +1,110 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +__global__ void TransformerInputConvertKernel(const int64_t* input, + int32_t* output0) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + __shared__ int32_t shared_data; + if (threadIdx.x == static_cast(input[tid])) { + atomicAdd(&shared_data, 1); + } + output0[0] = 0; + output0[blockIdx.x + 1] = shared_data; + __syncthreads(); + for (int i = 0; i < blockDim.x; ++i) { + output0[i + 1] += output0[i]; + } +} + +nvinfer1::DataType TransformerInputConvertPlugin::getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + return nvinfer1::DataType::kINT32; +} + +nvinfer1::DimsExprs TransformerInputConvertPlugin::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs output_dims{}; + output_dims.nbDims = 1; + if (outputIndex == 0) { // PosId + const auto* one = exprBuilder.constant(1); + output_dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, + *inputs[0].d[0], *one); + } else { // MaxSeqlen + output_dims.d[0] = inputs[0].d[1]; + } + return output_dims; +} + +bool TransformerInputConvertPlugin::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument("Must have 1 inputs, " + "but got %d input(s). ", + nbInputs)); + PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(), + platform::errors::InvalidArgument("Must have 2 output, " + "but got %d output(s). ", + nbOutputs)); + if (pos == 0) { // input + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + } else { // output0, output1 + return inOut[pos].type == nvinfer1::DataType::kINT32 && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + } +} + +void TransformerInputConvertPlugin::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + +void TransformerInputConvertPlugin::attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + +void TransformerInputConvertPlugin::detachFromContext() TRT_NOEXCEPT {} + +void TransformerInputConvertPlugin::terminate() TRT_NOEXCEPT {} + +int TransformerInputConvertPlugin::enqueue( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { + const auto input_desc = inputDesc[0]; + const int64_t* input = static_cast(inputs[0]); + int32_t* output0 = static_cast(outputs[0]); // PosId + // int32_t* output1 = static_cast(outputs[1]); // MaxSeqlen + + const int32_t num_blocks = input_desc.dims.d[0]; // batchs + const int32_t num_threads = input_desc.dims.d[1]; // max sequnce length + + TransformerInputConvertKernel<<>>( + input, output0); + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h b/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h new file mode 100644 index 00000000000..87dc876fa9c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h @@ -0,0 +1,134 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class TransformerInputConvertPlugin : public DynamicPluginTensorRT { + public: + TransformerInputConvertPlugin() {} + + TransformerInputConvertPlugin(void const* serial_data, size_t serial_length) { + } + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + TransformerInputConvertPlugin* ptr = new TransformerInputConvertPlugin(); + return ptr; + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "transformer_input_convert_plugin"; + } + + int getNbOutputs() const TRT_NOEXCEPT override { return 2; } + + int initialize() TRT_NOEXCEPT { return 0; } + void terminate() TRT_NOEXCEPT; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) + TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + protected: + size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; } + + void serialize(void* buffer) const TRT_NOEXCEPT override {} +}; + +class TransformerInputConvertPluginCreator : public nvinfer1::IPluginCreator { + public: + TransformerInputConvertPluginCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "transformer_input_convert_plugin"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* plugin_field) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, void const* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + TransformerInputConvertPlugin* obj = + new TransformerInputConvertPlugin(serial_data, serial_length); + obj->setPluginNamespace(name); + return obj; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; +}; +REGISTER_TRT_PLUGIN_V2(TransformerInputConvertPluginCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/analyzer_capi_exp_gpu_tester.cc b/paddle/fluid/inference/tests/api/analyzer_capi_exp_gpu_tester.cc index dcda34c64da..d11d09458e4 100644 --- a/paddle/fluid/inference/tests/api/analyzer_capi_exp_gpu_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_capi_exp_gpu_tester.cc @@ -65,7 +65,7 @@ TEST(PD_Config, gpu_interface) { &min_shape_ptr, &max_shape_ptr, &opt_shape_ptr, FALSE); PD_ConfigDisableTensorRtOPs(config, 1, &ops_name); - PD_ConfigEnableTensorRtOSS(config); + PD_ConfigEnableVarseqlen(config); bool oss_enabled = PD_ConfigTensorRtOssEnabled(config); EXPECT_TRUE(oss_enabled); diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index 1058a5b5ec6..262b7269cb3 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -210,7 +210,11 @@ std::shared_ptr InitPredictor() { config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, opt_input_shape); // erinie varlen must be used with oss - config.EnableTensorRtOSS(); + config.EnableVarseqlen(); + paddle_infer::experimental::InternalUtils::SetTransformerPosid(&config, + input_name2); + paddle_infer::experimental::InternalUtils::SetTransformerMaskid(&config, + input_name3); return paddle_infer::CreatePredictor(config); } diff --git a/paddle/fluid/inference/tests/infer_ut/test_ernie_xnli_int8.cc b/paddle/fluid/inference/tests/infer_ut/test_ernie_xnli_int8.cc index 4e924e31979..53edc554eba 100644 --- a/paddle/fluid/inference/tests/infer_ut/test_ernie_xnli_int8.cc +++ b/paddle/fluid/inference/tests/infer_ut/test_ernie_xnli_int8.cc @@ -68,7 +68,7 @@ std::shared_ptr InitPredictor() { config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, opt_input_shape); // erinie varlen must be used with oss - config.EnableTensorRtOSS(); + config.EnableVarseqlen(); return CreatePredictor(config); } diff --git a/paddle/fluid/inference/utils/table_printer_tester.cc b/paddle/fluid/inference/utils/table_printer_tester.cc index f56d2527d73..8faac79c517 100644 --- a/paddle/fluid/inference/utils/table_printer_tester.cc +++ b/paddle/fluid/inference/utils/table_printer_tester.cc @@ -43,7 +43,7 @@ TEST(table_printer, output) { table.InsertRow({"trt_precision", "fp32"}); table.InsertRow({"enable_dynamic_shape", "true"}); table.InsertRow({"DisableTensorRtOPs", "{}"}); - table.InsertRow({"EnableTensorRtOSS", "ON"}); + table.InsertRow({"EnableVarseqlen", "ON"}); table.InsertRow({"tensorrt_dla_enabled", "ON"}); table.InsetDivider(); diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 94478148407..d4c19364d48 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -657,8 +657,9 @@ void BindAnalysisConfig(py::module *m) { py::arg("disable_trt_plugin_fp16") = false) .def("tensorrt_dynamic_shape_enabled", &AnalysisConfig::tensorrt_dynamic_shape_enabled) - .def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS) - .def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled) + .def("enable_tensorrt_varseqlen", &AnalysisConfig::EnableVarseqlen) + .def("tensorrt_varseqlen_enabled", + &AnalysisConfig::tensorrt_varseqlen_enabled) .def("collect_shape_range_info", &AnalysisConfig::CollectShapeRangeInfo) .def("shape_range_info_path", &AnalysisConfig::shape_range_info_path) .def("shape_range_info_collected", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py index 20d9b9d972d..88045324b38 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py @@ -42,7 +42,7 @@ class InferencePassTest(unittest.TestCase): self.enable_mkldnn = False self.enable_mkldnn_bfloat16 = False self.enable_trt = False - self.enable_tensorrt_oss = True + self.enable_tensorrt_varseqlen = True self.trt_parameters = None self.dynamic_shape_params = None self.enable_lite = False @@ -134,8 +134,8 @@ class InferencePassTest(unittest.TestCase): self.dynamic_shape_params.max_input_shape, self.dynamic_shape_params.optim_input_shape, self.dynamic_shape_params.disable_trt_plugin_fp16) - if self.enable_tensorrt_oss: - config.enable_tensorrt_oss() + if self.enable_tensorrt_varseqlen: + config.enable_tensorrt_varseqlen() elif use_mkldnn: config.enable_mkldnn() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py b/python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py index 1ca7799963b..9ea72043379 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py @@ -46,7 +46,7 @@ class QuantDequantTest(unittest.TestCase): self.enable_mkldnn = False self.enable_mkldnn_bfloat16 = False self.enable_trt = False - self.enable_tensorrt_oss = True + self.enable_tensorrt_varseqlen = True self.trt_parameters = None self.dynamic_shape_params = None self.enable_lite = False @@ -184,8 +184,8 @@ class QuantDequantTest(unittest.TestCase): self.dynamic_shape_params.max_input_shape, self.dynamic_shape_params.optim_input_shape, self.dynamic_shape_params.disable_trt_plugin_fp16) - if self.enable_tensorrt_oss: - config.enable_tensorrt_oss() + if self.enable_tensorrt_varseqlen: + config.enable_tensorrt_varseqlen() elif use_mkldnn: config.enable_mkldnn() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py index ed993ffce7d..8540555497d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py @@ -179,7 +179,7 @@ def multiclass_nms(bboxes, class TensorRTMultiClassNMS3Test(InferencePassTest): def setUp(self): self.enable_trt = True - self.enable_tensorrt_oss = True + self.enable_tensorrt_varseqlen = True self.precision = AnalysisConfig.Precision.Float32 self.serialize = False self.bs = 1 @@ -291,8 +291,8 @@ class TensorRTMultiClassNMS3Test(InferencePassTest): self.background = 7 self.run_test() - def test_disable_oss(self): - self.diable_tensorrt_oss = False + def test_disable_varseqlen(self): + self.diable_tensorrt_varseqlen = False self.run_test() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py index 045261fabb0..5c9ad5de5a7 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py @@ -25,7 +25,7 @@ from paddle.fluid.core import AnalysisConfig class TensorRTMultiClassNMSTest(InferencePassTest): def setUp(self): self.enable_trt = True - self.enable_tensorrt_oss = True + self.enable_tensorrt_varseqlen = True self.precision = AnalysisConfig.Precision.Float32 self.serialize = False self.bs = 1 @@ -135,8 +135,8 @@ class TensorRTMultiClassNMSTest(InferencePassTest): self.background = 7 self.run_test() - def test_disable_oss(self): - self.diable_tensorrt_oss = False + def test_disable_varseqlen(self): + self.diable_tensorrt_varseqlen = False self.run_test() -- GitLab