未验证 提交 4d772144 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]support n lookup_tables fuse to embeddinglayernorm(3) (#46243)

* [Paddle Inference]support n lookup_tables fuse to embeddinglayernorm(3)
上级 ba1bbe8e
...@@ -119,10 +119,8 @@ target_link_libraries(generate_pass pass_desc_proto) ...@@ -119,10 +119,8 @@ target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT) if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference)
pass_library(set_transformer_input_convert_pass inference) pass_library(set_transformer_input_convert_pass inference)
pass_library(remove_padding_recover_padding_pass inference) pass_library(remove_padding_recover_padding_pass inference)
...@@ -130,6 +128,11 @@ if(WITH_TENSORRT) ...@@ -130,6 +128,11 @@ if(WITH_TENSORRT)
pass_library(layernorm_shift_partition_fuse_pass inference) pass_library(layernorm_shift_partition_fuse_pass inference)
endif() endif()
if(WITH_TENSORRT AND NOT WIN32)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference) pass_library(embedding_eltwise_layernorm_fuse_pass inference)
......
...@@ -154,7 +154,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -154,7 +154,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
/*const Scope* scope*/) const { /*const Scope* scope*/) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
std::vector<std::vector<std::pair<Node*, Node*>>> start_pattern_in_nodes; std::vector<std::vector<std::pair<Node*, Node*>>> start_pattern_in_nodes;
std::vector<Node*> start_pattern_out_node; std::vector<Node*> start_pattern_out_node;
std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes; std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes;
...@@ -331,17 +332,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -331,17 +332,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm"); new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs); new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]}); new_op_desc.SetInput("PosId", {pos_id});
new_op_desc.SetInput("PosId", {ids[1]}); new_op_desc.SetInput("MaskId", {mask_id});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
...@@ -441,8 +433,6 @@ PrelnEmbeddingEltwiseLayerNormFusePass:: ...@@ -441,8 +433,6 @@ PrelnEmbeddingEltwiseLayerNormFusePass::
} }
void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
bool enable_int8 = Get<bool>("enable_int8"); bool enable_int8 = Get<bool>("enable_int8");
bool use_varseqlen = Get<bool>("use_varseqlen"); bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved"); bool with_interleaved = Get<bool>("with_interleaved");
...@@ -458,6 +448,7 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -458,6 +448,7 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
"please reconfig."; "please reconfig.";
return; return;
} }
FusePassBase::Init(name_scope_, graph);
int fusion_count = int fusion_count =
PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_); PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) { if (fusion_count > 0) {
......
...@@ -326,33 +326,14 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -326,33 +326,14 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs.push_back(inner_pattern_ins[js[iter]].second->Name()); embs.push_back(inner_pattern_ins[js[iter]].second->Name());
} }
// todo: support any inputs with lookup_table_v2
if (ids.size() < 3) {
VLOG(3) << "trt_embedding_eltwise_layernorm_fuse_pass only support >=3 "
"inputs with lookup_table_v2";
return fusion_count;
}
OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block()); OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_embedding_eltwise_layernorm"); new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs); new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
if (use_varseqlen && pos_id != "" && mask_id != "") { if (use_varseqlen && pos_id != "" && mask_id != "") {
new_op_desc.SetInput("PosId", {pos_id}); new_op_desc.SetInput("PosId", {pos_id});
new_op_desc.SetInput("MaskId", {mask_id}); new_op_desc.SetInput("MaskId", {mask_id});
} else {
new_op_desc.SetInput("PosId", {ids[1]});
}
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
} }
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
......
...@@ -2143,7 +2143,6 @@ USE_TRT_CONVERTER(instance_norm); ...@@ -2143,7 +2143,6 @@ USE_TRT_CONVERTER(instance_norm);
USE_TRT_CONVERTER(layer_norm); USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu); USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul); USE_TRT_CONVERTER(multihead_matmul);
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER(skip_layernorm); USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice); USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale); USE_TRT_CONVERTER(scale);
...@@ -2172,7 +2171,11 @@ USE_TRT_CONVERTER(conv3d_transpose); ...@@ -2172,7 +2171,11 @@ USE_TRT_CONVERTER(conv3d_transpose);
USE_TRT_CONVERTER(mish); USE_TRT_CONVERTER(mish);
USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(deformable_conv);
USE_TRT_CONVERTER(pool3d) USE_TRT_CONVERTER(pool3d)
#ifdef _WIN32
#else
USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
#endif
USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(preln_residual_bias) USE_TRT_CONVERTER(preln_residual_bias)
USE_TRT_CONVERTER(c_allreduce_sum) USE_TRT_CONVERTER(c_allreduce_sum)
......
...@@ -95,17 +95,22 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -95,17 +95,22 @@ const std::vector<std::string> kTRTSubgraphPasses({
"delete_quant_dequant_linear_op_pass", // "delete_quant_dequant_linear_op_pass", //
"add_support_int8_pass", // "add_support_int8_pass", //
// "fc_fuse_pass", // // "fc_fuse_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
#if defined _WIN32
#else
"trt_embedding_eltwise_layernorm_fuse_pass", // "trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", // "preln_embedding_eltwise_layernorm_fuse_pass", //
"delete_c_identity_op_pass", // #endif
"trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", // "delete_c_identity_op_pass", //
"vit_attention_fuse_pass", // "trt_multihead_matmul_fuse_pass_v2", //
"trt_skip_layernorm_fuse_pass", // "trt_multihead_matmul_fuse_pass_v3", //
"preln_skip_layernorm_fuse_pass", // "vit_attention_fuse_pass", //
"preln_residual_bias_fuse_pass", // "trt_skip_layernorm_fuse_pass", //
"layernorm_shift_partition_fuse_pass", // "preln_skip_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
"layernorm_shift_partition_fuse_pass", //
// "set_transformer_input_convert_pass", // // "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
......
...@@ -30,7 +30,6 @@ list( ...@@ -30,7 +30,6 @@ list(
transpose_op.cc transpose_op.cc
flatten_op.cc flatten_op.cc
flatten_contiguous_range_op.cc flatten_contiguous_range_op.cc
emb_eltwise_layernorm.cc
skip_layernorm.cc skip_layernorm.cc
scale_op.cc scale_op.cc
slice_op.cc slice_op.cc
...@@ -57,7 +56,6 @@ list( ...@@ -57,7 +56,6 @@ list(
bilinear_interp_v2_op.cc bilinear_interp_v2_op.cc
pool3d_op.cc pool3d_op.cc
deformable_conv_op.cc deformable_conv_op.cc
preln_emb_eltwise_layernorm.cc
strided_slice_op.cc strided_slice_op.cc
preln_skip_layernorm.cc preln_skip_layernorm.cc
roll_op.cc roll_op.cc
...@@ -80,6 +78,11 @@ list( ...@@ -80,6 +78,11 @@ list(
layernorm_shift_partition_op.cc layernorm_shift_partition_op.cc
generic_and_custom_plugin_creater.cc) generic_and_custom_plugin_creater.cc)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32)
list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc
preln_emb_eltwise_layernorm.cc)
endif()
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc)
endif() endif()
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
namespace paddle { namespace paddle {
...@@ -36,30 +37,64 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -36,30 +37,64 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
bool test_mode) override { bool test_mode) override {
VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer"; VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer";
// get the presistable var's data
auto GetWeight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetTrtWeight(var_name, *temp_tensor);
return weight;
};
auto GetFp16Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor);
return weight;
};
auto GetFp32Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor);
return weight;
};
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto word_id_name = op_desc.Input("WordId").front();
auto pos_id_name = engine_->tensorrt_transformer_posid(); auto pos_id_name = engine_->tensorrt_transformer_posid();
engine_->Set("ernie_pos_name", new std::string(pos_id_name));
auto sent_id_name = op_desc.Input("SentId").front();
auto mask_id_name = engine_->tensorrt_transformer_maskid(); auto mask_id_name = engine_->tensorrt_transformer_maskid();
auto word_emb_name = op_desc.Input("WordEmbedding").front();
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
std::vector<std::string> id_names;
std::vector<std::string> emb_names;
bool flag_varseqlen = bool flag_varseqlen =
engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != ""; engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != "";
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int hidden = 0;
// Declare inputs
std::vector<nvinfer1::ITensor*> input_ids;
// Declare inputs_weight
std::vector<nvinfer1::Weights> input_embs;
std::vector<int> emb_sizes;
TensorRTEngine::Weight weight;
framework::DDim emb_dims;
framework::DDim bias_dims, scale_dims;
TensorRTEngine::Weight bias_weight, scale_weight;
int64_t bias_size = phi::product(bias_dims);
int64_t scale_size = phi::product(scale_dims);
nvinfer1::ILayer* layer = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
std::vector<std::string> id_names = op_desc.Input("Ids");
std::vector<std::string> emb_names = op_desc.Input("Embs");
int input_num = id_names.size();
if (flag_varseqlen) { if (flag_varseqlen) {
engine_->SetITensor("word_id", engine_->GetITensor(word_id_name));
engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
id_names =
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
auto mask_id_tensor = engine_->GetITensor("mask_id"); auto mask_id_tensor = engine_->GetITensor("mask_id");
auto mask_dims = mask_id_tensor->getDimensions(); auto mask_dims = mask_id_tensor->getDimensions();
...@@ -72,16 +107,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -72,16 +107,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
} }
auto* shape_tensor = Shape(mask_id_tensor); auto* shape_tensor = Shape(mask_id_tensor);
std::vector<nvinfer1::ITensor*> start_vec_tensor;
std::vector<nvinfer1::ITensor*> size_vec_tensor; std::vector<nvinfer1::ITensor*> size_vec_tensor;
std::vector<nvinfer1::ITensor*> start_vec_tensor;
for (int i = 0; i < mask_dims.nbDims; i++) { for (int i = 0; i < mask_dims.nbDims; i++) {
start_vec_tensor.push_back(Add1DConstantLayer(0));
size_vec_tensor.push_back(Add1DConstantLayer(1)); size_vec_tensor.push_back(Add1DConstantLayer(1));
start_vec_tensor.push_back(Add1DConstantLayer(0));
} }
size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1); size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1);
auto start_tensor = Concat(start_vec_tensor);
auto size_tensor = Concat(size_vec_tensor); auto size_tensor = Concat(size_vec_tensor);
auto start_tensor = Concat(start_vec_tensor);
auto slice_layer = auto slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, TRT_ENGINE_ADD_LAYER(engine_,
...@@ -109,110 +143,32 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -109,110 +143,32 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
.c_str()); .c_str());
engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f); engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f);
engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0)); engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0));
} else {
id_names = op_desc.Input("Ids");
emb_names = op_desc.Input("Embs");
}
int input_num = id_names.size();
// Declare inputs
std::vector<nvinfer1::ITensor*> input_ids;
for (int i = 0; i < input_num; i++) {
input_ids.push_back(engine_->GetITensor(id_names[i]));
}
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
// input_embs[2]: sent_embedding
std::vector<nvinfer1::Weights> input_embs;
std::vector<int> emb_sizes;
// get the presistable var's data for (int i = 0; i < input_num; i++) {
auto GetWeight = [&](const std::string& var_name, auto input_tensor = engine_->GetITensor(id_names[i]);
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetTrtWeight(var_name, *temp_tensor);
return weight;
};
auto GetFp16Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor);
return weight;
};
auto GetFp32Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor);
return weight;
};
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int hidden = 0;
for (int i = 0; i < input_num; i++) {
framework::DDim emb_dims;
TensorRTEngine::Weight weight;
if (flag_varseqlen) {
weight = GetWeight(emb_names[i], &emb_dims); weight = GetWeight(emb_names[i], &emb_dims);
} else { if (id_names[i] == pos_id_name) {
if (with_fp16) { input_ids.insert(input_ids.begin(), input_tensor);
weight = GetFp16Weight(emb_names[i], &emb_dims); input_embs.insert(input_embs.begin(), weight.get());
emb_sizes.insert(emb_sizes.begin(), weight.get().count);
} else { } else {
weight = GetFp32Weight(emb_names[i], &emb_dims); input_ids.push_back(input_tensor);
input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count);
} }
hidden = emb_dims[1];
} }
input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count);
PADDLE_ENFORCE_EQ(
emb_dims.size(),
2,
platform::errors::InvalidArgument(
"The fused EmbEltwiseLayerNorm's emb should be 2 dims."));
hidden = emb_dims[1];
}
framework::DDim bias_dims, scale_dims;
TensorRTEngine::Weight bias_weight, scale_weight;
if (flag_varseqlen) {
bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims);
} else { bias_size = phi::product(bias_dims);
if (with_fp16) { scale_size = phi::product(scale_dims);
bias_weight = GetFp16Weight(op_desc.Input("Bias").front(), &bias_dims); // other_id(except pos_id)
scale_weight = engine_->SetITensor("word_id", input_ids[1]);
GetFp16Weight(op_desc.Input("Scale").front(), &scale_dims);
} else {
bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight =
GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims);
}
}
int64_t bias_size = phi::product(bias_dims);
int64_t scale_size = phi::product(scale_dims);
nvinfer1::ILayer* layer = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (flag_varseqlen) {
int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0); int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
if (enable_int8) { if (enable_int8) {
output_fp16 = 1; output_fp16 = 1;
} }
PADDLE_ENFORCE_EQ(
input_num,
3,
platform::errors::InvalidArgument(
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d.",
input_num));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
output_fp16, output_fp16,
1, 1,
...@@ -220,29 +176,27 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -220,29 +176,27 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
"Only Precision::KHalf(fp16) is supported when infering " "Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.EnableVarseqlen(). " "ernie(bert) model with config.EnableVarseqlen(). "
"But Precision::KFloat32 is setted.")); "But Precision::KFloat32 is setted."));
const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta", std::vector<nvinfer1::PluginField> fields;
bias_weight.get().values, std::vector<std::string> temp_fields_keys;
GetPluginFieldType(bias_weight.get().type), fields.emplace_back("bert_embeddings_layernorm_beta",
static_cast<int32_t>(bias_size)}, bias_weight.get().values,
{"bert_embeddings_layernorm_gamma", GetPluginFieldType(bias_weight.get().type),
scale_weight.get().values, static_cast<int32_t>(bias_size));
GetPluginFieldType(scale_weight.get().type), fields.emplace_back("bert_embeddings_layernorm_gamma",
static_cast<int32_t>(scale_size)}, scale_weight.get().values,
{"bert_embeddings_word_embeddings", GetPluginFieldType(scale_weight.get().type),
input_embs[0].values, static_cast<int32_t>(scale_size));
GetPluginFieldType(input_embs[0].type), fields.emplace_back(
static_cast<int32_t>(emb_sizes[0])}, "output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1);
{"bert_embeddings_token_type_embeddings", for (int i = 0; i < input_num; ++i) {
input_embs[2].values, temp_fields_keys.push_back("bert_embeddings_word_embeddings_" +
GetPluginFieldType(input_embs[2].type), std::to_string(i));
static_cast<int32_t>(emb_sizes[2])}, fields.emplace_back(temp_fields_keys.rbegin()->c_str(),
{"bert_embeddings_position_embeddings", input_embs[i].values,
input_embs[1].values, GetPluginFieldType(input_embs[i].type),
GetPluginFieldType(input_embs[1].type), static_cast<int32_t>(emb_sizes[i]));
static_cast<int32_t>(emb_sizes[1])}, }
{"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1},
};
nvinfer1::PluginFieldCollection* plugin_ptr = nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>( static_cast<nvinfer1::PluginFieldCollection*>(
...@@ -251,27 +205,19 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -251,27 +205,19 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_ptr->nbFields = static_cast<int>(fields.size()); plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data(); plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs = input_ids;
plugin_inputs.emplace_back(
engine_->GetITensor(word_id_name)); // word_embedding,
// eval_placeholder_0
plugin_inputs.emplace_back(
engine_->GetITensor(sent_id_name)); // sent_embedding,
// eval_placeholder_1
plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(engine_->GetITensor(
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "2"); "ManyEmbLayerNormPluginDynamic", "2");
auto plugin_obj = auto plugin_obj =
creator->createPlugin("CustomEmbLayerNormPluginDynamic", plugin_ptr); creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("CustomEmbLayerNormPluginDynamic_V2(Output: " +
plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V2(Output: " +
op_desc.Output("Out")[0] + ")") op_desc.Output("Out")[0] + ")")
.c_str()); .c_str());
free(plugin_ptr); free(plugin_ptr);
...@@ -302,11 +248,33 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -302,11 +248,33 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer = plugin_layer; layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, RreplenishLayerAndOutput(layer,
"CustomEmbLayerNormPluginDynamic_V2", "ManyEmbLayerNormPluginDynamic_V2",
{output_name, std::string("qkv_plugin_mask")}, {output_name, std::string("qkv_plugin_mask")},
test_mode); test_mode);
} }
} else { } else {
for (int i = 0; i < input_num; i++) {
if (with_fp16) {
weight = GetFp16Weight(emb_names[i], &emb_dims);
} else {
weight = GetFp32Weight(emb_names[i], &emb_dims);
}
input_ids.push_back(engine_->GetITensor(id_names[i]));
input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count);
hidden = emb_dims[1];
}
if (with_fp16) {
bias_weight = GetFp16Weight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight =
GetFp16Weight(op_desc.Input("Scale").front(), &scale_dims);
} else {
bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight =
GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims);
}
bias_size = phi::product(bias_dims);
scale_size = phi::product(scale_dims);
float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
plugin::DynamicPluginTensorRT* plugin = nullptr; plugin::DynamicPluginTensorRT* plugin = nullptr;
std::vector<void*> input_embs_data; std::vector<void*> input_embs_data;
......
...@@ -10,7 +10,9 @@ See the License for the specific language governing permissions and ...@@ -10,7 +10,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,6 +34,15 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -32,6 +34,15 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
bool test_mode) override { bool test_mode) override {
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"; VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer";
// get the presistable var's data
auto GetWeight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetTrtWeight(var_name, *temp_tensor);
return weight;
};
auto pos_id_name = engine_->tensorrt_transformer_posid(); auto pos_id_name = engine_->tensorrt_transformer_posid();
auto mask_id_name = engine_->tensorrt_transformer_maskid(); auto mask_id_name = engine_->tensorrt_transformer_maskid();
...@@ -50,126 +61,48 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -50,126 +61,48 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8.")); platform::errors::Fatal("use with_interleaved must be int8."));
} }
auto word_id_name = op_desc.Input("WordId").front(); // Declare inputs
engine_->Set("ernie_pos_name", new std::string(pos_id_name)); std::vector<nvinfer1::ITensor*> input_ids;
auto sent_id_name = op_desc.Input("SentId").front();
auto word_emb_name = op_desc.Input("WordEmbedding").front();
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
engine_->SetITensor("word_id", engine_->GetITensor(word_id_name));
engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
std::vector<std::string> emb_names;
emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
int input_num = emb_names.size(); // Declare inputs_weight
std::vector<nvinfer1::Weights> input_embs;
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
// input_embs[2]: sent_embedding
std::vector<float*> input_embs;
std::vector<int> emb_sizes; std::vector<int> emb_sizes;
TensorRTEngine::Weight weight;
// get the presistable var's data framework::DDim emb_dims;
auto get_persistable_data = [&](const std::string& var_name,
framework::DDim* dims) -> float* {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(var_name, *temp_tensor).get().values));
return temp_data;
};
for (int i = 0; i < input_num; i++) {
framework::DDim emb_dims;
float* emb_data = get_persistable_data(emb_names[i], &emb_dims);
int64_t emb_size = phi::product(emb_dims);
input_embs.push_back(emb_data);
emb_sizes.push_back(emb_size);
PADDLE_ENFORCE_EQ(
emb_dims.size(),
2,
platform::errors::InvalidArgument(
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."));
}
framework::DDim bias_dims, scale_dims; framework::DDim bias_dims, scale_dims;
TensorRTEngine::Weight bias_weight, scale_weight;
auto* bias =
get_persistable_data(op_desc.Input("Bias").front(), &bias_dims);
auto* scale =
get_persistable_data(op_desc.Input("Scale").front(), &scale_dims);
int64_t bias_size = phi::product(bias_dims); int64_t bias_size = phi::product(bias_dims);
int64_t scale_size = phi::product(scale_dims); int64_t scale_size = phi::product(scale_dims);
int output_int8 = 1;
PADDLE_ENFORCE_EQ( std::vector<std::string> id_names = op_desc.Input("Ids");
input_num, std::vector<std::string> emb_names = op_desc.Input("Embs");
3, int input_num = id_names.size();
platform::errors::InvalidArgument(
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d.",
input_num));
const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta",
bias,
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(bias_size)},
{"bert_embeddings_layernorm_gamma",
scale,
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(scale_size)},
{"bert_embeddings_word_embeddings",
input_embs[0],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[0])},
{"bert_embeddings_token_type_embeddings",
input_embs[2],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[2])},
{"bert_embeddings_position_embeddings",
input_embs[1],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[1])},
{"output_fp16", &output_int8, nvinfer1::PluginFieldType::kINT32, 1},
};
nvinfer1::PluginFieldCollection* plugin_ptr = engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
static_cast<nvinfer1::PluginFieldCollection*>( engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
malloc(sizeof(*plugin_ptr) +
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(
engine_->GetITensor(word_id_name)); // word_embedding,
// eval_placeholder_0
plugin_inputs.emplace_back(
engine_->GetITensor(sent_id_name)); // sent_embedding,
// eval_placeholder_1
plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2
auto mask_id_tensor = engine_->GetITensor("mask_id"); auto mask_id_tensor = engine_->GetITensor("mask_id");
auto mask_dims = mask_id_tensor->getDimensions(); auto mask_dims = mask_id_tensor->getDimensions();
auto slice_start_dims = mask_dims; auto slice_start_dims = mask_dims;
auto slice_size_dims = mask_dims;
auto slice_stride_dims = mask_dims; auto slice_stride_dims = mask_dims;
for (int i = 0; i < mask_dims.nbDims; i++) { for (int i = 0; i < mask_dims.nbDims; i++) {
slice_start_dims.d[i] = 0; slice_start_dims.d[i] = 0;
slice_size_dims.d[i] = 1;
slice_stride_dims.d[i] = 1; slice_stride_dims.d[i] = 1;
} }
slice_size_dims.d[1] = mask_dims.d[1];
auto* slice_size_tensor = Add1DConstantLayer(slice_size_dims); auto* shape_tensor = Shape(mask_id_tensor);
std::vector<nvinfer1::ITensor*> size_vec_tensor;
std::vector<nvinfer1::ITensor*> start_vec_tensor;
for (int i = 0; i < mask_dims.nbDims; i++) {
size_vec_tensor.push_back(Add1DConstantLayer(1));
start_vec_tensor.push_back(Add1DConstantLayer(0));
}
size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1);
auto size_tensor = Concat(size_vec_tensor);
auto start_tensor = Concat(start_vec_tensor);
auto slice_layer = auto slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, TRT_ENGINE_ADD_LAYER(engine_,
Slice, Slice,
...@@ -177,11 +110,11 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -177,11 +110,11 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
slice_start_dims, slice_start_dims,
slice_start_dims, slice_start_dims,
slice_stride_dims); // unuseful slice_start_dims slice_stride_dims); // unuseful slice_start_dims
slice_layer->setInput(2, *slice_size_tensor); slice_layer->setInput(1, *start_tensor);
slice_layer->setName( slice_layer->setInput(2, *size_tensor);
("PrelnEmbeltwise_slice_layer (Output: slice_max_seqlen " + slice_layer->setName(("Embeltwise_slice_layer (Output: slice_max_seqlen " +
op_desc.Output("Out")[0] + ")") op_desc.Output("Out")[0] + ")")
.c_str()); .c_str());
engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f); engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f);
auto* reshape_layer = auto* reshape_layer =
...@@ -190,24 +123,87 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -190,24 +123,87 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
shape_dim.nbDims = 1; shape_dim.nbDims = 1;
shape_dim.d[0] = -1; shape_dim.d[0] = -1;
reshape_layer->setReshapeDimensions(shape_dim); reshape_layer->setReshapeDimensions(shape_dim);
reshape_layer->setName( reshape_layer->setName(("Embeltwise_reshape_layer (Output: max_seqlen " +
("PrelnEmbeltwise_reshape_layer (Output: max_seqlen " + op_desc.Output("Out")[0] + ")")
op_desc.Output("Out")[0] + ")") .c_str());
.c_str());
engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f); engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f);
engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0)); engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0));
plugin_inputs.emplace_back(
reshape_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 for (int i = 0; i < input_num; i++) {
auto input_tensor = engine_->GetITensor(id_names[i]);
weight = GetWeight(emb_names[i], &emb_dims);
if (id_names[i] == pos_id_name) {
input_ids.insert(input_ids.begin(), input_tensor);
input_embs.insert(input_embs.begin(), weight.get());
emb_sizes.insert(emb_sizes.begin(), weight.get().count);
} else {
input_ids.push_back(input_tensor);
input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count);
}
}
bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims);
bias_size = phi::product(bias_dims);
scale_size = phi::product(scale_dims);
// other_id(except pos_id)
engine_->SetITensor("word_id", input_ids[1]);
int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
if (enable_int8) {
output_fp16 = 1;
}
PADDLE_ENFORCE_EQ(
output_fp16,
1,
platform::errors::InvalidArgument(
"Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.EnableVarseqlen(). "
"But Precision::KFloat32 is setted."));
std::vector<nvinfer1::PluginField> fields;
std::vector<std::string> temp_fields_keys;
fields.emplace_back("bert_embeddings_layernorm_beta",
bias_weight.get().values,
GetPluginFieldType(bias_weight.get().type),
static_cast<int32_t>(bias_size));
fields.emplace_back("bert_embeddings_layernorm_gamma",
scale_weight.get().values,
GetPluginFieldType(scale_weight.get().type),
static_cast<int32_t>(scale_size));
fields.emplace_back(
"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1);
for (int i = 0; i < input_num; ++i) {
temp_fields_keys.push_back("bert_embeddings_word_embeddings_" +
std::to_string(i));
fields.emplace_back(temp_fields_keys.rbegin()->c_str(),
input_embs[i].values,
GetPluginFieldType(input_embs[i].type),
static_cast<int32_t>(emb_sizes[i]));
}
nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_ptr) +
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs = input_ids;
plugin_inputs.emplace_back(engine_->GetITensor(
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "3"); "ManyEmbLayerNormPluginDynamic", "3");
auto plugin_obj = auto plugin_obj =
creator->createPlugin("CustomEmbLayerNormPluginDynamic", plugin_ptr); creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("CustomPrelnEmbLayerNormPluginDynamic_V3(Output: " +
op_desc.Output("Out_0")[0] + ")") plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V3(Output: " +
op_desc.Output("Out")[0] + ")")
.c_str()); .c_str());
free(plugin_ptr); free(plugin_ptr);
float out_0_scale = float out_0_scale =
...@@ -226,7 +222,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -226,7 +222,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
engine_->SetITensor(op_desc.Output("Out_0")[0], engine_->SetITensor(op_desc.Output("Out_0")[0],
shuffler_embed_out0->getOutput(0)); shuffler_embed_out0->getOutput(0));
shuffler_embed_out0->setName( shuffler_embed_out0->setName(
("shuffler_after_CustomPrelnEmbLayerNormPluginDynamic_V3(Output_0: " + ("shuffler_after_ManyEmbLayerNormPluginDynamic_V3(Output_0: " +
op_desc.Output("Out_0")[0] + ")") op_desc.Output("Out_0")[0] + ")")
.c_str()); .c_str());
...@@ -240,7 +236,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -240,7 +236,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
engine_->SetITensor(op_desc.Output("Out_1")[0], engine_->SetITensor(op_desc.Output("Out_1")[0],
shuffler_embed_out1->getOutput(0)); shuffler_embed_out1->getOutput(0));
shuffler_embed_out1->setName( shuffler_embed_out1->setName(
("shuffler_after_CustomPrelnEmbLayerNormPluginDynamic_V3(Output_1: " + ("shuffler_after_ManyEmbLayerNormPluginDynamic_V3(Output_1: " +
op_desc.Output("Out_1")[0] + ")") op_desc.Output("Out_1")[0] + ")")
.c_str()); .c_str());
......
...@@ -34,6 +34,11 @@ list( ...@@ -34,6 +34,11 @@ list(
fused_token_prune_op_plugin.cu fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu layernorm_shift_partition_op.cu
generic_plugin.cu) generic_plugin.cu)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32)
list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernelMTron.cu
many_emb_Layernorm_varseqlen_kernelHFace.cu)
endif()
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND TRT_FILES spmm_plugin.cu) list(APPEND TRT_FILES spmm_plugin.cu)
......
...@@ -96,8 +96,8 @@ template <typename T> ...@@ -96,8 +96,8 @@ template <typename T>
inline void serFromDev(char** buffer, const T* data, size_t nbElem) { inline void serFromDev(char** buffer, const T* data, size_t nbElem) {
const size_t len = sizeof(T) * nbElem; const size_t len = sizeof(T) * nbElem;
cudaMemcpy( cudaMemcpy(
buffer, static_cast<const void*>(data), len, cudaMemcpyDeviceToHost); *buffer, static_cast<const void*>(data), len, cudaMemcpyDeviceToHost);
buffer += len; *buffer += len;
} }
template <typename T> template <typename T>
...@@ -174,8 +174,8 @@ struct WeightsWithOwnership : public nvinfer1::Weights { ...@@ -174,8 +174,8 @@ struct WeightsWithOwnership : public nvinfer1::Weights {
const auto nbBytes = getWeightsSize(*this, type); const auto nbBytes = getWeightsSize(*this, type);
auto destBuf = new char[nbBytes]; auto destBuf = new char[nbBytes];
this->values = destBuf; this->values = destBuf;
std::copy_n(srcBuf, nbBytes, destBuf); std::copy_n(*srcBuf, nbBytes, destBuf);
srcBuf += nbBytes; *srcBuf += nbBytes;
} }
}; };
......
...@@ -59,7 +59,7 @@ template <> ...@@ -59,7 +59,7 @@ template <>
struct Serializer<const char*> { struct Serializer<const char*> {
static size_t serialized_size(const char* value) { return strlen(value) + 1; } static size_t serialized_size(const char* value) { return strlen(value) + 1; }
static void serialize(void** buffer, const char* value) { static void serialize(void** buffer, const char* value) {
::snprintf(static_cast<char*>(*buffer), value); ::strcpy(static_cast<char*>(*buffer), value); // NOLINT
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1; reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
} }
static void deserialize(void const** buffer, static void deserialize(void const** buffer,
......
...@@ -25,8 +25,7 @@ ...@@ -25,8 +25,7 @@
#include "common/common.cuh" #include "common/common.cuh"
#include "common/plugin.h" #include "common/plugin.h"
#include "common/serialize.h" #include "common/serialize.h"
// #include #include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
// "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -25,8 +25,7 @@ ...@@ -25,8 +25,7 @@
#include "common/common.cuh" #include "common/common.cuh"
#include "common/plugin.h" #include "common/plugin.h"
#include "common/serialize.h" #include "common/serialize.h"
// #include #include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
// "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -61,8 +61,8 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( ...@@ -61,8 +61,8 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
assert(beta.count == gamma.count); assert(beta.count == gamma.count);
mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT); mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT);
mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT); mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT);
copyToDevice(mGamma, sizeof(float) * mGamma.count, mGammaDev); copyToDevice(&mGamma, sizeof(float) * mGamma.count, &mGammaDev);
copyToDevice(mBeta, sizeof(float) * mBeta.count, mBetaDev); copyToDevice(&mBeta, sizeof(float) * mBeta.count, &mBetaDev);
for (size_t i = 0; i < mIdsEmb_.size(); ++i) { for (size_t i = 0; i < mIdsEmb_.size(); ++i) {
assert(mIdsEmb_[i].count % mLd == 0); assert(mIdsEmb_[i].count % mLd == 0);
mIdsVocabSize.push_back(int32_t(mIdsEmb_[i].count / mLd)); mIdsVocabSize.push_back(int32_t(mIdsEmb_[i].count / mLd));
...@@ -96,8 +96,8 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( ...@@ -96,8 +96,8 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
mIdsVocabSize.push_back(tem); mIdsVocabSize.push_back(tem);
} }
char const* d = static_cast<char const*>(data); char const* d = static_cast<char const*>(data);
mBeta.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); mBeta.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT);
mGamma.convertAndCopy(d, mLd, nvinfer1::DataType::kFLOAT); mGamma.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT);
for (int32_t i = 0; i < nbLookupTables_; ++i) { for (int32_t i = 0; i < nbLookupTables_; ++i) {
nvinfer1::Weights pre_tem_weight; nvinfer1::Weights pre_tem_weight;
pre_tem_weight.type = mType; pre_tem_weight.type = mType;
...@@ -565,10 +565,10 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept { ...@@ -565,10 +565,10 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept {
} }
char* d = static_cast<char*>(buffer); char* d = static_cast<char*>(buffer);
size_t const wordSize = getElementSize(mType); size_t const wordSize = getElementSize(mType);
serFromDev(d, mBetaDev.get(), mLd); serFromDev(&d, mBetaDev.get(), mLd);
serFromDev(d, mGammaDev.get(), mLd); serFromDev(&d, mGammaDev.get(), mLd);
for (size_t i = 0; i < mIdsEmbDev.size(); ++i) { for (size_t i = 0; i < mIdsEmbDev.size(); ++i) {
serFromDev(d, serFromDev(&d,
static_cast<char*>(mIdsEmbDev[i]), static_cast<char*>(mIdsEmbDev[i]),
mLd * mIdsVocabSize[i] * wordSize); mLd * mIdsVocabSize[i] * wordSize);
} }
...@@ -673,7 +673,7 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin( ...@@ -673,7 +673,7 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
nvinfer1::Weights beta; nvinfer1::Weights beta;
nvinfer1::Weights gamma; nvinfer1::Weights gamma;
std::vector<nvinfer1::Weights> IdsEmb; std::vector<nvinfer1::Weights> IdsEmb;
bool output_fp16 = initializeFields(fc, beta, gamma, IdsEmb); bool output_fp16 = initializeFields(fc, &beta, &gamma, &IdsEmb);
TRANSFORMER_DEBUG_MSG("Building the Plugin..."); TRANSFORMER_DEBUG_MSG("Building the Plugin...");
EmbLayerNormVarSeqlenPluginHFace* p = new EmbLayerNormVarSeqlenPluginHFace( EmbLayerNormVarSeqlenPluginHFace* p = new EmbLayerNormVarSeqlenPluginHFace(
name, name,
...@@ -691,7 +691,7 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginMTronCreator::createPlugin( ...@@ -691,7 +691,7 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginMTronCreator::createPlugin(
nvinfer1::Weights beta; nvinfer1::Weights beta;
nvinfer1::Weights gamma; nvinfer1::Weights gamma;
std::vector<nvinfer1::Weights> IdsEmb; std::vector<nvinfer1::Weights> IdsEmb;
bool output_fp16 = initializeFields(fc, beta, gamma, IdsEmb); bool output_fp16 = initializeFields(fc, &beta, &gamma, &IdsEmb);
TRANSFORMER_DEBUG_MSG("Building the Plugin..."); TRANSFORMER_DEBUG_MSG("Building the Plugin...");
EmbLayerNormVarSeqlenPluginMTron* p = new EmbLayerNormVarSeqlenPluginMTron( EmbLayerNormVarSeqlenPluginMTron* p = new EmbLayerNormVarSeqlenPluginMTron(
name, name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册