From 4d77214434cfc1bc0600661d6a72cfb36f4a616f Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 27 Sep 2022 15:58:38 +0800 Subject: [PATCH] [Paddle Inference]support n lookup_tables fuse to embeddinglayernorm(3) (#46243) * [Paddle Inference]support n lookup_tables fuse to embeddinglayernorm(3) --- paddle/fluid/framework/ir/CMakeLists.txt | 7 +- ...n_embedding_eltwise_layernorm_fuse_pass.cc | 19 +- ...t_embedding_eltwise_layernorm_fuse_pass.cc | 19 -- .../fluid/inference/api/analysis_predictor.cc | 5 +- .../inference/api/paddle_pass_builder.cc | 23 +- .../inference/tensorrt/convert/CMakeLists.txt | 7 +- .../tensorrt/convert/emb_eltwise_layernorm.cc | 266 ++++++++---------- .../convert/preln_emb_eltwise_layernorm.cc | 234 ++++++++------- .../inference/tensorrt/plugin/CMakeLists.txt | 5 + .../tensorrt/plugin/common/bertCommon.h | 8 +- .../tensorrt/plugin/common/serialize.h | 2 +- ...any_emb_Layernorm_varseqlen_kernelHFace.cu | 3 +- ...any_emb_Layernorm_varseqlen_kernelMTron.cu | 3 +- .../many_emb_layernorm_varseqlen_plugin.cu | 18 +- 14 files changed, 286 insertions(+), 333 deletions(-) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 455417b521d..08d5e23b6f4 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -119,10 +119,8 @@ target_link_libraries(generate_pass pass_desc_proto) if(WITH_TENSORRT) pass_library(trt_map_matmul_to_mul_pass inference) - pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference) - pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(set_transformer_input_convert_pass inference) pass_library(remove_padding_recover_padding_pass inference) @@ -130,6 +128,11 @@ if(WITH_TENSORRT) pass_library(layernorm_shift_partition_fuse_pass inference) endif() +if(WITH_TENSORRT AND NOT WIN32) + pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) + pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) +endif() + if(WITH_GPU OR WITH_ROCM) pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(embedding_eltwise_layernorm_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc index ddcde5014a4..5281f27ff1c 100644 --- a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc @@ -154,7 +154,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( /*const Scope* scope*/) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); std::vector>> start_pattern_in_nodes; std::vector start_pattern_out_node; std::vector> start_pattern_remove_nodes; @@ -331,17 +332,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm"); new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Embs", embs); - new_op_desc.SetInput("WordId", {ids[0]}); - new_op_desc.SetInput("PosId", {ids[1]}); - if (ids.size() > 2) { - new_op_desc.SetInput("SentId", {ids[2]}); - } - - new_op_desc.SetInput("WordEmbedding", {embs[0]}); - new_op_desc.SetInput("PosEmbedding", {embs[1]}); - if (embs.size() > 2) { - new_op_desc.SetInput("SentEmbedding", {embs[2]}); - } + new_op_desc.SetInput("PosId", {pos_id}); + new_op_desc.SetInput("MaskId", {mask_id}); new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); @@ -441,8 +433,6 @@ PrelnEmbeddingEltwiseLayerNormFusePass:: } void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { - FusePassBase::Init(name_scope_, graph); - bool enable_int8 = Get("enable_int8"); bool use_varseqlen = Get("use_varseqlen"); bool with_interleaved = Get("with_interleaved"); @@ -458,6 +448,7 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { "please reconfig."; return; } + FusePassBase::Init(name_scope_, graph); int fusion_count = PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_); if (fusion_count > 0) { diff --git a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc index 2b0a5e5c93f..f870796a4c1 100644 --- a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc @@ -326,33 +326,14 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion( embs.push_back(inner_pattern_ins[js[iter]].second->Name()); } - // todo: support any inputs with lookup_table_v2 - if (ids.size() < 3) { - VLOG(3) << "trt_embedding_eltwise_layernorm_fuse_pass only support >=3 " - "inputs with lookup_table_v2"; - return fusion_count; - } OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block()); new_op_desc.SetType("fused_embedding_eltwise_layernorm"); new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Embs", embs); - new_op_desc.SetInput("WordId", {ids[0]}); if (use_varseqlen && pos_id != "" && mask_id != "") { new_op_desc.SetInput("PosId", {pos_id}); new_op_desc.SetInput("MaskId", {mask_id}); - } else { - new_op_desc.SetInput("PosId", {ids[1]}); - } - if (ids.size() > 2) { - new_op_desc.SetInput("SentId", {ids[2]}); } - - new_op_desc.SetInput("WordEmbedding", {embs[0]}); - new_op_desc.SetInput("PosEmbedding", {embs[1]}); - if (embs.size() > 2) { - new_op_desc.SetInput("SentEmbedding", {embs[2]}); - } - new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index ae34fd52341..1df1425de24 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2143,7 +2143,6 @@ USE_TRT_CONVERTER(instance_norm); USE_TRT_CONVERTER(layer_norm); USE_TRT_CONVERTER(gelu); USE_TRT_CONVERTER(multihead_matmul); -USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(skip_layernorm); USE_TRT_CONVERTER(slice); USE_TRT_CONVERTER(scale); @@ -2172,7 +2171,11 @@ USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(mish); USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(pool3d) +#ifdef _WIN32 +#else USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) +USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); +#endif USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(preln_residual_bias) USE_TRT_CONVERTER(c_allreduce_sum) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1f1f86e70c9..222c90703a5 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -95,17 +95,22 @@ const std::vector kTRTSubgraphPasses({ "delete_quant_dequant_linear_op_pass", // "add_support_int8_pass", // // "fc_fuse_pass", // - "simplify_with_basic_ops_pass", // + "simplify_with_basic_ops_pass", // + +#if defined _WIN32 +#else "trt_embedding_eltwise_layernorm_fuse_pass", // "preln_embedding_eltwise_layernorm_fuse_pass", // - "delete_c_identity_op_pass", // - "trt_multihead_matmul_fuse_pass_v2", // - "trt_multihead_matmul_fuse_pass_v3", // - "vit_attention_fuse_pass", // - "trt_skip_layernorm_fuse_pass", // - "preln_skip_layernorm_fuse_pass", // - "preln_residual_bias_fuse_pass", // - "layernorm_shift_partition_fuse_pass", // +#endif + + "delete_c_identity_op_pass", // + "trt_multihead_matmul_fuse_pass_v2", // + "trt_multihead_matmul_fuse_pass_v3", // + "vit_attention_fuse_pass", // + "trt_skip_layernorm_fuse_pass", // + "preln_skip_layernorm_fuse_pass", // + "preln_residual_bias_fuse_pass", // + "layernorm_shift_partition_fuse_pass", // // "set_transformer_input_convert_pass", // "conv_bn_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 3a2fb526078..9bc0ad9114b 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -30,7 +30,6 @@ list( transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc - emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc @@ -57,7 +56,6 @@ list( bilinear_interp_v2_op.cc pool3d_op.cc deformable_conv_op.cc - preln_emb_eltwise_layernorm.cc strided_slice_op.cc preln_skip_layernorm.cc roll_op.cc @@ -80,6 +78,11 @@ list( layernorm_shift_partition_op.cc generic_and_custom_plugin_creater.cc) +if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) + list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc + preln_emb_eltwise_layernorm.cc) +endif() + if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) endif() diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index fbf49ece755..24dbd8a0e17 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" #include "paddle/phi/core/ddim.h" namespace paddle { @@ -36,30 +37,64 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { bool test_mode) override { VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer"; + // get the presistable var's data + auto GetWeight = [&](const std::string& var_name, + framework::DDim* dim) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + *dim = temp_tensor->dims(); + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; + }; + + auto GetFp16Weight = [&](const std::string& var_name, + framework::DDim* dim) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + *dim = temp_tensor->dims(); + auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor); + return weight; + }; + + auto GetFp32Weight = [&](const std::string& var_name, + framework::DDim* dim) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + *dim = temp_tensor->dims(); + auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor); + return weight; + }; + framework::OpDesc op_desc(op, nullptr); - auto word_id_name = op_desc.Input("WordId").front(); auto pos_id_name = engine_->tensorrt_transformer_posid(); - engine_->Set("ernie_pos_name", new std::string(pos_id_name)); - - auto sent_id_name = op_desc.Input("SentId").front(); auto mask_id_name = engine_->tensorrt_transformer_maskid(); - auto word_emb_name = op_desc.Input("WordEmbedding").front(); - auto pos_emb_name = op_desc.Input("PosEmbedding").front(); - auto sent_emb_name = op_desc.Input("SentEmbedding").front(); - - std::vector id_names; - std::vector emb_names; bool flag_varseqlen = engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != ""; + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + int hidden = 0; + // Declare inputs + std::vector input_ids; + + // Declare inputs_weight + std::vector input_embs; + std::vector emb_sizes; + TensorRTEngine::Weight weight; + framework::DDim emb_dims; + framework::DDim bias_dims, scale_dims; + TensorRTEngine::Weight bias_weight, scale_weight; + + int64_t bias_size = phi::product(bias_dims); + int64_t scale_size = phi::product(scale_dims); + nvinfer1::ILayer* layer = nullptr; + bool enable_int8 = op_desc.HasAttr("enable_int8"); + + std::vector id_names = op_desc.Input("Ids"); + std::vector emb_names = op_desc.Input("Embs"); + int input_num = id_names.size(); if (flag_varseqlen) { - engine_->SetITensor("word_id", engine_->GetITensor(word_id_name)); engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); - id_names = - std::vector{word_id_name, pos_id_name, sent_id_name}; - emb_names = - std::vector{word_emb_name, pos_emb_name, sent_emb_name}; auto mask_id_tensor = engine_->GetITensor("mask_id"); auto mask_dims = mask_id_tensor->getDimensions(); @@ -72,16 +107,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { } auto* shape_tensor = Shape(mask_id_tensor); - std::vector start_vec_tensor; std::vector size_vec_tensor; + std::vector start_vec_tensor; for (int i = 0; i < mask_dims.nbDims; i++) { - start_vec_tensor.push_back(Add1DConstantLayer(0)); size_vec_tensor.push_back(Add1DConstantLayer(1)); + start_vec_tensor.push_back(Add1DConstantLayer(0)); } size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1); - - auto start_tensor = Concat(start_vec_tensor); auto size_tensor = Concat(size_vec_tensor); + auto start_tensor = Concat(start_vec_tensor); auto slice_layer = TRT_ENGINE_ADD_LAYER(engine_, @@ -109,110 +143,32 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { .c_str()); engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f); engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0)); - } else { - id_names = op_desc.Input("Ids"); - emb_names = op_desc.Input("Embs"); - } - - int input_num = id_names.size(); - - // Declare inputs - std::vector input_ids; - for (int i = 0; i < input_num; i++) { - input_ids.push_back(engine_->GetITensor(id_names[i])); - } - - // input_embs[0]: word_embedding - // input_embs[1]: pos_embedding - // input_embs[2]: sent_embedding - std::vector input_embs; - std::vector emb_sizes; - // get the presistable var's data - auto GetWeight = [&](const std::string& var_name, - framework::DDim* dim) -> TensorRTEngine::Weight { - auto* temp_var = scope.FindVar(var_name); - auto* temp_tensor = temp_var->GetMutable(); - *dim = temp_tensor->dims(); - auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); - return weight; - }; - - auto GetFp16Weight = [&](const std::string& var_name, - framework::DDim* dim) -> TensorRTEngine::Weight { - auto* temp_var = scope.FindVar(var_name); - auto* temp_tensor = temp_var->GetMutable(); - *dim = temp_tensor->dims(); - auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor); - return weight; - }; - - auto GetFp32Weight = [&](const std::string& var_name, - framework::DDim* dim) -> TensorRTEngine::Weight { - auto* temp_var = scope.FindVar(var_name); - auto* temp_tensor = temp_var->GetMutable(); - *dim = temp_tensor->dims(); - auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor); - return weight; - }; - bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - int hidden = 0; - for (int i = 0; i < input_num; i++) { - framework::DDim emb_dims; - TensorRTEngine::Weight weight; - if (flag_varseqlen) { + for (int i = 0; i < input_num; i++) { + auto input_tensor = engine_->GetITensor(id_names[i]); weight = GetWeight(emb_names[i], &emb_dims); - } else { - if (with_fp16) { - weight = GetFp16Weight(emb_names[i], &emb_dims); + if (id_names[i] == pos_id_name) { + input_ids.insert(input_ids.begin(), input_tensor); + input_embs.insert(input_embs.begin(), weight.get()); + emb_sizes.insert(emb_sizes.begin(), weight.get().count); } else { - weight = GetFp32Weight(emb_names[i], &emb_dims); + input_ids.push_back(input_tensor); + input_embs.push_back(weight.get()); + emb_sizes.push_back(weight.get().count); } + hidden = emb_dims[1]; } - input_embs.push_back(weight.get()); - emb_sizes.push_back(weight.get().count); - PADDLE_ENFORCE_EQ( - emb_dims.size(), - 2, - platform::errors::InvalidArgument( - "The fused EmbEltwiseLayerNorm's emb should be 2 dims.")); - hidden = emb_dims[1]; - } - - framework::DDim bias_dims, scale_dims; - TensorRTEngine::Weight bias_weight, scale_weight; - if (flag_varseqlen) { bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); - } else { - if (with_fp16) { - bias_weight = GetFp16Weight(op_desc.Input("Bias").front(), &bias_dims); - scale_weight = - GetFp16Weight(op_desc.Input("Scale").front(), &scale_dims); - } else { - bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims); - scale_weight = - GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims); - } - } - - int64_t bias_size = phi::product(bias_dims); - int64_t scale_size = phi::product(scale_dims); - nvinfer1::ILayer* layer = nullptr; - bool enable_int8 = op_desc.HasAttr("enable_int8"); + bias_size = phi::product(bias_dims); + scale_size = phi::product(scale_dims); + // other_id(except pos_id) + engine_->SetITensor("word_id", input_ids[1]); - if (flag_varseqlen) { int output_fp16 = static_cast((engine_->WithFp16() == 1) ? 1 : 0); if (enable_int8) { output_fp16 = 1; } - PADDLE_ENFORCE_EQ( - input_num, - 3, - platform::errors::InvalidArgument( - "When using oss and var-len, embedding_eltwise_layernorm op" - "should have 3 inputs only, but got %d.", - input_num)); PADDLE_ENFORCE_EQ( output_fp16, 1, @@ -220,29 +176,27 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { "Only Precision::KHalf(fp16) is supported when infering " "ernie(bert) model with config.EnableVarseqlen(). " "But Precision::KFloat32 is setted.")); - const std::vector fields{ - {"bert_embeddings_layernorm_beta", - bias_weight.get().values, - GetPluginFieldType(bias_weight.get().type), - static_cast(bias_size)}, - {"bert_embeddings_layernorm_gamma", - scale_weight.get().values, - GetPluginFieldType(scale_weight.get().type), - static_cast(scale_size)}, - {"bert_embeddings_word_embeddings", - input_embs[0].values, - GetPluginFieldType(input_embs[0].type), - static_cast(emb_sizes[0])}, - {"bert_embeddings_token_type_embeddings", - input_embs[2].values, - GetPluginFieldType(input_embs[2].type), - static_cast(emb_sizes[2])}, - {"bert_embeddings_position_embeddings", - input_embs[1].values, - GetPluginFieldType(input_embs[1].type), - static_cast(emb_sizes[1])}, - {"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1}, - }; + + std::vector fields; + std::vector temp_fields_keys; + fields.emplace_back("bert_embeddings_layernorm_beta", + bias_weight.get().values, + GetPluginFieldType(bias_weight.get().type), + static_cast(bias_size)); + fields.emplace_back("bert_embeddings_layernorm_gamma", + scale_weight.get().values, + GetPluginFieldType(scale_weight.get().type), + static_cast(scale_size)); + fields.emplace_back( + "output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1); + for (int i = 0; i < input_num; ++i) { + temp_fields_keys.push_back("bert_embeddings_word_embeddings_" + + std::to_string(i)); + fields.emplace_back(temp_fields_keys.rbegin()->c_str(), + input_embs[i].values, + GetPluginFieldType(input_embs[i].type), + static_cast(emb_sizes[i])); + } nvinfer1::PluginFieldCollection* plugin_ptr = static_cast( @@ -251,27 +205,19 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { plugin_ptr->nbFields = static_cast(fields.size()); plugin_ptr->fields = fields.data(); - std::vector plugin_inputs; - plugin_inputs.emplace_back( - engine_->GetITensor(word_id_name)); // word_embedding, - // eval_placeholder_0 - plugin_inputs.emplace_back( - engine_->GetITensor(sent_id_name)); // sent_embedding, - // eval_placeholder_1 - plugin_inputs.emplace_back( - engine_->GetITensor(pos_id_name)); // cu_seqlens, - // eval_placeholder_2 + std::vector plugin_inputs = input_ids; plugin_inputs.emplace_back(engine_->GetITensor( "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 auto creator = GetPluginRegistry()->getPluginCreator( - "CustomEmbLayerNormPluginDynamic", "2"); - + "ManyEmbLayerNormPluginDynamic", "2"); auto plugin_obj = - creator->createPlugin("CustomEmbLayerNormPluginDynamic", plugin_ptr); + creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); + auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); - plugin_layer->setName(("CustomEmbLayerNormPluginDynamic_V2(Output: " + + + plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V2(Output: " + op_desc.Output("Out")[0] + ")") .c_str()); free(plugin_ptr); @@ -302,11 +248,33 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { layer = plugin_layer; auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, - "CustomEmbLayerNormPluginDynamic_V2", + "ManyEmbLayerNormPluginDynamic_V2", {output_name, std::string("qkv_plugin_mask")}, test_mode); } } else { + for (int i = 0; i < input_num; i++) { + if (with_fp16) { + weight = GetFp16Weight(emb_names[i], &emb_dims); + } else { + weight = GetFp32Weight(emb_names[i], &emb_dims); + } + input_ids.push_back(engine_->GetITensor(id_names[i])); + input_embs.push_back(weight.get()); + emb_sizes.push_back(weight.get().count); + hidden = emb_dims[1]; + } + if (with_fp16) { + bias_weight = GetFp16Weight(op_desc.Input("Bias").front(), &bias_dims); + scale_weight = + GetFp16Weight(op_desc.Input("Scale").front(), &scale_dims); + } else { + bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims); + scale_weight = + GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims); + } + bias_size = phi::product(bias_dims); + scale_size = phi::product(scale_dims); float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); plugin::DynamicPluginTensorRT* plugin = nullptr; std::vector input_embs_data; diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index 8079915b9e0..47992536c48 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -10,7 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/utils.h" #include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" namespace paddle { namespace framework { @@ -32,6 +34,15 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { bool test_mode) override { #if IS_TRT_VERSION_GE(7000) VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"; + // get the presistable var's data + auto GetWeight = [&](const std::string& var_name, + framework::DDim* dim) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + *dim = temp_tensor->dims(); + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; + }; auto pos_id_name = engine_->tensorrt_transformer_posid(); auto mask_id_name = engine_->tensorrt_transformer_maskid(); @@ -50,126 +61,48 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { PADDLE_THROW( platform::errors::Fatal("use with_interleaved must be int8.")); } - auto word_id_name = op_desc.Input("WordId").front(); - engine_->Set("ernie_pos_name", new std::string(pos_id_name)); - - auto sent_id_name = op_desc.Input("SentId").front(); - auto word_emb_name = op_desc.Input("WordEmbedding").front(); - auto pos_emb_name = op_desc.Input("PosEmbedding").front(); - auto sent_emb_name = op_desc.Input("SentEmbedding").front(); - - engine_->SetITensor("word_id", engine_->GetITensor(word_id_name)); - engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); - engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); - - std::vector emb_names; - emb_names = - std::vector{word_emb_name, pos_emb_name, sent_emb_name}; + // Declare inputs + std::vector input_ids; - int input_num = emb_names.size(); - - // input_embs[0]: word_embedding - // input_embs[1]: pos_embedding - // input_embs[2]: sent_embedding - std::vector input_embs; + // Declare inputs_weight + std::vector input_embs; std::vector emb_sizes; - - // get the presistable var's data - auto get_persistable_data = [&](const std::string& var_name, - framework::DDim* dims) -> float* { - auto* temp_var = scope.FindVar(var_name); - auto* temp_tensor = temp_var->GetMutable(); - (*dims) = temp_tensor->dims(); - - auto* temp_data = const_cast(static_cast( - engine_->GetFp32TrtWeight(var_name, *temp_tensor).get().values)); - return temp_data; - }; - - for (int i = 0; i < input_num; i++) { - framework::DDim emb_dims; - float* emb_data = get_persistable_data(emb_names[i], &emb_dims); - int64_t emb_size = phi::product(emb_dims); - input_embs.push_back(emb_data); - emb_sizes.push_back(emb_size); - PADDLE_ENFORCE_EQ( - emb_dims.size(), - 2, - platform::errors::InvalidArgument( - "The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims.")); - } - + TensorRTEngine::Weight weight; + framework::DDim emb_dims; framework::DDim bias_dims, scale_dims; + TensorRTEngine::Weight bias_weight, scale_weight; - auto* bias = - get_persistable_data(op_desc.Input("Bias").front(), &bias_dims); - auto* scale = - get_persistable_data(op_desc.Input("Scale").front(), &scale_dims); int64_t bias_size = phi::product(bias_dims); int64_t scale_size = phi::product(scale_dims); - int output_int8 = 1; - PADDLE_ENFORCE_EQ( - input_num, - 3, - platform::errors::InvalidArgument( - "When using oss and var-len, embedding_eltwise_layernorm op" - "should have 3 inputs only, but got %d.", - input_num)); - const std::vector fields{ - {"bert_embeddings_layernorm_beta", - bias, - nvinfer1::PluginFieldType::kFLOAT32, - static_cast(bias_size)}, - {"bert_embeddings_layernorm_gamma", - scale, - nvinfer1::PluginFieldType::kFLOAT32, - static_cast(scale_size)}, - {"bert_embeddings_word_embeddings", - input_embs[0], - nvinfer1::PluginFieldType::kFLOAT32, - static_cast(emb_sizes[0])}, - {"bert_embeddings_token_type_embeddings", - input_embs[2], - nvinfer1::PluginFieldType::kFLOAT32, - static_cast(emb_sizes[2])}, - {"bert_embeddings_position_embeddings", - input_embs[1], - nvinfer1::PluginFieldType::kFLOAT32, - static_cast(emb_sizes[1])}, - {"output_fp16", &output_int8, nvinfer1::PluginFieldType::kINT32, 1}, - }; + std::vector id_names = op_desc.Input("Ids"); + std::vector emb_names = op_desc.Input("Embs"); + int input_num = id_names.size(); - nvinfer1::PluginFieldCollection* plugin_ptr = - static_cast( - malloc(sizeof(*plugin_ptr) + - fields.size() * sizeof(nvinfer1::PluginField))); - plugin_ptr->nbFields = static_cast(fields.size()); - plugin_ptr->fields = fields.data(); + engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); + engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); - std::vector plugin_inputs; - plugin_inputs.emplace_back( - engine_->GetITensor(word_id_name)); // word_embedding, - // eval_placeholder_0 - plugin_inputs.emplace_back( - engine_->GetITensor(sent_id_name)); // sent_embedding, - // eval_placeholder_1 - plugin_inputs.emplace_back( - engine_->GetITensor(pos_id_name)); // cu_seqlens, - // eval_placeholder_2 auto mask_id_tensor = engine_->GetITensor("mask_id"); auto mask_dims = mask_id_tensor->getDimensions(); auto slice_start_dims = mask_dims; - auto slice_size_dims = mask_dims; auto slice_stride_dims = mask_dims; for (int i = 0; i < mask_dims.nbDims; i++) { slice_start_dims.d[i] = 0; - slice_size_dims.d[i] = 1; slice_stride_dims.d[i] = 1; } - slice_size_dims.d[1] = mask_dims.d[1]; - auto* slice_size_tensor = Add1DConstantLayer(slice_size_dims); + + auto* shape_tensor = Shape(mask_id_tensor); + std::vector size_vec_tensor; + std::vector start_vec_tensor; + for (int i = 0; i < mask_dims.nbDims; i++) { + size_vec_tensor.push_back(Add1DConstantLayer(1)); + start_vec_tensor.push_back(Add1DConstantLayer(0)); + } + size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1); + auto size_tensor = Concat(size_vec_tensor); + auto start_tensor = Concat(start_vec_tensor); + auto slice_layer = TRT_ENGINE_ADD_LAYER(engine_, Slice, @@ -177,11 +110,11 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { slice_start_dims, slice_start_dims, slice_stride_dims); // unuseful slice_start_dims - slice_layer->setInput(2, *slice_size_tensor); - slice_layer->setName( - ("PrelnEmbeltwise_slice_layer (Output: slice_max_seqlen " + - op_desc.Output("Out")[0] + ")") - .c_str()); + slice_layer->setInput(1, *start_tensor); + slice_layer->setInput(2, *size_tensor); + slice_layer->setName(("Embeltwise_slice_layer (Output: slice_max_seqlen " + + op_desc.Output("Out")[0] + ")") + .c_str()); engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f); auto* reshape_layer = @@ -190,24 +123,87 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { shape_dim.nbDims = 1; shape_dim.d[0] = -1; reshape_layer->setReshapeDimensions(shape_dim); - reshape_layer->setName( - ("PrelnEmbeltwise_reshape_layer (Output: max_seqlen " + - op_desc.Output("Out")[0] + ")") - .c_str()); + reshape_layer->setName(("Embeltwise_reshape_layer (Output: max_seqlen " + + op_desc.Output("Out")[0] + ")") + .c_str()); engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f); engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0)); - plugin_inputs.emplace_back( - reshape_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + + for (int i = 0; i < input_num; i++) { + auto input_tensor = engine_->GetITensor(id_names[i]); + weight = GetWeight(emb_names[i], &emb_dims); + if (id_names[i] == pos_id_name) { + input_ids.insert(input_ids.begin(), input_tensor); + input_embs.insert(input_embs.begin(), weight.get()); + emb_sizes.insert(emb_sizes.begin(), weight.get().count); + } else { + input_ids.push_back(input_tensor); + input_embs.push_back(weight.get()); + emb_sizes.push_back(weight.get().count); + } + } + bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); + scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); + bias_size = phi::product(bias_dims); + scale_size = phi::product(scale_dims); + // other_id(except pos_id) + engine_->SetITensor("word_id", input_ids[1]); + + int output_fp16 = static_cast((engine_->WithFp16() == 1) ? 1 : 0); + if (enable_int8) { + output_fp16 = 1; + } + PADDLE_ENFORCE_EQ( + output_fp16, + 1, + platform::errors::InvalidArgument( + "Only Precision::KHalf(fp16) is supported when infering " + "ernie(bert) model with config.EnableVarseqlen(). " + "But Precision::KFloat32 is setted.")); + + std::vector fields; + std::vector temp_fields_keys; + fields.emplace_back("bert_embeddings_layernorm_beta", + bias_weight.get().values, + GetPluginFieldType(bias_weight.get().type), + static_cast(bias_size)); + fields.emplace_back("bert_embeddings_layernorm_gamma", + scale_weight.get().values, + GetPluginFieldType(scale_weight.get().type), + static_cast(scale_size)); + fields.emplace_back( + "output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1); + for (int i = 0; i < input_num; ++i) { + temp_fields_keys.push_back("bert_embeddings_word_embeddings_" + + std::to_string(i)); + fields.emplace_back(temp_fields_keys.rbegin()->c_str(), + input_embs[i].values, + GetPluginFieldType(input_embs[i].type), + static_cast(emb_sizes[i])); + } + + nvinfer1::PluginFieldCollection* plugin_ptr = + static_cast( + malloc(sizeof(*plugin_ptr) + + fields.size() * sizeof(nvinfer1::PluginField))); + plugin_ptr->nbFields = static_cast(fields.size()); + plugin_ptr->fields = fields.data(); + + std::vector plugin_inputs = input_ids; + plugin_inputs.emplace_back(engine_->GetITensor( + "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 auto creator = GetPluginRegistry()->getPluginCreator( - "CustomEmbLayerNormPluginDynamic", "3"); + "ManyEmbLayerNormPluginDynamic", "3"); auto plugin_obj = - creator->createPlugin("CustomEmbLayerNormPluginDynamic", plugin_ptr); + creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); + auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); - plugin_layer->setName(("CustomPrelnEmbLayerNormPluginDynamic_V3(Output: " + - op_desc.Output("Out_0")[0] + ")") + + plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V3(Output: " + + op_desc.Output("Out")[0] + ")") .c_str()); free(plugin_ptr); float out_0_scale = @@ -226,7 +222,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { engine_->SetITensor(op_desc.Output("Out_0")[0], shuffler_embed_out0->getOutput(0)); shuffler_embed_out0->setName( - ("shuffler_after_CustomPrelnEmbLayerNormPluginDynamic_V3(Output_0: " + + ("shuffler_after_ManyEmbLayerNormPluginDynamic_V3(Output_0: " + op_desc.Output("Out_0")[0] + ")") .c_str()); @@ -240,7 +236,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { engine_->SetITensor(op_desc.Output("Out_1")[0], shuffler_embed_out1->getOutput(0)); shuffler_embed_out1->setName( - ("shuffler_after_CustomPrelnEmbLayerNormPluginDynamic_V3(Output_1: " + + ("shuffler_after_ManyEmbLayerNormPluginDynamic_V3(Output_1: " + op_desc.Output("Out_1")[0] + ")") .c_str()); diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 9fe02cd731d..b091ef42d8c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -34,6 +34,11 @@ list( fused_token_prune_op_plugin.cu layernorm_shift_partition_op.cu generic_plugin.cu) +if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) + list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu + many_emb_Layernorm_varseqlen_kernelMTron.cu + many_emb_Layernorm_varseqlen_kernelHFace.cu) +endif() if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND TRT_FILES spmm_plugin.cu) diff --git a/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h b/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h index 3783b197ae0..fc07cdededc 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h +++ b/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h @@ -96,8 +96,8 @@ template inline void serFromDev(char** buffer, const T* data, size_t nbElem) { const size_t len = sizeof(T) * nbElem; cudaMemcpy( - buffer, static_cast(data), len, cudaMemcpyDeviceToHost); - buffer += len; + *buffer, static_cast(data), len, cudaMemcpyDeviceToHost); + *buffer += len; } template @@ -174,8 +174,8 @@ struct WeightsWithOwnership : public nvinfer1::Weights { const auto nbBytes = getWeightsSize(*this, type); auto destBuf = new char[nbBytes]; this->values = destBuf; - std::copy_n(srcBuf, nbBytes, destBuf); - srcBuf += nbBytes; + std::copy_n(*srcBuf, nbBytes, destBuf); + *srcBuf += nbBytes; } }; diff --git a/paddle/fluid/inference/tensorrt/plugin/common/serialize.h b/paddle/fluid/inference/tensorrt/plugin/common/serialize.h index b51528cb5ab..39dc7fdf502 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/serialize.h +++ b/paddle/fluid/inference/tensorrt/plugin/common/serialize.h @@ -59,7 +59,7 @@ template <> struct Serializer { static size_t serialized_size(const char* value) { return strlen(value) + 1; } static void serialize(void** buffer, const char* value) { - ::snprintf(static_cast(*buffer), value); + ::strcpy(static_cast(*buffer), value); // NOLINT reinterpret_cast(*buffer) += strlen(value) + 1; } static void deserialize(void const** buffer, diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu index 366acbc11e0..b89ac08404b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu @@ -25,8 +25,7 @@ #include "common/common.cuh" #include "common/plugin.h" #include "common/serialize.h" -// #include -// "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu index d33a3772139..198d7c57b67 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu @@ -25,8 +25,7 @@ #include "common/common.cuh" #include "common/plugin.h" #include "common/serialize.h" -// #include -// "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu index 17c46bf0cf6..8ad149bd959 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu @@ -61,8 +61,8 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( assert(beta.count == gamma.count); mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT); mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT); - copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev); - copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev); + copyToDevice(&mGamma, sizeof(float) * mGamma.count, &mGammaDev); + copyToDevice(&mBeta, sizeof(float) * mBeta.count, &mBetaDev); for (size_t i = 0; i < mIdsEmb_.size(); ++i) { assert(mIdsEmb_[i].count % mLd == 0); mIdsVocabSize.push_back(int32_t(mIdsEmb_[i].count / mLd)); @@ -96,8 +96,8 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( mIdsVocabSize.push_back(tem); } char const* d = static_cast(data); - mBeta.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); - mGamma.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); + mBeta.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT); + mGamma.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < nbLookupTables_; ++i) { nvinfer1::Weights pre_tem_weight; pre_tem_weight.type = mType; @@ -565,10 +565,10 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept { } char* d = static_cast(buffer); size_t const wordSize = getElementSize(mType); - serFromDev(d, mBetaDev.get(), mLd); - serFromDev(d, mGammaDev.get(), mLd); + serFromDev(&d, mBetaDev.get(), mLd); + serFromDev(&d, mGammaDev.get(), mLd); for (size_t i = 0; i < mIdsEmbDev.size(); ++i) { - serFromDev(d, + serFromDev(&d, static_cast(mIdsEmbDev[i]), mLd * mIdsVocabSize[i] * wordSize); } @@ -673,7 +673,7 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin( nvinfer1::Weights beta; nvinfer1::Weights gamma; std::vector IdsEmb; - bool output_fp16 = initializeFields(fc, beta, gamma, IdsEmb); + bool output_fp16 = initializeFields(fc, &beta, &gamma, &IdsEmb); TRANSFORMER_DEBUG_MSG("Building the Plugin..."); EmbLayerNormVarSeqlenPluginHFace* p = new EmbLayerNormVarSeqlenPluginHFace( name, @@ -691,7 +691,7 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginMTronCreator::createPlugin( nvinfer1::Weights beta; nvinfer1::Weights gamma; std::vector IdsEmb; - bool output_fp16 = initializeFields(fc, beta, gamma, IdsEmb); + bool output_fp16 = initializeFields(fc, &beta, &gamma, &IdsEmb); TRANSFORMER_DEBUG_MSG("Building the Plugin..."); EmbLayerNormVarSeqlenPluginMTron* p = new EmbLayerNormVarSeqlenPluginMTron( name, -- GitLab