未验证 提交 22bfa579 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] General optimization for no_varlen embedding layernorm (#48580)

* general optimization no_varlen embedding layernorm
上级 8c416653
...@@ -140,7 +140,7 @@ if(WITH_TENSORRT) ...@@ -140,7 +140,7 @@ if(WITH_TENSORRT)
pass_library(preln_layernorm_x_fuse_pass inference) pass_library(preln_layernorm_x_fuse_pass inference)
endif() endif()
if(WITH_TENSORRT AND NOT WIN32) if(WITH_TENSORRT)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
endif() endif()
......
...@@ -1170,14 +1170,14 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { ...@@ -1170,14 +1170,14 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
"preln_embedding_eltwise_layernorm_fuse_" "preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen")); "pass. please use no_varseqlen"));
} }
} else if (!use_varseqlen && pos_id == "" && mask_id == "") { } else if (!use_varseqlen && pos_id == "") {
VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass";
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: " platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set " "use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set " "mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please " "pos_id. Please "
"reconfig")); "reconfig"));
} }
graph->Set(kMultiheadMatmulPass, new bool(true)); graph->Set(kMultiheadMatmulPass, new bool(true));
......
...@@ -2338,11 +2338,8 @@ USE_TRT_CONVERTER(conv3d_transpose); ...@@ -2338,11 +2338,8 @@ USE_TRT_CONVERTER(conv3d_transpose);
USE_TRT_CONVERTER(mish); USE_TRT_CONVERTER(mish);
USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(deformable_conv);
USE_TRT_CONVERTER(pool3d) USE_TRT_CONVERTER(pool3d)
#ifdef _WIN32
#else
USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
#endif
USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(preln_residual_bias) USE_TRT_CONVERTER(preln_residual_bias)
USE_TRT_CONVERTER(c_allreduce_sum) USE_TRT_CONVERTER(c_allreduce_sum)
......
...@@ -96,13 +96,8 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -96,13 +96,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"add_support_int8_pass", // "add_support_int8_pass", //
// "fc_fuse_pass", // // "fc_fuse_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
#if defined _WIN32
#else
"trt_embedding_eltwise_layernorm_fuse_pass", // "trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", // "preln_embedding_eltwise_layernorm_fuse_pass", //
#endif
"delete_c_identity_op_pass", // "delete_c_identity_op_pass", //
"trt_multihead_matmul_fuse_pass_v2", // "trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", // "trt_multihead_matmul_fuse_pass_v3", //
...@@ -116,7 +111,6 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -116,7 +111,6 @@ const std::vector<std::string> kTRTSubgraphPasses({
"preln_residual_bias_fuse_pass", // "preln_residual_bias_fuse_pass", //
"preln_layernorm_x_fuse_pass", // "preln_layernorm_x_fuse_pass", //
"reverse_roll_fuse_pass", // "reverse_roll_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", // "trt_squeeze2_matmul_fuse_pass", //
......
...@@ -94,7 +94,7 @@ list( ...@@ -94,7 +94,7 @@ list(
fused_lookup_tables_op.cc fused_lookup_tables_op.cc
expand_v2_op.cc) expand_v2_op.cc)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc
preln_emb_eltwise_layernorm.cc) preln_emb_eltwise_layernorm.cc)
endif() endif()
......
...@@ -13,7 +13,7 @@ limitations under the License. */ ...@@ -13,7 +13,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/utils.h" #include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
...@@ -36,7 +36,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -36,7 +36,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
const framework::Scope& scope, const framework::Scope& scope,
bool test_mode) override { bool test_mode) override {
VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer"; VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer";
// get the presistable var's data // get the presistable var's data
auto GetWeight = [&](const std::string& var_name, auto GetWeight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight { framework::DDim* dim) -> TensorRTEngine::Weight {
...@@ -47,32 +46,13 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -47,32 +46,13 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
return weight; return weight;
}; };
auto GetFp16Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<phi::DenseTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor);
return weight;
};
auto GetFp32Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<phi::DenseTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor);
return weight;
};
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto pos_id_name = engine_->tensorrt_transformer_posid(); auto pos_id_name = engine_->tensorrt_transformer_posid();
auto mask_id_name = engine_->tensorrt_transformer_maskid(); auto mask_id_name = engine_->tensorrt_transformer_maskid();
bool flag_varseqlen = bool flag_varseqlen =
engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != ""; engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != "";
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); // bool with_fp16 = engine_->WithFp16() &&
int hidden = 0; // !engine_->disable_trt_plugin_fp16(); int hidden = 0; Declare inputs
// Declare inputs
std::vector<nvinfer1::ITensor*> input_ids; std::vector<nvinfer1::ITensor*> input_ids;
// Declare inputs_weight // Declare inputs_weight
...@@ -95,55 +75,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -95,55 +75,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
if (flag_varseqlen) { if (flag_varseqlen) {
engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
auto mask_id_tensor = engine_->GetITensor("mask_id");
auto mask_dims = mask_id_tensor->getDimensions();
auto slice_start_dims = mask_dims;
auto slice_stride_dims = mask_dims;
for (int i = 0; i < mask_dims.nbDims; i++) {
slice_start_dims.d[i] = 0;
slice_stride_dims.d[i] = 1;
}
auto* shape_tensor = Shape(mask_id_tensor);
std::vector<nvinfer1::ITensor*> size_vec_tensor;
std::vector<nvinfer1::ITensor*> start_vec_tensor;
for (int i = 0; i < mask_dims.nbDims; i++) {
size_vec_tensor.push_back(Add1DConstantLayer(1));
start_vec_tensor.push_back(Add1DConstantLayer(0));
}
size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1);
auto size_tensor = Concat(size_vec_tensor);
auto start_tensor = Concat(start_vec_tensor);
auto slice_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*mask_id_tensor,
slice_start_dims,
slice_start_dims,
slice_stride_dims); // unuseful slice_start_dims
slice_layer->setInput(1, *start_tensor);
slice_layer->setInput(2, *size_tensor);
slice_layer->setName(
("Embeltwise_slice_layer (Output: slice_max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f);
auto* reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *slice_layer->getOutput(0));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
reshape_layer->setReshapeDimensions(shape_dim);
reshape_layer->setName(("Embeltwise_reshape_layer (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f);
engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0));
for (int i = 0; i < input_num; i++) { for (int i = 0; i < input_num; i++) {
auto input_tensor = engine_->GetITensor(id_names[i]); auto input_tensor = engine_->GetITensor(id_names[i]);
weight = GetWeight(emb_names[i], &emb_dims); weight = GetWeight(emb_names[i], &emb_dims);
...@@ -156,7 +87,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -156,7 +87,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs.push_back(weight.get()); input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count); emb_sizes.push_back(weight.get().count);
} }
hidden = emb_dims[1];
} }
bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims);
...@@ -206,26 +136,29 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -206,26 +136,29 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_ptr->fields = fields.data(); plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs = input_ids; std::vector<nvinfer1::ITensor*> plugin_inputs = input_ids;
plugin_inputs.emplace_back(engine_->GetITensor( plugin_inputs.emplace_back(
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 engine_->GetITensor("mask_id")); // input mask_id
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "1"); "ManyEmbLayerNormVarlenPluginDynamic", "1");
auto plugin_obj = auto plugin_obj = creator->createPlugin(
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); "ManyEmbLayerNormVarlenPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V1(Output: " + plugin_layer->setName(("ManyEmbLayerNormVarlenPluginDynamicV1(Output: " +
op_desc.Output("Out")[0] + ")") op_desc.Output("Out")[0] + ")")
.c_str()); .c_str());
free(plugin_ptr); free(plugin_ptr);
if (enable_int8) { if (enable_int8) {
float out_scale = float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), out_scale); engine_->SetTensorDynamicRange(plugin_layer->getOutput(0),
engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_scale); out_scale); // output
engine_->SetTensorDynamicRange(plugin_layer->getOutput(1),
out_scale); // mask
engine_->SetTensorDynamicRange(plugin_layer->getOutput(2),
out_scale); // max seqlen
} }
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) << "fused emb_eltwise_layernorm op: use_varseqlen and " VLOG(4) << "fused emb_eltwise_layernorm op: use_varseqlen and "
...@@ -249,54 +182,82 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -249,54 +182,82 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, RreplenishLayerAndOutput(layer,
"ManyEmbLayerNormPluginDynamic_V1", "ManyEmbLayerNormPluginDynamic_V1",
{output_name, std::string("qkv_plugin_mask")}, {output_name,
std::string("qkv_plugin_mask"),
std::string("max_seqlen_tensor")},
test_mode); test_mode);
} }
} else { } else {
for (int i = 0; i < input_num; i++) { for (int i = 0; i < input_num; i++) {
if (with_fp16) { auto input_tensor = engine_->GetITensor(id_names[i]);
weight = GetFp16Weight(emb_names[i], &emb_dims); weight = GetWeight(emb_names[i], &emb_dims);
} else { input_ids.push_back(input_tensor);
weight = GetFp32Weight(emb_names[i], &emb_dims);
}
input_ids.push_back(engine_->GetITensor(id_names[i]));
input_embs.push_back(weight.get()); input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count); emb_sizes.push_back(weight.get().count);
hidden = emb_dims[1]; // hidden = emb_dims[1];
}
if (with_fp16) {
bias_weight = GetFp16Weight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight =
GetFp16Weight(op_desc.Input("Scale").front(), &scale_dims);
} else {
bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight =
GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims);
} }
bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims);
bias_size = phi::product(bias_dims); bias_size = phi::product(bias_dims);
scale_size = phi::product(scale_dims); scale_size = phi::product(scale_dims);
float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
plugin::DynamicPluginTensorRT* plugin = nullptr; int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
std::vector<void*> input_embs_data; if (enable_int8) {
for (size_t i = 0; i < input_embs.size(); ++i) { output_fp16 = 1;
input_embs_data.push_back(const_cast<void*>(
reinterpret_cast<const void*>(input_embs[i].values)));
} }
plugin = new plugin::EmbEltwiseLayernormPluginDynamic(
input_embs_data, std::vector<nvinfer1::PluginField> fields;
const_cast<void*>(static_cast<const void*>(bias_weight.get().values)), std::vector<std::string> temp_fields_keys;
const_cast<void*>( fields.emplace_back("bert_embeddings_layernorm_beta",
static_cast<const void*>(scale_weight.get().values)), bias_weight.get().values,
emb_sizes, GetPluginFieldType(bias_weight.get().type),
bias_size, static_cast<int32_t>(bias_size));
scale_size, fields.emplace_back("bert_embeddings_layernorm_gamma",
hidden, scale_weight.get().values,
eps, GetPluginFieldType(scale_weight.get().type),
with_fp16); static_cast<int32_t>(scale_size));
layer = engine_->AddDynamicPlugin(input_ids.data(), input_num, plugin); fields.emplace_back(
"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1);
for (int i = 0; i < input_num; ++i) {
temp_fields_keys.push_back("bert_embeddings_word_embeddings_" +
std::to_string(i));
fields.emplace_back(temp_fields_keys.rbegin()->c_str(),
input_embs[i].values,
GetPluginFieldType(input_embs[i].type),
static_cast<int32_t>(emb_sizes[i]));
}
nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_ptr) +
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs = input_ids;
auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "1");
auto plugin_obj =
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("ManyEmbLayerNormPluginDynamicV1(Output: " +
op_desc.Output("Out")[0] + ")")
.c_str());
free(plugin_ptr);
if (enable_int8) {
float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(plugin_layer->getOutput(0),
out_scale); // output
}
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput( RreplenishLayerAndOutput(
layer, "emb_eltwise_layernorm", {output_name}, test_mode); layer, "ManyEmbLayerNormPluginDynamicV1", {output_name}, test_mode);
} }
} }
}; };
......
...@@ -194,10 +194,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -194,10 +194,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "2"); "ManyEmbLayerNormVarlenPluginDynamic", "2");
auto plugin_obj = auto plugin_obj = creator->createPlugin(
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); "ManyEmbLayerNormVarlenPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
......
...@@ -11,7 +11,6 @@ list( ...@@ -11,7 +11,6 @@ list(
group_norm_op_plugin.cu group_norm_op_plugin.cu
layer_norm_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu instance_norm_op_plugin.cu
emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu qkv_to_context_plugin.cu
skip_layernorm_op_plugin.cu skip_layernorm_op_plugin.cu
hard_swish_op_plugin.cu hard_swish_op_plugin.cu
...@@ -38,12 +37,14 @@ list( ...@@ -38,12 +37,14 @@ list(
merge_layernorm_op_plugin.cu merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu skip_merge_layernorm_op_plugin.cu
generic_plugin.cu generic_plugin.cu
lookup_table.cu) lookup_table.cu
many_emb_layernorm_plugin.cu
many_emb_Layernorm_kernel.cu)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernelMTron.cu many_emb_Layernorm_varseqlen_kernel_mtron.cu
many_emb_Layernorm_varseqlen_kernelHFace.cu) many_emb_Layernorm_varseqlen_kernel_hface.cu)
endif() endif()
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <type_traits>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic shape plugin requires TRT version greater than 6.0.
#if IS_TRT_VERSION_GE(6000)
template <typename T>
void EmbEltwiseLayernormPluginDynamicImpl<T>::shareGPUData(
const EmbEltwiseLayernormPluginDynamicImplBase *anthor) {
auto *ptr =
dynamic_cast<const EmbEltwiseLayernormPluginDynamicImpl<T> *>(anthor);
if (!ptr->is_initialized_) {
return;
}
embs_gpu_ = ptr->embs_gpu_;
scale_gpu_ = ptr->scale_gpu_;
bias_gpu_ = ptr->bias_gpu_;
int input_num = embs_.size();
in_ptr_tensor_.Resize({input_num});
emb_ptr_tensor_.ShareDataWith(ptr->emb_ptr_tensor_);
}
template <typename T>
int EmbEltwiseLayernormPluginDynamicImpl<T>::initialize() {
if (is_initialized_) {
return 0;
}
embs_gpu_.resize(embs_.size());
for (int i = 0; i < embs_.size(); i++) {
if (embs_[i]) {
T *host_ptr = embs_[i];
auto size = emb_sizes_[i];
cudaMalloc(&embs_gpu_[i], sizeof(T) * size);
cudaMemcpy(
embs_gpu_[i], host_ptr, size * sizeof(T), cudaMemcpyHostToDevice);
}
}
if (bias_) {
cudaMalloc(&bias_gpu_, sizeof(T) * bias_size_);
cudaMemcpy(
bias_gpu_, bias_, bias_size_ * sizeof(T), cudaMemcpyHostToDevice);
}
if (scale_) {
cudaMalloc(&scale_gpu_, sizeof(T) * scale_size_);
cudaMemcpy(
scale_gpu_, scale_, scale_size_ * sizeof(T), cudaMemcpyHostToDevice);
}
int input_num = embs_.size();
in_ptr_tensor_.Resize({input_num});
emb_ptr_tensor_.Resize({input_num});
cudaGetDevice(&device_id_);
auto emb_ptr_gpu_d =
emb_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
cudaMemcpy(emb_ptr_gpu_d,
embs_gpu_.data(),
sizeof(uintptr_t) * input_num,
cudaMemcpyHostToDevice);
is_initialized_ = true;
return 0;
}
template <typename T>
void EmbEltwiseLayernormPluginDynamicImpl<T>::terminate() {
for (int i = 0; i < embs_gpu_.size(); ++i) {
if (embs_gpu_[i]) {
cudaFree(embs_gpu_[i]);
embs_gpu_[i] = nullptr;
}
}
if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
}
template <typename T>
int EmbEltwiseLayernormPluginDynamicImpl<T>::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto id_dims = input_desc[0].dims;
int batch = id_dims.d[0];
int seq_len = id_dims.d[1];
int input_num = embs_.size();
cudaGetDevice(&device_id_);
auto in_ptr_gpu_d =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
auto emb_ptr_gpu_d =
emb_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
cudaMemcpyAsync(in_ptr_gpu_d,
reinterpret_cast<const void *>(inputs),
sizeof(uintptr_t) * input_num,
cudaMemcpyHostToDevice,
stream);
auto out_type = output_desc[0].type;
if (std::is_same<T, float>::value) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT,
true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp32 input."));
} else if (std::is_same<T, half>::value) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kHALF,
true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp16 input."));
} else {
PADDLE_THROW(platform::errors::Fatal(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."));
}
auto *output_d = reinterpret_cast<T *>(outputs[0]);
operators::math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch,
seq_len,
hidden_size_,
in_ptr_gpu_d,
scale_gpu_,
bias_gpu_,
emb_ptr_gpu_d,
output_d,
eps_,
input_num,
stream);
return cudaGetLastError() != cudaSuccess;
}
template class EmbEltwiseLayernormPluginDynamicImpl<float>;
#ifdef TRT_PLUGIN_FP16_AVALIABLE
template class EmbEltwiseLayernormPluginDynamicImpl<half>;
#endif
int EmbEltwiseLayernormPluginDynamic::initialize() TRT_NOEXCEPT {
impl_->initialize();
return 0;
}
void EmbEltwiseLayernormPluginDynamic::terminate() TRT_NOEXCEPT {
impl_->terminate();
}
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { // NOLINT
PADDLE_ENFORCE_EQ(output_index,
0,
platform::errors::InvalidArgument(
"There is only one output of the EmbEltwiseLayernorm, "
"so the index should be zero,"
"but it's (%d)",
output_index));
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = expr_builder.constant(hidden_size_);
return ret;
}
bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_EQ(nb_outputs,
1,
platform::errors::InvalidArgument(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
int all_nums = nb_inputs + nb_outputs;
PADDLE_ENFORCE_LT(
pos,
all_nums,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
all_nums));
const nvinfer1::PluginTensorDesc &desc = in_out[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
return false;
}
if (pos == 0) {
return desc.type == nvinfer1::DataType::kINT32;
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos < all_nums - 1) {
return desc.type == nvinfer1::DataType::kINT32 &&
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
}
// output
if (pos == all_nums - 1) {
if (with_fp16_ == false) {
return desc.type == nvinfer1::DataType::kFLOAT;
} else {
return desc.type == nvinfer1::DataType::kHALF;
}
}
return false;
}
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(
index,
0,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only has one output, so the "
"index value should be 0, but get %d.",
index));
if (with_fp16_)
return nvinfer1::DataType::kHALF;
else
return nvinfer1::DataType::kFLOAT;
}
int EmbEltwiseLayernormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
impl_->enqueue(input_desc, output_desc, inputs, outputs, workspace, stream);
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstddef>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class EmbEltwiseLayernormPluginDynamicImplBase {
public:
EmbEltwiseLayernormPluginDynamicImplBase() {}
virtual ~EmbEltwiseLayernormPluginDynamicImplBase() {}
virtual int initialize() = 0;
virtual void terminate() = 0;
virtual int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) = 0;
virtual void shareGPUData(
const EmbEltwiseLayernormPluginDynamicImplBase* anthor) = 0;
};
template <typename T>
class EmbEltwiseLayernormPluginDynamicImpl
: public EmbEltwiseLayernormPluginDynamicImplBase {
public:
explicit EmbEltwiseLayernormPluginDynamicImpl(std::vector<T*> input_embs,
T* bias,
T* scale,
std::vector<int> emb_sizes,
int bias_size,
int scale_size,
int hidden_size,
float eps)
: embs_(input_embs),
bias_(bias),
scale_(scale),
emb_sizes_(emb_sizes),
bias_size_(bias_size),
scale_size_(scale_size),
hidden_size_(hidden_size),
eps_(eps) {}
~EmbEltwiseLayernormPluginDynamicImpl() {}
int initialize();
void terminate();
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT;
void shareGPUData(const EmbEltwiseLayernormPluginDynamicImplBase* anthor);
private:
std::vector<T*> embs_;
T* bias_{nullptr};
T* scale_{nullptr};
// data on devices
T* bias_gpu_{nullptr};
T* scale_gpu_{nullptr};
std::vector<T*> embs_gpu_;
std::vector<int> emb_sizes_;
int bias_size_;
int scale_size_;
int hidden_size_;
float eps_;
phi::DenseTensor in_ptr_tensor_, emb_ptr_tensor_;
int device_id_{0};
bool is_initialized_{false};
};
class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
public:
explicit EmbEltwiseLayernormPluginDynamic(std::vector<void*> input_embs,
void* bias,
void* scale,
std::vector<int> emb_sizes,
int bias_size,
int scale_size,
int hidden_size,
float eps,
bool with_fp16)
: embs_(input_embs),
bias_(bias),
scale_(scale),
emb_sizes_(emb_sizes),
bias_size_(bias_size),
scale_size_(scale_size),
hidden_size_(hidden_size),
eps_(eps),
own_host_buff_(false) {
with_fp16_ = with_fp16;
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp16";
instantiateImpl<half>();
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "));
#endif
} else {
VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp32";
instantiateImpl<float>();
}
}
EmbEltwiseLayernormPluginDynamic(void const* serial_data,
size_t serial_length)
: own_host_buff_(true) {
// the first var is with_fp16, we will use it.
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
DeserializeValue(&serial_data, &serial_length, &emb_sizes_);
DeserializeValue(&serial_data, &serial_length, &bias_size_);
DeserializeValue(&serial_data, &serial_length, &scale_size_);
embs_.resize(emb_sizes_.size());
if (with_fp16_) {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
auto ptr = new half[size];
memcpy(ptr, serial_data, sizeof(half) * size);
embs_[i] = ptr;
reinterpret_cast<char const*&>(serial_data) += size * sizeof(half);
serial_length -= size * sizeof(half);
}
if (bias_size_) {
bias_ = new half[bias_size_];
memcpy(bias_, serial_data, sizeof(half) * bias_size_);
}
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(half);
serial_length -= bias_size_ * sizeof(half);
if (scale_size_) {
scale_ = new half[scale_size_];
memcpy(scale_, serial_data, sizeof(half) * scale_size_);
}
reinterpret_cast<char const*&>(serial_data) += scale_size_ * sizeof(half);
serial_length -= scale_size_ * sizeof(half);
} else {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
auto ptr = new float[size];
memcpy(ptr, serial_data, sizeof(float) * size);
embs_[i] = ptr;
reinterpret_cast<char const*&>(serial_data) += size * sizeof(float);
serial_length -= size * sizeof(float);
}
if (bias_size_) {
bias_ = new float[bias_size_];
memcpy(bias_, serial_data, sizeof(float) * bias_size_);
}
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(float);
serial_length -= bias_size_ * sizeof(float);
if (scale_size_) {
scale_ = new float[scale_size_];
memcpy(scale_, serial_data, sizeof(float) * scale_size_);
}
reinterpret_cast<char const*&>(serial_data) +=
scale_size_ * sizeof(float);
serial_length -= scale_size_ * sizeof(float);
}
DeserializeValue(&serial_data, &serial_length, &hidden_size_);
DeserializeValue(&serial_data, &serial_length, &eps_);
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
instantiateImpl<half>();
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "));
#endif
} else {
instantiateImpl<float>();
}
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto ptr = new EmbEltwiseLayernormPluginDynamic(embs_,
bias_,
scale_,
emb_sizes_,
bias_size_,
scale_size_,
hidden_size_,
eps_,
with_fp16_);
ptr->shareGPUData(this);
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "fused_embedding_eltwise_layernorm_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override {
int sum_num = 0;
sum_num += SerializedSize(with_fp16_);
sum_num += SerializedSize(emb_sizes_);
if (with_fp16_) {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
sum_num += emb_sizes_[i] * sizeof(half);
}
sum_num += (bias_size_ + scale_size_) * sizeof(half);
} else {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
sum_num += emb_sizes_[i] * sizeof(float);
}
sum_num += (bias_size_ + scale_size_) * sizeof(float);
}
sum_num += SerializedSize(bias_size_);
sum_num += SerializedSize(scale_size_);
sum_num += SerializedSize(hidden_size_);
sum_num += SerializedSize(eps_);
return sum_num;
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
// the first var is for with_fp16, we will use it later;
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, emb_sizes_);
SerializeValue(&buffer, bias_size_);
SerializeValue(&buffer, scale_size_);
if (with_fp16_) {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
for (int j = 0; j < size; ++j) {
SerializeValue(&buffer, reinterpret_cast<half*>(embs_[i])[j]);
}
}
for (int i = 0; i < bias_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<half*>(bias_)[i]);
}
for (int i = 0; i < scale_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<half*>(scale_)[i]);
}
} else {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
for (int j = 0; j < size; ++j) {
SerializeValue(&buffer, reinterpret_cast<float*>(embs_[i])[j]);
}
}
for (int i = 0; i < bias_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<float*>(bias_)[i]);
}
for (int i = 0; i < scale_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<float*>(scale_)[i]);
}
}
SerializeValue(&buffer, hidden_size_);
SerializeValue(&buffer, eps_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override {
if (own_host_buff_) {
if (with_fp16_) {
for (auto ptr : embs_) {
delete[] reinterpret_cast<half*>(ptr);
}
delete[] reinterpret_cast<half*>(bias_);
delete[] reinterpret_cast<half*>(scale_);
} else {
for (auto ptr : embs_) {
delete[] reinterpret_cast<float*>(ptr);
}
delete[] reinterpret_cast<float*>(bias_);
delete[] reinterpret_cast<float*>(scale_);
}
}
delete impl_;
delete this;
}
private:
std::vector<void*> embs_;
void* bias_{nullptr};
void* scale_{nullptr};
std::vector<int> emb_sizes_;
int bias_size_;
int scale_size_;
int hidden_size_;
float eps_;
bool own_host_buff_{false};
EmbEltwiseLayernormPluginDynamicImplBase* impl_{nullptr};
void shareGPUData(const EmbEltwiseLayernormPluginDynamic* anthor) {
impl_->shareGPUData(anthor->impl_);
}
template <typename U>
void instantiateImpl() {
std::vector<U*> embs;
embs.resize(embs_.size());
for (size_t i = 0; i < embs_.size(); ++i) {
embs[i] = reinterpret_cast<U*>(embs_[i]);
}
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<U>(
embs,
reinterpret_cast<U*>(bias_),
reinterpret_cast<U*>(scale_),
emb_sizes_,
bias_size_,
scale_size_,
hidden_size_,
eps_);
}
};
class EmbEltwiseLayernormPluginDynamicCreator
: public nvinfer1::IPluginCreator {
public:
EmbEltwiseLayernormPluginDynamicCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "fused_embedding_eltwise_layernorm_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new EmbEltwiseLayernormPluginDynamic(serial_data, serial_length);
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_;
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(EmbEltwiseLayernormPluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include <cassert>
#include <cstring>
#include <iostream>
#include <vector>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
template <typename T, unsigned TPB>
__global__ void embLayerNormKernel_2(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
int32_t IdsSize0,
int32_t IdsSize1,
T* output) {
T const rld = T(1.f) / T(ld);
cub::Sum pairSum;
int32_t const seqPos = blockIdx.y * gridDim.x + blockIdx.x;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
if (static_cast<int32_t const*>(inputIds0)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds0)[seqPos] >= IdsSize0) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[0] = static_cast<int32_t const*>(inputIds0)[seqPos];
}
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[1] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
}
__syncthreads();
// offset into embeddings is given by wordId * hidden_size
int32_t const outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
int32_t const offset0 = word_id[0] * ld;
T val = mIdsEmbDev0[offset0 + it];
int32_t const offset1 = word_id[1] * ld;
val += mIdsEmbDev1[offset1 + it];
output[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T, unsigned TPB>
__global__ void embLayerNormKernel_3(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
T* output) {
T const rld = T(1.f) / T(ld);
cub::Sum pairSum;
int32_t const seqPos = blockIdx.y * gridDim.x + blockIdx.x;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
if (static_cast<int32_t const*>(inputIds0)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds0)[seqPos] >= IdsSize0) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[0] = static_cast<int32_t const*>(inputIds0)[seqPos];
}
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[1] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
if (static_cast<int32_t const*>(inputIds2)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds2)[seqPos] >= IdsSize2) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[2] = static_cast<int32_t const*>(inputIds2)[seqPos];
}
}
__syncthreads();
// offset into embeddings is given by wordId * hidden_size
int32_t const outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
int32_t const offset0 = word_id[0] * ld;
T val = mIdsEmbDev0[offset0 + it];
int32_t const offset1 = word_id[1] * ld;
val += mIdsEmbDev1[offset1 + it];
int32_t const offset2 = word_id[2] * ld;
val += mIdsEmbDev2[offset2 + it];
output[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T, unsigned TPB>
__global__ void embLayerNormKernel_4(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
int32_t const* inputIds3,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
T const* mIdsEmbDev3,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
int32_t IdsSize3,
T* output) {
T const rld = T(1.f) / T(ld);
cub::Sum pairSum;
int32_t const seqPos = blockIdx.y * gridDim.x + blockIdx.x;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
if (static_cast<int32_t const*>(inputIds0)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds0)[seqPos] >= IdsSize0) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[0] = static_cast<int32_t const*>(inputIds0)[seqPos];
}
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[1] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
if (static_cast<int32_t const*>(inputIds2)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds2)[seqPos] >= IdsSize2) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[2] = static_cast<int32_t const*>(inputIds2)[seqPos];
}
if (static_cast<int32_t const*>(inputIds3)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds3)[seqPos] >= IdsSize3) {
printf(
"Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[3] = static_cast<int32_t const*>(inputIds3)[seqPos];
}
}
__syncthreads();
// offset into embeddings is given by wordId * hidden_size
int32_t const outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
int32_t const offset0 = word_id[0] * ld;
T val = mIdsEmbDev0[offset0 + it];
int32_t const offset1 = word_id[1] * ld;
val += mIdsEmbDev1[offset1 + it];
int32_t const offset2 = word_id[2] * ld;
val += mIdsEmbDev2[offset2 + it];
int32_t const offset3 = word_id[3] * ld;
val += mIdsEmbDev3[offset3 + it];
output[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T>
int32_t embSkipLayerNorm_2(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int const* inputIds0,
int const* inputIds1,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
int32_t IdsSize0,
int32_t IdsSize1,
T* output) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * nbLookupTables;
embLayerNormKernel_2<T, tpb><<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
IdsSize0,
IdsSize1,
output);
return cudaPeekAtLastError();
}
template <typename T>
int32_t embSkipLayerNorm_3(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int const* inputIds0,
int const* inputIds1,
int const* inputIds2,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
T* output) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * nbLookupTables;
embLayerNormKernel_3<T, tpb><<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
inputIds2,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
mIdsEmbDev2,
IdsSize0,
IdsSize1,
IdsSize2,
output);
return cudaPeekAtLastError();
}
template <typename T>
int32_t embSkipLayerNorm_4(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int const* inputIds0,
int const* inputIds1,
int const* inputIds2,
int const* inputIds3,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
T const* mIdsEmbDev3,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
int32_t IdsSize3,
T* output) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * nbLookupTables;
embLayerNormKernel_4<T, tpb><<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
inputIds2,
inputIds3,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
mIdsEmbDev2,
mIdsEmbDev3,
IdsSize0,
IdsSize1,
IdsSize2,
IdsSize3,
output);
return cudaPeekAtLastError();
}
template int32_t embSkipLayerNorm_2<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
int32_t,
int32_t,
float*);
template int32_t embSkipLayerNorm_3<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
float const*,
int32_t,
int32_t,
int32_t,
float*);
template int32_t embSkipLayerNorm_4<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
float const*,
float const*,
int32_t,
int32_t,
int32_t,
int32_t,
float*);
template int32_t embSkipLayerNorm_2<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
int32_t,
int32_t,
half*);
template int32_t embSkipLayerNorm_3<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
half const*,
int32_t,
int32_t,
int32_t,
half*);
template int32_t embSkipLayerNorm_4<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
half const*,
half const*,
int32_t,
int32_t,
int32_t,
int32_t,
half*);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -33,7 +33,6 @@ template <typename T, unsigned TPB> ...@@ -33,7 +33,6 @@ template <typename T, unsigned TPB>
__global__ void embLayerNormKernelHFace_2(int32_t ld, __global__ void embLayerNormKernelHFace_2(int32_t ld,
int32_t const* inputIds0, int32_t const* inputIds0,
int32_t const* inputIds1, int32_t const* inputIds1,
int32_t nbLookupTables,
float const* beta, float const* beta,
float const* gamma, float const* gamma,
T const* mIdsEmbDev0, T const* mIdsEmbDev0,
...@@ -93,7 +92,6 @@ __global__ void embLayerNormKernelHFace_3(int32_t ld, ...@@ -93,7 +92,6 @@ __global__ void embLayerNormKernelHFace_3(int32_t ld,
int32_t const* inputIds0, int32_t const* inputIds0,
int32_t const* inputIds1, int32_t const* inputIds1,
int32_t const* inputIds2, int32_t const* inputIds2,
int32_t nbLookupTables,
float const* beta, float const* beta,
float const* gamma, float const* gamma,
T const* mIdsEmbDev0, T const* mIdsEmbDev0,
...@@ -168,7 +166,6 @@ __global__ void embLayerNormKernelHFace_4(int32_t ld, ...@@ -168,7 +166,6 @@ __global__ void embLayerNormKernelHFace_4(int32_t ld,
int32_t const* inputIds1, int32_t const* inputIds1,
int32_t const* inputIds2, int32_t const* inputIds2,
int32_t const* inputIds3, int32_t const* inputIds3,
int32_t nbLookupTables,
float const* beta, float const* beta,
float const* gamma, float const* gamma,
T const* mIdsEmbDev0, T const* mIdsEmbDev0,
...@@ -273,7 +270,6 @@ int32_t embSkipLayerNormHFace_2(cudaStream_t stream, ...@@ -273,7 +270,6 @@ int32_t embSkipLayerNormHFace_2(cudaStream_t stream,
<<<grid, block, cache_size, stream>>>(ld, <<<grid, block, cache_size, stream>>>(ld,
inputIds0, inputIds0,
inputIds1, inputIds1,
nbLookupTables,
beta, beta,
gamma, gamma,
mIdsEmbDev0, mIdsEmbDev0,
...@@ -311,7 +307,6 @@ int32_t embSkipLayerNormHFace_3(cudaStream_t stream, ...@@ -311,7 +307,6 @@ int32_t embSkipLayerNormHFace_3(cudaStream_t stream,
inputIds0, inputIds0,
inputIds1, inputIds1,
inputIds2, inputIds2,
nbLookupTables,
beta, beta,
gamma, gamma,
mIdsEmbDev0, mIdsEmbDev0,
...@@ -355,7 +350,6 @@ int32_t embSkipLayerNormHFace_4(cudaStream_t stream, ...@@ -355,7 +350,6 @@ int32_t embSkipLayerNormHFace_4(cudaStream_t stream,
inputIds1, inputIds1,
inputIds2, inputIds2,
inputIds3, inputIds3,
nbLookupTables,
beta, beta,
gamma, gamma,
mIdsEmbDev0, mIdsEmbDev0,
......
...@@ -33,7 +33,6 @@ template <typename T, unsigned TPB> ...@@ -33,7 +33,6 @@ template <typename T, unsigned TPB>
__global__ void embLayerNormKernelMTron_2(int32_t ld, __global__ void embLayerNormKernelMTron_2(int32_t ld,
int32_t const* inputIds0, int32_t const* inputIds0,
int32_t const* inputIds1, int32_t const* inputIds1,
int32_t nbLookupTables,
float const* beta, float const* beta,
float const* gamma, float const* gamma,
T const* mIdsEmbDev0, T const* mIdsEmbDev0,
...@@ -95,7 +94,6 @@ __global__ void embLayerNormKernelMTron_3(int32_t ld, ...@@ -95,7 +94,6 @@ __global__ void embLayerNormKernelMTron_3(int32_t ld,
int32_t const* inputIds0, int32_t const* inputIds0,
int32_t const* inputIds1, int32_t const* inputIds1,
int32_t const* inputIds2, int32_t const* inputIds2,
int32_t nbLookupTables,
float const* beta, float const* beta,
float const* gamma, float const* gamma,
T const* mIdsEmbDev0, T const* mIdsEmbDev0,
...@@ -172,7 +170,6 @@ __global__ void embLayerNormKernelMTron_4(int32_t ld, ...@@ -172,7 +170,6 @@ __global__ void embLayerNormKernelMTron_4(int32_t ld,
int32_t const* inputIds1, int32_t const* inputIds1,
int32_t const* inputIds2, int32_t const* inputIds2,
int32_t const* inputIds3, int32_t const* inputIds3,
int32_t nbLookupTables,
float const* beta, float const* beta,
float const* gamma, float const* gamma,
T const* mIdsEmbDev0, T const* mIdsEmbDev0,
...@@ -280,7 +277,6 @@ int32_t embSkipLayerNormMTron_2(cudaStream_t stream, ...@@ -280,7 +277,6 @@ int32_t embSkipLayerNormMTron_2(cudaStream_t stream,
<<<grid, block, cache_size, stream>>>(ld, <<<grid, block, cache_size, stream>>>(ld,
inputIds0, inputIds0,
inputIds1, inputIds1,
nbLookupTables,
beta, beta,
gamma, gamma,
mIdsEmbDev0, mIdsEmbDev0,
...@@ -320,7 +316,6 @@ int32_t embSkipLayerNormMTron_3(cudaStream_t stream, ...@@ -320,7 +316,6 @@ int32_t embSkipLayerNormMTron_3(cudaStream_t stream,
inputIds0, inputIds0,
inputIds1, inputIds1,
inputIds2, inputIds2,
nbLookupTables,
beta, beta,
gamma, gamma,
mIdsEmbDev0, mIdsEmbDev0,
...@@ -366,7 +361,6 @@ int32_t embSkipLayerNormMTron_4(cudaStream_t stream, ...@@ -366,7 +361,6 @@ int32_t embSkipLayerNormMTron_4(cudaStream_t stream,
inputIds1, inputIds1,
inputIds2, inputIds2,
inputIds3, inputIds3,
nbLookupTables,
beta, beta,
gamma, gamma,
mIdsEmbDev0, mIdsEmbDev0,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h"
#include <cuda.h>
#include <cstring>
#include <vector>
#include "NvInfer.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
constexpr size_t threadsPerCta128 = 2 * 2 * 32;
constexpr size_t threadsPerCta256 = 1 * 4 * 32;
constexpr size_t threadsPerCta384 = 1 * 8 * 32;
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M
// dimension: (s + 16*warps_m - 1) / (16*warps_m);
constexpr size_t xmmasM128 = 4;
constexpr size_t xmmasM256 = 16;
constexpr size_t xmmasM384 = 24;
// Packed mask size per batch. Layout is XMMAS_M * THREADS_PER_CTA.
constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256;
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384;
char const* EMB_LAYER_NORM_VERSION{"1"};
char const* EMB_LAYER_NORM_NAME{"ManyEmbLayerNormPluginDynamic"};
// Static class fields initialization
nvinfer1::PluginFieldCollection EmbLayerNormPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> EmbLayerNormPluginCreator::mPluginAttributes;
EmbLayerNormPlugin::EmbLayerNormPlugin(
std::string const& name,
nvinfer1::DataType const type,
nvinfer1::Weights const& beta,
nvinfer1::Weights const& gamma,
const std::vector<nvinfer1::Weights>& IdsEmb)
: mLayerName(name),
mLd(beta.count),
mType(type),
mIdsEmb_(IdsEmb),
nbLookupTables_(static_cast<int>(IdsEmb.size())) {
// Assuming Weights.count is the number of elements and not bytes
assert(beta.count == gamma.count);
mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT);
mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT);
copyToDevice(&mGamma, sizeof(float) * mGamma.count, &mGammaDev);
copyToDevice(&mBeta, sizeof(float) * mBeta.count, &mBetaDev);
for (size_t i = 0; i < mIdsEmb_.size(); ++i) {
assert(mIdsEmb_[i].count % mLd == 0);
mIdsVocabSize.push_back(int32_t(mIdsEmb_[i].count / mLd));
WeightsWithOwnership tem_weight;
tem_weight.convertAndCopy(mIdsEmb_[i], mType);
void* cudaMem{nullptr};
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMalloc(&cudaMem, getWeightsSize(tem_weight, mType)));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(cudaMem,
tem_weight.values,
getWeightsSize(tem_weight, mType),
cudaMemcpyHostToDevice));
mIdsEmbPtrs.push_back(cudaMem);
}
}
EmbLayerNormPlugin::EmbLayerNormPlugin(std::string const& name,
void const* data,
size_t length)
: mLayerName(name),
mGammaDev(nullptr),
mBetaDev(nullptr),
mIdsEmbPtrs{},
mIdsEmb_{} {
// Deserialize in the same order as serialization
deserialize_value(&data, &length, &mType);
deserialize_value(&data, &length, &mLd);
deserialize_value(&data, &length, &nbLookupTables_);
for (int32_t i = 0; i < nbLookupTables_; ++i) {
int32_t tem;
deserialize_value(&data, &length, &tem);
mIdsVocabSize.push_back(tem);
}
char const* d = static_cast<char const*>(data);
mBeta.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT);
mGamma.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT);
for (int32_t i = 0; i < nbLookupTables_; ++i) {
nvinfer1::Weights pre_tem_weight;
pre_tem_weight.type = mType;
pre_tem_weight.count = mLd * size_t(mIdsVocabSize[i]);
const auto nbBytes = mLd * size_t(mIdsVocabSize[i]) * getElementSize(mType);
auto destBuf = new char[nbBytes];
pre_tem_weight.values = destBuf;
std::copy_n(d, nbBytes, destBuf);
d += nbBytes;
mIdsEmb_.push_back(pre_tem_weight);
}
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* EmbLayerNormPlugin::clone() const noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin clone");
auto p = new EmbLayerNormPlugin(mLayerName, mType, mBeta, mGamma, mIdsEmb_);
p->setPluginNamespace(mNamespace.c_str());
return p;
}
nvinfer1::DimsExprs EmbLayerNormPlugin::getOutputDimensions(
int32_t outputIndex,
nvinfer1::DimsExprs const* inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept {
assert(outputIndex == 0);
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = exprBuilder.constant(mLd);
return ret;
}
bool EmbLayerNormPlugin::supportsFormatCombination(
int32_t pos,
nvinfer1::PluginTensorDesc const* inOut,
int32_t nbInputs,
int32_t nbOutputs) noexcept {
assert(nbOutputs == 1);
nvinfer1::PluginTensorDesc const& prev = inOut[0];
nvinfer1::PluginTensorDesc const& desc = inOut[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
return false;
}
if (pos == 0) {
return desc.type == nvinfer1::DataType::kINT32;
}
if (0 < pos && pos < nbInputs) {
assert(desc.dims.nbDims == prev.dims.nbDims);
for (int i = 0; i < prev.dims.nbDims; ++i) {
assert(desc.dims.d[i] == prev.dims.d[i]);
}
return desc.type == prev.type;
}
if (pos == nbInputs) { // output
return desc.type == mType && desc.dims.nbDims == 3 &&
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
}
}
void EmbLayerNormPlugin::configurePlugin(
nvinfer1::DynamicPluginTensorDesc const* inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* outputs,
int32_t nbOutputs) noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin configurePlugin");
assert(static_cast<size_t>(outputs[0].desc.dims.d[2]) ==
static_cast<size_t>(mLd));
int32_t const B = inputs[0].desc.dims.d[0];
if (B > 0) {
assert(outputs[0].desc.dims.d[0] == B);
}
assert(outputs[0].desc.type == mType);
}
size_t EmbLayerNormPlugin::getWorkspaceSize(
nvinfer1::PluginTensorDesc const* inputs,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs,
int32_t nbOutputs) const noexcept {
return 0;
}
int32_t EmbLayerNormPlugin::enqueue(
nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
int32_t batchSize = inputDesc[0].dims.d[0];
int32_t const maxSeqlen = inputDesc[0].dims.d[1];
if (maxSeqlen > 512) {
PADDLE_THROW(platform::errors::InvalidArgument(
"EmbLayerNormPlugin support maxSeqlen is 512"));
}
const float* beta = mBetaDev.get();
const float* gamma = mGammaDev.get();
if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]);
if (nbLookupTables_ == 2) {
return embSkipLayerNorm_2<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
maxSeqlen,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNorm_3<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
maxSeqlen,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNorm_4<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
maxSeqlen,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
static_cast<float const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]);
if (nbLookupTables_ == 2) {
return embSkipLayerNorm_2<half>(stream,
static_cast<int32_t>(mLd),
batchSize,
maxSeqlen,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNorm_3<half>(stream,
static_cast<int32_t>(mLd),
batchSize,
maxSeqlen,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNorm_4<half>(stream,
static_cast<int32_t>(mLd),
batchSize,
maxSeqlen,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
static_cast<half const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported type error, expected [kHALF,kFLOAT]"));
}
return STATUS_SUCCESS;
}
// IPluginV2Ext Methods
nvinfer1::DataType EmbLayerNormPlugin::getOutputDataType(
int32_t index,
nvinfer1::DataType const* inputTypes,
int32_t nbInputs) const noexcept {
assert(index == 0);
assert(mType == nvinfer1::DataType::kHALF ||
mType == nvinfer1::DataType::kFLOAT);
return mType;
}
// IPluginV2 Methods
char const* EmbLayerNormPlugin::getPluginType() const noexcept {
return EMB_LAYER_NORM_NAME;
}
char const* EmbLayerNormPlugin::getPluginVersion() const noexcept {
return EMB_LAYER_NORM_VERSION;
}
int32_t EmbLayerNormPlugin::getNbOutputs() const noexcept { return 1; }
int32_t EmbLayerNormPlugin::initialize() noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin initialize");
return 0;
}
void EmbLayerNormPlugin::terminate() noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin terminate");
}
size_t EmbLayerNormPlugin::getSerializationSize() const noexcept {
size_t const wordSize = getElementSize(mType);
return 2 * sizeof(float) * mLd // beta + gamma
+ sizeof(mType) //
+ sizeof(mLd) //
+ mIdsVocabSize.size() * sizeof(mIdsVocabSize[0]) //
+ wordSize * mLd *
accumulate(
mIdsVocabSize.begin(), mIdsVocabSize.end(), 0) // ids emb
+ sizeof(nbLookupTables_); // numbers of lookup_table
}
void EmbLayerNormPlugin::serialize(void* buffer) const noexcept {
serialize_value(&buffer, mType);
serialize_value(&buffer, mLd);
serialize_value(&buffer, nbLookupTables_);
for (size_t i = 0; i < mIdsVocabSize.size(); ++i) {
serialize_value(&buffer, mIdsVocabSize[i]);
}
char* d = static_cast<char*>(buffer);
size_t const wordSize = getElementSize(mType);
serFromDev(&d, mBetaDev.get(), mLd);
serFromDev(&d, mGammaDev.get(), mLd);
for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
serFromDev(&d,
static_cast<char*>(mIdsEmbPtrs[i]),
mLd * mIdsVocabSize[i] * wordSize);
}
}
void EmbLayerNormPlugin::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
mBetaDev.reset(nullptr);
mGammaDev.reset(nullptr);
for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
cudaFree(mIdsEmbPtrs[i]);
}
delete this;
}
void EmbLayerNormPlugin::setPluginNamespace(char const* libNamespace) noexcept {
mNamespace = libNamespace;
}
char const* EmbLayerNormPlugin::getPluginNamespace() const noexcept {
return mNamespace.c_str();
}
EmbLayerNormPluginCreator::EmbLayerNormPluginCreator() {}
char const* EmbLayerNormPluginCreator::getPluginName() const noexcept {
return EMB_LAYER_NORM_NAME;
}
char const* EmbLayerNormPluginCreator::getPluginVersion() const noexcept {
return EMB_LAYER_NORM_VERSION;
}
nvinfer1::PluginFieldCollection const*
EmbLayerNormPluginCreator::getFieldNames() noexcept {
return &mFC;
}
bool initialize_fields(nvinfer1::PluginFieldCollection const* fc,
nvinfer1::Weights* beta,
nvinfer1::Weights* gamma,
std::vector<nvinfer1::Weights>* IdsEmb) {
bool output_fp16 = false;
for (int32_t i = 0; i < fc->nbFields; i++) {
std::string field_name(fc->fields[i].name);
if (field_name.compare("bert_embeddings_layernorm_beta") == 0) {
TRANSFORMER_DEBUG_MSG("Building bert_embeddings_layernorm_beta...");
beta->values = fc->fields[i].data;
beta->count = fc->fields[i].length;
beta->type = fieldTypeToDataType(fc->fields[i].type);
}
if (field_name.compare("bert_embeddings_layernorm_gamma") == 0) {
TRANSFORMER_DEBUG_MSG("Building bert_embeddings_layernorm_gamma...");
gamma->values = fc->fields[i].data;
gamma->count = fc->fields[i].length;
gamma->type = fieldTypeToDataType(fc->fields[i].type);
}
if (field_name.compare("output_fp16") == 0) {
TRANSFORMER_DEBUG_MSG("Building output_fp16...");
assert(fc->fields[i].type == nvinfer1::PluginFieldType::kINT32);
output_fp16 = static_cast<int32_t const*>(fc->fields[i].data)[0] != 0;
}
if (field_name.compare("bert_embeddings_word_embeddings_" +
std::to_string(i - 3)) == 0) {
TRANSFORMER_DEBUG_MSG(
("bert_embeddings_word_embeddings_" + std::to_string(i - 3)).c_str());
nvinfer1::Weights tem;
tem.values = fc->fields[i].data;
tem.count = fc->fields[i].length;
tem.type = fieldTypeToDataType(fc->fields[i].type);
IdsEmb->push_back(tem);
}
}
return output_fp16;
}
nvinfer1::IPluginV2* EmbLayerNormPluginCreator::createPlugin(
char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormVar createPlugin");
nvinfer1::Weights beta;
nvinfer1::Weights gamma;
std::vector<nvinfer1::Weights> IdsEmb;
bool output_fp16 = initialize_fields(fc, &beta, &gamma, &IdsEmb);
TRANSFORMER_DEBUG_MSG("Building the Plugin...");
EmbLayerNormPlugin* p = new EmbLayerNormPlugin(
name,
output_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
beta,
gamma,
IdsEmb);
return p;
}
nvinfer1::IPluginV2* EmbLayerNormPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept {
return new EmbLayerNormPlugin(name, serialData, serialLength);
}
void EmbLayerNormPluginCreator::setPluginNamespace(
char const* libNamespace) noexcept {
mNamespace = libNamespace;
}
char const* EmbLayerNormPluginCreator::getPluginNamespace() const noexcept {
return mNamespace.c_str();
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include "NvInferPlugin.h"
#include "NvInferRuntime.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
template <typename T>
int32_t embSkipLayerNorm_2(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNorm_3(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNorm_4(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
int32_t,
T*);
class EmbLayerNormPlugin : public nvinfer1::IPluginV2DynamicExt {
public:
EmbLayerNormPlugin(std::string const& name,
nvinfer1::DataType const type,
nvinfer1::Weights const& beta,
nvinfer1::Weights const& gamma,
const std::vector<nvinfer1::Weights>& ids_emb);
EmbLayerNormPlugin(std::string const& name, void const* data, size_t length);
EmbLayerNormPlugin() = delete;
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(
int32_t outputIndex,
const nvinfer1::DimsExprs* inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out,
int32_t nbOutputs) noexcept override;
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept override;
int32_t initialize() noexcept override;
void terminate() noexcept override;
char const* getPluginVersion() const noexcept override;
bool supportsFormatCombination(int32_t pos,
nvinfer1::PluginTensorDesc const* inOut,
int32_t nbInputs,
int32_t nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs,
int32_t nbOutputs) const noexcept override;
nvinfer1::DataType getOutputDataType(
int32_t index,
nvinfer1::DataType const* inputTypes,
int32_t nbInputs) const noexcept override;
char const* getPluginType() const noexcept override;
int32_t getNbOutputs() const noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
char const* getPluginNamespace() const noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
protected:
std::string const mLayerName;
std::string mNamespace;
cuda_unique_ptr<float> mGammaDev;
cuda_unique_ptr<float> mBetaDev;
std::vector<void*> mIdsEmbPtrs;
size_t mLd; // leading dim = hidden size
std::vector<int32_t> mIdsVocabSize;
WeightsWithOwnership mBeta;
WeightsWithOwnership mGamma;
nvinfer1::DataType mType;
std::vector<nvinfer1::Weights> mIdsEmb_;
int32_t nbLookupTables_ = 0;
};
class EmbLayerNormPluginCreator : public nvinfer1::IPluginCreator {
public:
EmbLayerNormPluginCreator();
char const* getPluginName() const noexcept override;
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
nvinfer1::IPluginV2* createPlugin(
char const* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::IPluginV2* deserializePlugin(char const* name,
void const* serialData,
size_t serialLength) noexcept override;
protected:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
REGISTER_TRT_PLUGIN_V2(EmbLayerNormPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -39,7 +39,8 @@ constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256; ...@@ -39,7 +39,8 @@ constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256;
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384;
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"2"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"2"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"ManyEmbLayerNormPluginDynamic"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{
"ManyEmbLayerNormVarlenPluginDynamic"};
// Static class fields initialization // Static class fields initialization
nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{}; nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{};
std::vector<nvinfer1::PluginField> std::vector<nvinfer1::PluginField>
...@@ -167,7 +168,6 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions( ...@@ -167,7 +168,6 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions(
assert(inputs[i].nbDims == inputs[1].nbDims); // same shape assert(inputs[i].nbDims == inputs[1].nbDims); // same shape
} }
assert(inputs[0].nbDims == 1); // pos_id: B+1 assert(inputs[0].nbDims == 1); // pos_id: B+1
assert(outputIndex == 0 || outputIndex == 1);
if (outputIndex == 0) { if (outputIndex == 0) {
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 4; ret.nbDims = 4;
...@@ -176,25 +176,32 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions( ...@@ -176,25 +176,32 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions(
ret.d[2] = exprBuilder.constant(1); ret.d[2] = exprBuilder.constant(1);
ret.d[3] = exprBuilder.constant(1); ret.d[3] = exprBuilder.constant(1);
return ret; return ret;
} } else if (outputIndex == 1) {
// This is a hack: we just report some mask size and rely the plugins to
// This is a hack: we just report some mask size and rely the plugins to play // play nicely together.
// nicely together.
// At runtime, depending on the actual maxSeqlen, the size might be // At runtime, depending on the actual maxSeqlen, the size might be
// different. // different.
int32_t maskSize_ = packedMaskSize384; int32_t maskSize_ = packedMaskSize384;
auto maskSize = exprBuilder.constant(maskSize_); auto maskSize = exprBuilder.constant(maskSize_);
auto fp16maskSize = exprBuilder.operation( auto fp16maskSize =
nvinfer1::DimensionOperation::kPROD, *maskSize, *exprBuilder.constant(2)); exprBuilder.operation(nvinfer1::DimensionOperation::kPROD,
*maskSize,
*exprBuilder.constant(2));
auto Bplus1 = inputs[0].d[0]; // pos_id auto Bplus1 = inputs[0].d[0]; // pos_id
auto one = exprBuilder.constant(1); auto one = exprBuilder.constant(1);
auto B = auto B = exprBuilder.operation(
exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, *Bplus1, *one); nvinfer1::DimensionOperation::kSUB, *Bplus1, *one);
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 2; ret.nbDims = 2;
ret.d[0] = B; ret.d[0] = B;
ret.d[1] = fp16maskSize; ret.d[1] = fp16maskSize;
return ret; return ret;
} else {
nvinfer1::DimsExprs ret;
ret.nbDims = 1;
ret.d[0] = inputs[nbInputs - 1].d[1]; // mask id: max seqlen
return ret;
}
} }
nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions( nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions(
...@@ -209,14 +216,20 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions( ...@@ -209,14 +216,20 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions(
assert(inputs[i].nbDims == inputs[1].nbDims); // same shape assert(inputs[i].nbDims == inputs[1].nbDims); // same shape
} }
assert(inputs[0].nbDims == 1); // pos_id: B+1 assert(inputs[0].nbDims == 1); // pos_id: B+1
assert(outputIndex == 0 || outputIndex == 1); if (outputIndex == 0 || outputIndex == 1) {
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 4; ret.nbDims = 4;
ret.d[0] = inputs[1].d[0]; ret.d[0] = inputs[1].d[0]; // sum of seq length
ret.d[1] = exprBuilder.constant(mLd); ret.d[1] = exprBuilder.constant(mLd);
ret.d[2] = exprBuilder.constant(1); ret.d[2] = exprBuilder.constant(1);
ret.d[3] = exprBuilder.constant(1); ret.d[3] = exprBuilder.constant(1);
return ret; return ret;
} else {
nvinfer1::DimsExprs ret;
ret.nbDims = 1;
ret.d[0] = inputs[nbInputs - 1].d[1]; // mask id: max seqlen
return ret;
}
} }
bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
...@@ -224,7 +237,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( ...@@ -224,7 +237,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
nvinfer1::PluginTensorDesc const* inOut, nvinfer1::PluginTensorDesc const* inOut,
int32_t nbInputs, int32_t nbInputs,
int32_t nbOutputs) noexcept { int32_t nbOutputs) noexcept {
assert(nbOutputs == 2); assert(nbOutputs == 3);
nvinfer1::PluginTensorDesc const& desc = inOut[pos]; nvinfer1::PluginTensorDesc const& desc = inOut[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) { if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
return false; return false;
...@@ -241,8 +254,8 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( ...@@ -241,8 +254,8 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
return desc.type == prev.type && desc.dims.nbDims == 1 && return desc.type == prev.type && desc.dims.nbDims == 1 &&
desc.dims.d[0] == prev.dims.d[0]; desc.dims.d[0] == prev.dims.d[0];
} }
if (pos == nbInputs - 1) { // max seq length if (pos == nbInputs - 1) { // mask id
return desc.dims.nbDims == 1; return desc.type == prev.type;
} }
// embedded sequence // embedded sequence
if (pos == nbInputs) { if (pos == nbInputs) {
...@@ -250,8 +263,14 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( ...@@ -250,8 +263,14 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
desc.dims.d[0] == inOut[1].dims.d[0] && desc.dims.d[2] == 1 && desc.dims.d[0] == inOut[1].dims.d[0] && desc.dims.d[2] == 1 &&
desc.dims.d[3] == 1; desc.dims.d[3] == 1;
} }
// mask // mask(HFace) or pre_layernorm_bias(MTron)
return desc.type == nvinfer1::DataType::kHALF; if (pos == nbInputs + 1) {
return desc.type == prev.type;
}
// max seqlen
if (pos == nbInputs + 2) {
return desc.type == prev.type;
}
} }
void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs, void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs,
...@@ -259,8 +278,7 @@ void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs, ...@@ -259,8 +278,7 @@ void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs,
nvinfer1::DynamicPluginTensorDesc const* outputs, nvinfer1::DynamicPluginTensorDesc const* outputs,
int32_t nbOutputs) noexcept { int32_t nbOutputs) noexcept {
// Validate input arguments // Validate input arguments
// assert(nbInputs == 4); assert(nbOutputs == 3);
assert(nbOutputs == 2);
assert(inputs[0].desc.dims.nbDims == 1); assert(inputs[0].desc.dims.nbDims == 1);
assert(inputs[0].desc.type == nvinfer1::DataType::kINT32); assert(inputs[0].desc.type == nvinfer1::DataType::kINT32);
for (int i = 1; i < nbInputs - 1; ++i) { for (int i = 1; i < nbInputs - 1; ++i) {
...@@ -671,7 +689,7 @@ char const* EmbLayerNormVarSeqlenPluginMTron::getPluginVersion() ...@@ -671,7 +689,7 @@ char const* EmbLayerNormVarSeqlenPluginMTron::getPluginVersion()
} }
int32_t EmbLayerNormVarSeqlenPluginBase::getNbOutputs() const noexcept { int32_t EmbLayerNormVarSeqlenPluginBase::getNbOutputs() const noexcept {
return 2; return 3;
} }
int32_t EmbLayerNormVarSeqlenPluginHFace::initialize() noexcept { int32_t EmbLayerNormVarSeqlenPluginHFace::initialize() noexcept {
......
...@@ -194,7 +194,6 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt { ...@@ -194,7 +194,6 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
cuda_unique_ptr<float> mGammaDev; cuda_unique_ptr<float> mGammaDev;
cuda_unique_ptr<float> mBetaDev; cuda_unique_ptr<float> mBetaDev;
std::vector<void*> mIdsEmbPtrs; std::vector<void*> mIdsEmbPtrs;
// std::vector<void*> mIdsEmbDev;
size_t mLd; // leading dim = hidden size size_t mLd; // leading dim = hidden size
std::vector<int32_t> mIdsVocabSize; std::vector<int32_t> mIdsVocabSize;
WeightsWithOwnership mBeta; WeightsWithOwnership mBeta;
......
...@@ -28,11 +28,13 @@ limitations under the License. */ ...@@ -28,11 +28,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
#if defined _WIN32
#else
TEST(AnalysisPredictor, no_fp16) { TEST(AnalysisPredictor, no_fp16) {
std::vector<float> result = {0.597841, 0.219972, 0.182187}; std::vector<float> result = {0.597841, 0.219972, 0.182187};
trt_ernie(false, result); trt_ernie(false, result);
} }
#endif
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -38,23 +38,23 @@ static void run(const AnalysisConfig& config, std::vector<float>* out_data) { ...@@ -38,23 +38,23 @@ static void run(const AnalysisConfig& config, std::vector<float>* out_data) {
int run_batch = 1; int run_batch = 1;
const int run_seq_len = 128; const int run_seq_len = 128;
std::vector<int64_t> tmp_input; std::vector<int32_t> tmp_input;
std::vector<float> tmp_four_input; std::vector<float> tmp_four_input;
tmp_input.reserve(run_batch * run_seq_len); tmp_input.reserve(run_batch * run_seq_len);
tmp_four_input.reserve(run_batch * run_seq_len); tmp_four_input.reserve(run_batch * run_seq_len);
int64_t i0[run_seq_len] = { int32_t i0[run_seq_len] = {
1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321,
4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2,
75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2};
int64_t i1[run_seq_len] = { int32_t i1[run_seq_len] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, int32_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; 30, 31, 32, 33, 34, 35, 36, 37, 38, 39};
...@@ -136,11 +136,7 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -136,11 +136,7 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) {
precision = AnalysisConfig::Precision::kHalf; precision = AnalysisConfig::Precision::kHalf;
} }
#if defined _WIN32
#else
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false);
#endif
config.SetTRTDynamicShapeInfo( config.SetTRTDynamicShapeInfo(
min_input_shape, max_input_shape, opt_input_shape); min_input_shape, max_input_shape, opt_input_shape);
AnalysisConfig* config_deser = new AnalysisConfig(config); AnalysisConfig* config_deser = new AnalysisConfig(config);
......
...@@ -33,18 +33,18 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data, int bs) { ...@@ -33,18 +33,18 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data, int bs) {
const int run_seq_len = 128; const int run_seq_len = 128;
size_t len = run_batch * run_seq_len; size_t len = run_batch * run_seq_len;
int64_t i0_bs1[run_seq_len] = { int32_t i0_bs1[run_seq_len] = {
1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321,
4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2,
75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2};
int64_t i1_bs1[run_seq_len] = { int32_t i1_bs1[run_seq_len] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int64_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, int32_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; 30, 31, 32, 33, 34, 35, 36, 37, 38, 39};
...@@ -52,7 +52,7 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data, int bs) { ...@@ -52,7 +52,7 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data, int bs) {
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
std::vector<int64_t> i0_data(len), i1_data(len), i2_data(len); std::vector<int32_t> i0_data(len), i1_data(len), i2_data(len);
std::vector<float> i3_data(len); std::vector<float> i3_data(len);
for (size_t i = 0; i < len; i++) { for (size_t i = 0; i < len; i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册