未验证 提交 6512e087 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]fix embedding fused (#46789)

* fix embedding fused
上级 ae6b4713
...@@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "2"); "ManyEmbLayerNormPluginDynamic", "1");
auto plugin_obj = auto plugin_obj =
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V2(Output: " + plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V1(Output: " +
op_desc.Output("Out")[0] + ")") op_desc.Output("Out")[0] + ")")
.c_str()); .c_str());
free(plugin_ptr); free(plugin_ptr);
...@@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer = plugin_layer; layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, RreplenishLayerAndOutput(layer,
"ManyEmbLayerNormPluginDynamic_V2", "ManyEmbLayerNormPluginDynamic_V1",
{output_name, std::string("qkv_plugin_mask")}, {output_name, std::string("qkv_plugin_mask")},
test_mode); test_mode);
} }
......
...@@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "3"); "ManyEmbLayerNormPluginDynamic", "2");
auto plugin_obj = auto plugin_obj =
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
......
...@@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24; ...@@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24;
constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256; constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256;
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384;
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"2"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"3"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"2"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"ManyEmbLayerNormPluginDynamic"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"ManyEmbLayerNormPluginDynamic"};
// Static class fields initialization // Static class fields initialization
nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{}; nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{};
...@@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( ...@@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
tem_weight.values, tem_weight.values,
getWeightsSize(tem_weight, mType), getWeightsSize(tem_weight, mType),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
mIdsEmbDev.push_back(cudaMem); mIdsEmbPtrs.push_back(cudaMem);
} }
} }
...@@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( ...@@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
: mLayerName(name), : mLayerName(name),
mGammaDev(nullptr), mGammaDev(nullptr),
mBetaDev(nullptr), mBetaDev(nullptr),
mIdsEmbDev{}, mIdsEmbPtrs{},
mIdsEmb_{} { mIdsEmb_{} {
// Deserialize in the same order as serialization // Deserialize in the same order as serialization
deserialize_value(&data, &length, &mType); deserialize_value(&data, &length, &mType);
...@@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron( ...@@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(
// IPluginV2DynamicExt Methods // IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone() nvinfer1::IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone()
const noexcept { const noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron clone"); TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace clone");
auto p = new EmbLayerNormVarSeqlenPluginMTron( auto p = new EmbLayerNormVarSeqlenPluginHFace(
mLayerName, mType, mBeta, mGamma, mIdsEmb_); mLayerName, mType, mBeta, mGamma, mIdsEmb_);
p->setPluginNamespace(mNamespace.c_str()); p->setPluginNamespace(mNamespace.c_str());
return p; return p;
...@@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue( ...@@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
void* const* outputs, void* const* outputs,
void* workspace, void* workspace,
cudaStream_t stream) noexcept { cudaStream_t stream) noexcept {
int32_t const batchSize = inputDesc[0].dims.d[0] - 1; int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input // read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t S = 384; int32_t S = 384;
...@@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue( ...@@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
} }
const float* beta = mBetaDev.get(); const float* beta = mBetaDev.get();
const float* gamma = mGammaDev.get(); const float* gamma = mGammaDev.get();
int32_t** tem_inputs_ptr_dev;
cudaMalloc(reinterpret_cast<void**>(&tem_inputs_ptr_dev),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(tem_inputs_ptr_dev,
inputs,
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
int32_t* mIdsVocabSize_dev;
cudaMalloc(reinterpret_cast<void**>(&mIdsVocabSize_dev),
sizeof(int32_t) * mIdsVocabSize.size());
cudaMemcpy(mIdsVocabSize_dev,
&(mIdsVocabSize[0]),
sizeof(int32_t) * mIdsVocabSize.size(),
cudaMemcpyHostToDevice);
if (mType == nvinfer1::DataType::kFLOAT) { if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]); auto output = static_cast<float*>(outputs[0]);
float** mIdsEmbDev_float; if (nbLookupTables_ == 2) {
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_float), return embSkipLayerNormHFace_2<float>(
sizeof(void*) * nbLookupTables_); stream,
cudaMemcpy(mIdsEmbDev_float, static_cast<int32_t>(mLd),
&(mIdsEmbDev[0]), batchSize,
sizeof(void*) * nbLookupTables_, S,
cudaMemcpyHostToDevice); static_cast<int32_t const*>(inputs[0]),
return embSkipLayerNormHFace<float>(stream, static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t>(mLd), nbLookupTables_,
batchSize, beta,
S, gamma,
tem_inputs_ptr_dev, static_cast<float const*>(mIdsEmbPtrs[0]),
nbLookupTables_, static_cast<float const*>(mIdsEmbPtrs[1]),
beta, mIdsVocabSize[0],
gamma, mIdsVocabSize[1],
mIdsEmbDev_float, output);
mIdsVocabSize_dev, } else if (nbLookupTables_ == 3) {
output); return embSkipLayerNormHFace_3<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormHFace_4<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
static_cast<float const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else if (mType == nvinfer1::DataType::kHALF) { } else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]); auto output = static_cast<half*>(outputs[0]);
half** mIdsEmbDev_half; if (nbLookupTables_ == 2) {
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_half), return embSkipLayerNormHFace_2<half>(
sizeof(void*) * nbLookupTables_); stream,
cudaMemcpy(mIdsEmbDev_half, static_cast<int32_t>(mLd),
&(mIdsEmbDev[0]), batchSize,
sizeof(void*) * nbLookupTables_, S,
cudaMemcpyHostToDevice); static_cast<int32_t const*>(inputs[0]),
return embSkipLayerNormHFace<half>(stream, static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t>(mLd), nbLookupTables_,
batchSize, beta,
S, gamma,
tem_inputs_ptr_dev, static_cast<half const*>(mIdsEmbPtrs[0]),
nbLookupTables_, static_cast<half const*>(mIdsEmbPtrs[1]),
beta, mIdsVocabSize[0],
gamma, mIdsVocabSize[1],
mIdsEmbDev_half, output);
mIdsVocabSize_dev, } else if (nbLookupTables_ == 3) {
output); return embSkipLayerNormHFace_3<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormHFace_4<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
static_cast<half const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported type error, expected [kHALF,kFLOAT]")); "Unsupported type error, expected [kHALF,kFLOAT]"));
...@@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue( ...@@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
void* const* outputs, void* const* outputs,
void* workspace, void* workspace,
cudaStream_t stream) noexcept { cudaStream_t stream) noexcept {
int32_t const batchSize = inputDesc[0].dims.d[0] - 1; int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input // read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t S = 384; int32_t S = 384;
...@@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue( ...@@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
} }
const float* beta = mBetaDev.get(); const float* beta = mBetaDev.get();
const float* gamma = mGammaDev.get(); const float* gamma = mGammaDev.get();
int32_t** tem_inputs_ptr_dev;
cudaMalloc(reinterpret_cast<void**>(&tem_inputs_ptr_dev),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(tem_inputs_ptr_dev,
inputs,
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
int32_t* mIdsVocabSize_dev;
cudaMalloc(reinterpret_cast<void**>(&mIdsVocabSize_dev),
sizeof(int32_t) * mIdsVocabSize.size());
cudaMemcpy(mIdsVocabSize_dev,
&(mIdsVocabSize[0]),
sizeof(int32_t) * mIdsVocabSize.size(),
cudaMemcpyHostToDevice);
if (mType == nvinfer1::DataType::kFLOAT) { if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]); auto output = static_cast<float*>(outputs[0]);
auto skip = static_cast<float*>(outputs[1]); auto skip = static_cast<float*>(outputs[1]);
float** mIdsEmbDev_float; if (nbLookupTables_ == 2) {
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_float), return embSkipLayerNormMTron_2<float>(
sizeof(void*) * nbLookupTables_); stream,
cudaMemcpy(mIdsEmbDev_float, static_cast<int32_t>(mLd),
&(mIdsEmbDev[0]), batchSize,
sizeof(void*) * nbLookupTables_, S,
cudaMemcpyHostToDevice); static_cast<int32_t const*>(inputs[0]),
return embSkipLayerNormMTron<float>(stream, static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t>(mLd), nbLookupTables_,
batchSize, beta,
S, gamma,
tem_inputs_ptr_dev, static_cast<float const*>(mIdsEmbPtrs[0]),
nbLookupTables_, static_cast<float const*>(mIdsEmbPtrs[1]),
beta, mIdsVocabSize[0],
gamma, mIdsVocabSize[1],
mIdsEmbDev_float, output,
mIdsVocabSize_dev, skip);
output, } else if (nbLookupTables_ == 3) {
skip); return embSkipLayerNormMTron_3<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output,
skip);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormMTron_4<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
static_cast<float const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output,
skip);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else if (mType == nvinfer1::DataType::kHALF) { } else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]); auto output = static_cast<half*>(outputs[0]);
auto skip = static_cast<half*>(outputs[1]); auto skip = static_cast<half*>(outputs[1]);
half** mIdsEmbDev_half; if (nbLookupTables_ == 2) {
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_half), return embSkipLayerNormMTron_2<half>(
sizeof(void*) * nbLookupTables_); stream,
cudaMemcpy(mIdsEmbDev_half, static_cast<int32_t>(mLd),
&(mIdsEmbDev[0]), batchSize,
sizeof(void*) * nbLookupTables_, S,
cudaMemcpyHostToDevice); static_cast<int32_t const*>(inputs[0]),
return embSkipLayerNormMTron<half>(stream, static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t>(mLd), nbLookupTables_,
batchSize, beta,
S, gamma,
tem_inputs_ptr_dev, static_cast<half const*>(mIdsEmbPtrs[0]),
nbLookupTables_, static_cast<half const*>(mIdsEmbPtrs[1]),
beta, mIdsVocabSize[0],
gamma, mIdsVocabSize[1],
mIdsEmbDev_half, output,
mIdsVocabSize_dev, skip);
output, } else if (nbLookupTables_ == 3) {
skip); return embSkipLayerNormMTron_3<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output,
skip);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormMTron_4<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
static_cast<half const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output,
skip);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported type error, expected [kHALF,kFLOAT]")); "Unsupported type error, expected [kHALF,kFLOAT]"));
...@@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept { ...@@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept {
size_t const wordSize = getElementSize(mType); size_t const wordSize = getElementSize(mType);
serFromDev(&d, mBetaDev.get(), mLd); serFromDev(&d, mBetaDev.get(), mLd);
serFromDev(&d, mGammaDev.get(), mLd); serFromDev(&d, mGammaDev.get(), mLd);
for (size_t i = 0; i < mIdsEmbDev.size(); ++i) { for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
serFromDev(&d, serFromDev(&d,
static_cast<char*>(mIdsEmbDev[i]), static_cast<char*>(mIdsEmbPtrs[i]),
mLd * mIdsVocabSize[i] * wordSize); mLd * mIdsVocabSize[i] * wordSize);
} }
} }
...@@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept { ...@@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept {
// This gets called when the network containing plugin is destroyed // This gets called when the network containing plugin is destroyed
mBetaDev.reset(nullptr); mBetaDev.reset(nullptr);
mGammaDev.reset(nullptr); mGammaDev.reset(nullptr);
for (size_t i = 0; i < mIdsEmbDev.size(); ++i) { for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
cudaFree(mIdsEmbDev[i]); cudaFree(mIdsEmbPtrs[i]);
} }
delete this; delete this;
} }
...@@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin( ...@@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
beta, beta,
gamma, gamma,
IdsEmb); IdsEmb);
return p; return p;
} }
......
...@@ -31,32 +31,121 @@ namespace tensorrt { ...@@ -31,32 +31,121 @@ namespace tensorrt {
namespace plugin { namespace plugin {
template <typename T> template <typename T>
int32_t embSkipLayerNormHFace(cudaStream_t stream, int32_t embSkipLayerNormHFace_2(cudaStream_t,
int32_t ld, int32_t,
int32_t B, int32_t,
int32_t S, int32_t,
int32_t** inputIds, int32_t const*,
int32_t const nbLookupTables, int32_t const*,
float const* beta, int32_t,
float const* gamma, float const*,
T** idsEmb, float const*,
int32_t*, T const*,
T* output); T const*,
int32_t,
int32_t,
T*);
template <typename T> template <typename T>
int32_t embSkipLayerNormMTron(cudaStream_t stream, int32_t embSkipLayerNormHFace_3(cudaStream_t,
int32_t ld, int32_t,
int32_t B, int32_t,
int32_t S, int32_t,
int32_t** inputIds, int32_t const*,
int32_t const nbLookupTables, int32_t const*,
float const* beta, int32_t const*,
float const* gamma, int32_t,
T** idsEmb, float const*,
int32_t*, float const*,
T* output, T const*,
T* skip); T const*,
T const*,
int32_t,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormHFace_4(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_2(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
int32_t,
int32_t,
T*,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_3(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
T*,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_4(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
int32_t,
T*,
T*);
class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt { class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
public: public:
EmbLayerNormVarSeqlenPluginBase( EmbLayerNormVarSeqlenPluginBase(
...@@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt { ...@@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
std::string mNamespace; std::string mNamespace;
cuda_unique_ptr<float> mGammaDev; cuda_unique_ptr<float> mGammaDev;
cuda_unique_ptr<float> mBetaDev; cuda_unique_ptr<float> mBetaDev;
std::vector<void*> mIdsEmbDev; std::vector<void*> mIdsEmbPtrs;
// std::vector<void*> mIdsEmbDev;
size_t mLd; // leading dim = hidden size size_t mLd; // leading dim = hidden size
std::vector<int32_t> mIdsVocabSize; std::vector<int32_t> mIdsVocabSize;
WeightsWithOwnership mBeta; WeightsWithOwnership mBeta;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册