未验证 提交 6512e087 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]fix embedding fused (#46789)

* fix embedding fused
上级 ae6b4713
......@@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "2");
"ManyEmbLayerNormPluginDynamic", "1");
auto plugin_obj =
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V2(Output: " +
plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V1(Output: " +
op_desc.Output("Out")[0] + ")")
.c_str());
free(plugin_ptr);
......@@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer,
"ManyEmbLayerNormPluginDynamic_V2",
"ManyEmbLayerNormPluginDynamic_V1",
{output_name, std::string("qkv_plugin_mask")},
test_mode);
}
......
......@@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "3");
"ManyEmbLayerNormPluginDynamic", "2");
auto plugin_obj =
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
......
......@@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24;
constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256;
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384;
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"2"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"3"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"2"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"ManyEmbLayerNormPluginDynamic"};
// Static class fields initialization
nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{};
......@@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
tem_weight.values,
getWeightsSize(tem_weight, mType),
cudaMemcpyHostToDevice));
mIdsEmbDev.push_back(cudaMem);
mIdsEmbPtrs.push_back(cudaMem);
}
}
......@@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
: mLayerName(name),
mGammaDev(nullptr),
mBetaDev(nullptr),
mIdsEmbDev{},
mIdsEmbPtrs{},
mIdsEmb_{} {
// Deserialize in the same order as serialization
deserialize_value(&data, &length, &mType);
......@@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone()
const noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron clone");
auto p = new EmbLayerNormVarSeqlenPluginMTron(
TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace clone");
auto p = new EmbLayerNormVarSeqlenPluginHFace(
mLayerName, mType, mBeta, mGamma, mIdsEmb_);
p->setPluginNamespace(mNamespace.c_str());
return p;
......@@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
int32_t const batchSize = inputDesc[0].dims.d[0] - 1;
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t S = 384;
......@@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
}
const float* beta = mBetaDev.get();
const float* gamma = mGammaDev.get();
int32_t** tem_inputs_ptr_dev;
cudaMalloc(reinterpret_cast<void**>(&tem_inputs_ptr_dev),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(tem_inputs_ptr_dev,
inputs,
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
int32_t* mIdsVocabSize_dev;
cudaMalloc(reinterpret_cast<void**>(&mIdsVocabSize_dev),
sizeof(int32_t) * mIdsVocabSize.size());
cudaMemcpy(mIdsVocabSize_dev,
&(mIdsVocabSize[0]),
sizeof(int32_t) * mIdsVocabSize.size(),
cudaMemcpyHostToDevice);
if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]);
float** mIdsEmbDev_float;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_float),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_float,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormHFace<float>(stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_float,
mIdsVocabSize_dev,
output);
if (nbLookupTables_ == 2) {
return embSkipLayerNormHFace_2<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormHFace_3<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormHFace_4<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
static_cast<float const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]);
half** mIdsEmbDev_half;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_half),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_half,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormHFace<half>(stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_half,
mIdsVocabSize_dev,
output);
if (nbLookupTables_ == 2) {
return embSkipLayerNormHFace_2<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormHFace_3<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormHFace_4<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
static_cast<half const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported type error, expected [kHALF,kFLOAT]"));
......@@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
int32_t const batchSize = inputDesc[0].dims.d[0] - 1;
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t S = 384;
......@@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
}
const float* beta = mBetaDev.get();
const float* gamma = mGammaDev.get();
int32_t** tem_inputs_ptr_dev;
cudaMalloc(reinterpret_cast<void**>(&tem_inputs_ptr_dev),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(tem_inputs_ptr_dev,
inputs,
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
int32_t* mIdsVocabSize_dev;
cudaMalloc(reinterpret_cast<void**>(&mIdsVocabSize_dev),
sizeof(int32_t) * mIdsVocabSize.size());
cudaMemcpy(mIdsVocabSize_dev,
&(mIdsVocabSize[0]),
sizeof(int32_t) * mIdsVocabSize.size(),
cudaMemcpyHostToDevice);
if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]);
auto skip = static_cast<float*>(outputs[1]);
float** mIdsEmbDev_float;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_float),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_float,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormMTron<float>(stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_float,
mIdsVocabSize_dev,
output,
skip);
if (nbLookupTables_ == 2) {
return embSkipLayerNormMTron_2<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output,
skip);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormMTron_3<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output,
skip);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormMTron_4<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
static_cast<float const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output,
skip);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]);
auto skip = static_cast<half*>(outputs[1]);
half** mIdsEmbDev_half;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_half),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_half,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormMTron<half>(stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_half,
mIdsVocabSize_dev,
output,
skip);
if (nbLookupTables_ == 2) {
return embSkipLayerNormMTron_2<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output,
skip);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormMTron_3<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output,
skip);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormMTron_4<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
static_cast<half const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output,
skip);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported type error, expected [kHALF,kFLOAT]"));
......@@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept {
size_t const wordSize = getElementSize(mType);
serFromDev(&d, mBetaDev.get(), mLd);
serFromDev(&d, mGammaDev.get(), mLd);
for (size_t i = 0; i < mIdsEmbDev.size(); ++i) {
for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
serFromDev(&d,
static_cast<char*>(mIdsEmbDev[i]),
static_cast<char*>(mIdsEmbPtrs[i]),
mLd * mIdsVocabSize[i] * wordSize);
}
}
......@@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
mBetaDev.reset(nullptr);
mGammaDev.reset(nullptr);
for (size_t i = 0; i < mIdsEmbDev.size(); ++i) {
cudaFree(mIdsEmbDev[i]);
for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
cudaFree(mIdsEmbPtrs[i]);
}
delete this;
}
......@@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
beta,
gamma,
IdsEmb);
return p;
}
......
......@@ -31,32 +31,121 @@ namespace tensorrt {
namespace plugin {
template <typename T>
int32_t embSkipLayerNormHFace(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t** inputIds,
int32_t const nbLookupTables,
float const* beta,
float const* gamma,
T** idsEmb,
int32_t*,
T* output);
int32_t embSkipLayerNormHFace_2(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormMTron(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t** inputIds,
int32_t const nbLookupTables,
float const* beta,
float const* gamma,
T** idsEmb,
int32_t*,
T* output,
T* skip);
int32_t embSkipLayerNormHFace_3(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormHFace_4(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_2(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
int32_t,
int32_t,
T*,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_3(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
T*,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_4(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
int32_t,
T*,
T*);
class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
public:
EmbLayerNormVarSeqlenPluginBase(
......@@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
std::string mNamespace;
cuda_unique_ptr<float> mGammaDev;
cuda_unique_ptr<float> mBetaDev;
std::vector<void*> mIdsEmbDev;
std::vector<void*> mIdsEmbPtrs;
// std::vector<void*> mIdsEmbDev;
size_t mLd; // leading dim = hidden size
std::vector<int32_t> mIdsVocabSize;
WeightsWithOwnership mBeta;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册