From 6512e087651dca148640bbe8c3738e3c862e589a Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 10 Oct 2022 14:34:50 +0800 Subject: [PATCH] [Paddle Inference]fix embedding fused (#46789) * fix embedding fused --- .../tensorrt/convert/emb_eltwise_layernorm.cc | 6 +- .../convert/preln_emb_eltwise_layernorm.cc | 2 +- ...any_emb_Layernorm_varseqlen_kernelHFace.cu | 464 ++++++++++++++--- ...any_emb_Layernorm_varseqlen_kernelMTron.cu | 490 +++++++++++++++--- .../many_emb_layernorm_varseqlen_plugin.cu | 378 ++++++++++---- .../many_emb_layernorm_varseqlen_plugin.h | 138 ++++- 6 files changed, 1203 insertions(+), 275 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index 24dbd8a0e1..18b2ae7546 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 auto creator = GetPluginRegistry()->getPluginCreator( - "ManyEmbLayerNormPluginDynamic", "2"); + "ManyEmbLayerNormPluginDynamic", "1"); auto plugin_obj = creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); - plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V2(Output: " + + plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V1(Output: " + op_desc.Output("Out")[0] + ")") .c_str()); free(plugin_ptr); @@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { layer = plugin_layer; auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, - "ManyEmbLayerNormPluginDynamic_V2", + "ManyEmbLayerNormPluginDynamic_V1", {output_name, std::string("qkv_plugin_mask")}, test_mode); } diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index 47992536c4..1c227c0cf7 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 auto creator = GetPluginRegistry()->getPluginCreator( - "ManyEmbLayerNormPluginDynamic", "3"); + "ManyEmbLayerNormPluginDynamic", "2"); auto plugin_obj = creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu index fd7ab67d0a..1a23755000 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu @@ -30,20 +30,22 @@ namespace tensorrt { namespace plugin { template -__global__ void embLayerNormKernelHFace(int32_t ld, - int32_t** inputIds, - int32_t const nbLookupTables, - float const* beta, - float const* gamma, - T** mIdsEmbDev, - int32_t* IdsSize, - T* output) { +__global__ void embLayerNormKernelHFace_2(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + int32_t IdsSize0, + int32_t IdsSize1, + T* output) { cub::Sum pairSum; int32_t const s = blockIdx.x; int32_t const b = blockIdx.y; - int32_t* cuSeqlens = inputIds[0]; - int32_t const sumS = cuSeqlens[b]; - int32_t const s_b = cuSeqlens[b + 1] - sumS; + int32_t const sumS = inputIds0[b]; + int32_t const s_b = inputIds0[b + 1] - sumS; if (s >= s_b) { return; // This CTA has nothing to do } @@ -52,16 +54,14 @@ __global__ void embLayerNormKernelHFace(int32_t ld, extern __shared__ int32_t word_id[]; if (threadIdx.x == 0) { - for (int i = 1; i < nbLookupTables; ++i) { - if (static_cast(inputIds[i])[seqPos] < 0 || - static_cast(inputIds[i])[seqPos] >= IdsSize[i]) { - printf( - "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " - "table: ID < 0 or ID > max "); - return; - } else { - word_id[i - 1] = static_cast(inputIds[i])[seqPos]; - } + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds1)[seqPos]; } } __syncthreads(); @@ -74,12 +74,173 @@ __global__ void embLayerNormKernelHFace(int32_t ld, kvp threadData(0, 0); for (int32_t it = threadIdx.x; it < ld; it += TPB) { - T p(mIdsEmbDev[0][poffset + it]); // pos id + T p(mIdsEmbDev0[poffset + it]); // pos id T val = p; - for (int i = 1; i < nbLookupTables; ++i) { - int32_t const offset = word_id[i - 1] * ld; - val += mIdsEmbDev[i][offset + it]; + int32_t const offset = word_id[0] * ld; + val += mIdsEmbDev1[offset + it]; + output[outOffset + it] = val; + + T const rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + + // 3. layer norm on the sum + layerNorm(threadData, ld, outOffset, beta, gamma, output); +} + +template +__global__ void embLayerNormKernelHFace_3(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output) { + cub::Sum pairSum; + int32_t const s = blockIdx.x; + int32_t const b = blockIdx.y; + int32_t const sumS = inputIds0[b]; + int32_t const s_b = inputIds0[b + 1] - sumS; + if (s >= s_b) { + return; // This CTA has nothing to do + } + T const rld = T(1.f) / T(ld); + int32_t const seqPos = sumS + s; + extern __shared__ int32_t word_id[]; + + if (threadIdx.x == 0) { + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds1)[seqPos]; + } + + if (static_cast(inputIds2)[seqPos] < 0 || + static_cast(inputIds2)[seqPos] >= IdsSize2) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds2)[seqPos]; + } + } + __syncthreads(); + + // 2. load pos/tok/word embeddings and add them toghether + // offset into embeddings is given by wordId * hidden_size + int32_t const poffset = blockIdx.x * ld; + int32_t const outOffset = seqPos * ld; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + kvp threadData(0, 0); + + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + T p(mIdsEmbDev0[poffset + it]); // pos id + T val = p; + int32_t const offset0 = word_id[0] * ld; + val += mIdsEmbDev1[offset0 + it]; + int32_t const offset1 = word_id[1] * ld; + val += mIdsEmbDev2[offset1 + it]; + output[outOffset + it] = val; + + T const rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + + // 3. layer norm on the sum + layerNorm(threadData, ld, outOffset, beta, gamma, output); +} + +template +__global__ void embLayerNormKernelHFace_4(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + int32_t const* inputIds3, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + T const* mIdsEmbDev3, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + int32_t IdsSize3, + T* output) { + cub::Sum pairSum; + int32_t const s = blockIdx.x; + int32_t const b = blockIdx.y; + int32_t const sumS = inputIds0[b]; + int32_t const s_b = inputIds0[b + 1] - sumS; + if (s >= s_b) { + return; // This CTA has nothing to do + } + T const rld = T(1.f) / T(ld); + int32_t const seqPos = sumS + s; + extern __shared__ int32_t word_id[]; + + if (threadIdx.x == 0) { + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds1)[seqPos]; } + + if (static_cast(inputIds2)[seqPos] < 0 || + static_cast(inputIds2)[seqPos] >= IdsSize2) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds2)[seqPos]; + } + + if (static_cast(inputIds3)[seqPos] < 0 || + static_cast(inputIds3)[seqPos] >= IdsSize3) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[2] = static_cast(inputIds3)[seqPos]; + } + } + __syncthreads(); + + // 2. load pos/tok/word embeddings and add them toghether + // offset into embeddings is given by wordId * hidden_size + int32_t const poffset = blockIdx.x * ld; + int32_t const outOffset = seqPos * ld; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + kvp threadData(0, 0); + + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + T p(mIdsEmbDev0[poffset + it]); // pos id + T val = p; + int32_t const offset0 = word_id[0] * ld; + val += mIdsEmbDev1[offset0 + it]; + int32_t const offset1 = word_id[1] * ld; + val += mIdsEmbDev2[offset1 + it]; + int32_t const offset2 = word_id[2] * ld; + val += mIdsEmbDev3[offset2 + it]; output[outOffset + it] = val; T const rldval = rld * val; @@ -89,52 +250,233 @@ __global__ void embLayerNormKernelHFace(int32_t ld, // 3. layer norm on the sum layerNorm(threadData, ld, outOffset, beta, gamma, output); } +template +int32_t embSkipLayerNormHFace_2(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int const* inputIds0, + int const* inputIds1, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + int32_t IdsSize0, + int32_t IdsSize1, + T* output) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1); + embLayerNormKernelHFace_2 + <<>>(ld, + inputIds0, + inputIds1, + nbLookupTables, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + IdsSize0, + IdsSize1, + output); + return cudaPeekAtLastError(); +} template -int32_t embSkipLayerNormHFace(cudaStream_t stream, - int32_t ld, - int32_t B, - int32_t S, - int32_t** inputIds, - int32_t const nbLookupTables, - float const* beta, - float const* gamma, - T** mIdsEmbDev, - int32_t* IdsSize, - T* output) { +int32_t embSkipLayerNormHFace_3(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int const* inputIds0, + int const* inputIds1, + int const* inputIds2, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output) { constexpr int32_t tpb = 256; dim3 const grid(S, B, 1); dim3 const block(tpb, 1, 1); size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1); - embLayerNormKernelHFace<<>>( - ld, inputIds, nbLookupTables, beta, gamma, mIdsEmbDev, IdsSize, output); + embLayerNormKernelHFace_3 + <<>>(ld, + inputIds0, + inputIds1, + inputIds2, + nbLookupTables, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + mIdsEmbDev2, + IdsSize0, + IdsSize1, + IdsSize2, + output); return cudaPeekAtLastError(); } -template int32_t embSkipLayerNormHFace(cudaStream_t, - int32_t, - int32_t, - int32_t, - int32_t**, - int32_t const, - float const*, - float const*, - float**, - int32_t*, - float*); - -template int32_t embSkipLayerNormHFace(cudaStream_t, - int32_t, - int32_t, - int32_t, - int32_t**, - int32_t const, - float const*, - float const*, - half**, - int32_t*, - half*); +template +int32_t embSkipLayerNormHFace_4(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int const* inputIds0, + int const* inputIds1, + int const* inputIds2, + int const* inputIds3, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + T const* mIdsEmbDev3, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + int32_t IdsSize3, + T* output) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1); + embLayerNormKernelHFace_4 + <<>>(ld, + inputIds0, + inputIds1, + inputIds2, + inputIds3, + nbLookupTables, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + mIdsEmbDev2, + mIdsEmbDev3, + IdsSize0, + IdsSize1, + IdsSize2, + IdsSize3, + output); + return cudaPeekAtLastError(); +} + +template int32_t embSkipLayerNormHFace_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + float*); + +template int32_t embSkipLayerNormHFace_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + int32_t, + float*); + +template int32_t embSkipLayerNormHFace_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + int32_t, + int32_t, + float*); + +template int32_t embSkipLayerNormHFace_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + int32_t, + int32_t, + half*); + +template int32_t embSkipLayerNormHFace_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + half const*, + int32_t, + int32_t, + int32_t, + half*); +template int32_t embSkipLayerNormHFace_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + half const*, + half const*, + int32_t, + int32_t, + int32_t, + int32_t, + half*); } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu index cd69a1ba37..acdf9cc5a2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu @@ -30,61 +30,135 @@ namespace tensorrt { namespace plugin { template -__global__ void embLayerNormKernelMTron(int32_t ld, - int32_t** inputIds, - int32_t const nbLookupTables, - float const* beta, - float const* gamma, - T** mIdsEmbDev, - int32_t* IdsSize, - T* output, - T* skip) { +__global__ void embLayerNormKernelMTron_2(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + int32_t IdsSize0, + int32_t IdsSize1, + T* output, + T* skip) { cub::Sum pairSum; int32_t const s = blockIdx.x; int32_t const b = blockIdx.y; - int32_t* cuSeqlens = inputIds[0]; - int32_t const sumS = cuSeqlens[b]; - int32_t const s_b = cuSeqlens[b + 1] - sumS; + int32_t const sumS = inputIds0[b]; + int32_t const s_b = inputIds0[b + 1] - sumS; if (s >= s_b) { return; // This CTA has nothing to do } T const rld = T(1.f) / T(ld); - int32_t const seqPos = sumS + s; + const int32_t seqPos = sumS + s; extern __shared__ int32_t word_id[]; if (threadIdx.x == 0) { - for (int i = 1; i < nbLookupTables; ++i) { - if (static_cast(inputIds[i])[seqPos] < 0 || - static_cast(inputIds[i])[seqPos] >= IdsSize[i]) { - printf( - "Error !!!!!!!!!!!!!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot " - "be lookup table: ID < 0 or ID > max "); - return; - } else { - word_id[i - 1] = static_cast(inputIds[i])[seqPos]; - } + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds1)[seqPos]; } } __syncthreads(); // 2. load pos/tok/word embeddings and add them toghether // offset into embeddings is given by wordId * hidden_size - int32_t const poffset = blockIdx.x * ld; - int32_t const outOffset = seqPos * ld; + const int32_t poffset = blockIdx.x * ld; + const int32_t outOffset = seqPos * ld; // the output offset is given by b * (S*hidden_size) + s * hidden_size kvp threadData(0, 0); for (int32_t it = threadIdx.x; it < ld; it += TPB) { - T p(mIdsEmbDev[0][poffset + it]); // pos id + T p(mIdsEmbDev0[poffset + it]); // pos id T val = p; - for (int i = 1; i < nbLookupTables; ++i) { - int32_t const offset = word_id[i - 1] * ld; - val += mIdsEmbDev[i][offset + it]; + const int32_t offset = word_id[0] * ld; + val += mIdsEmbDev1[offset + it]; + output[outOffset + it] = val; + skip[outOffset + it] = val; + + const T rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + + // 3. layer norm on the sum + layerNorm(threadData, ld, outOffset, beta, gamma, output); +} + +template +__global__ void embLayerNormKernelMTron_3(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output, + T* skip) { + cub::Sum pairSum; + const int32_t s = blockIdx.x; + const int32_t b = blockIdx.y; + const int32_t sumS = inputIds0[b]; + const int32_t s_b = inputIds0[b + 1] - sumS; + if (s >= s_b) { + return; // This CTA has nothing to do + } + const T rld = T(1.f) / T(ld); + const int32_t seqPos = sumS + s; + extern __shared__ int32_t word_id[]; + + if (threadIdx.x == 0) { + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds1)[seqPos]; } + + if (static_cast(inputIds2)[seqPos] < 0 || + static_cast(inputIds2)[seqPos] >= IdsSize2) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds2)[seqPos]; + } + } + __syncthreads(); + + // 2. load pos/tok/word embeddings and add them toghether + // offset into embeddings is given by wordId * hidden_size + const int32_t poffset = blockIdx.x * ld; + const int32_t outOffset = seqPos * ld; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + kvp threadData(0, 0); + + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + T p(mIdsEmbDev0[poffset + it]); // pos id + T val = p; + const int32_t offset0 = word_id[0] * ld; + val += mIdsEmbDev1[offset0 + it]; + const int32_t offset1 = word_id[1] * ld; + val += mIdsEmbDev2[offset1 + it]; output[outOffset + it] = val; skip[outOffset + it] = val; - T const rldval = rld * val; + const T rldval = rld * val; threadData = pairSum(threadData, kvp(rldval, rldval * val)); } @@ -92,61 +166,335 @@ __global__ void embLayerNormKernelMTron(int32_t ld, layerNorm(threadData, ld, outOffset, beta, gamma, output); } +template +__global__ void embLayerNormKernelMTron_4(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + int32_t const* inputIds3, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + T const* mIdsEmbDev3, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + int32_t IdsSize3, + T* output, + T* skip) { + cub::Sum pairSum; + const int32_t s = blockIdx.x; + const int32_t b = blockIdx.y; + const int32_t sumS = inputIds0[b]; + const int32_t s_b = inputIds0[b + 1] - sumS; + if (s >= s_b) { + return; // This CTA has nothing to do + } + const T rld = T(1.f) / T(ld); + const int32_t seqPos = sumS + s; + extern __shared__ int32_t word_id[]; + + if (threadIdx.x == 0) { + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds1)[seqPos]; + } + + if (static_cast(inputIds2)[seqPos] < 0 || + static_cast(inputIds2)[seqPos] >= IdsSize2) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds2)[seqPos]; + } + + if (static_cast(inputIds3)[seqPos] < 0 || + static_cast(inputIds3)[seqPos] >= IdsSize3) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[2] = static_cast(inputIds3)[seqPos]; + } + } + __syncthreads(); + + // 2. load pos/tok/word embeddings and add them toghether + // offset into embeddings is given by wordId * hidden_size + const int32_t poffset = blockIdx.x * ld; + const int32_t outOffset = seqPos * ld; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + kvp threadData(0, 0); + + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + T p(mIdsEmbDev0[poffset + it]); // pos id + T val = p; + const int32_t offset0 = word_id[0] * ld; + val += mIdsEmbDev1[offset0 + it]; + const int32_t offset1 = word_id[1] * ld; + val += mIdsEmbDev2[offset1 + it]; + const int32_t offset2 = word_id[2] * ld; + val += mIdsEmbDev3[offset2 + it]; + output[outOffset + it] = val; + skip[outOffset + it] = val; + + const T rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + + // 3. layer norm on the sum + layerNorm(threadData, ld, outOffset, beta, gamma, output); +} template -int32_t embSkipLayerNormMTron(cudaStream_t stream, - int32_t ld, - int32_t B, - int32_t S, - int32_t** inputIds, - int32_t const nbLookupTables, - float const* beta, - float const* gamma, - T** mIdsEmbDev, - int32_t* IdsSize, - T* output, - T* skip) { +int32_t embSkipLayerNormMTron_2(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + int32_t IdsSize0, + int32_t IdsSize1, + T* output, + T* skip) { constexpr int32_t tpb = 256; dim3 const grid(S, B, 1); dim3 const block(tpb, 1, 1); size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1); - embLayerNormKernelMTron + embLayerNormKernelMTron_2 <<>>(ld, - inputIds, + inputIds0, + inputIds1, nbLookupTables, beta, gamma, - mIdsEmbDev, - IdsSize, + mIdsEmbDev0, + mIdsEmbDev1, + IdsSize0, + IdsSize1, output, skip); return cudaPeekAtLastError(); } -template int32_t embSkipLayerNormMTron(cudaStream_t, - int32_t, - int32_t, - int32_t, - int32_t**, - int32_t const, - float const*, - float const*, - float**, - int32_t*, - float*, - float*); - -template int32_t embSkipLayerNormMTron(cudaStream_t, - int32_t, - int32_t, - int32_t, - int32_t**, - int32_t const, - float const*, - float const*, - half**, - int32_t*, - half*, - half*); +template +int32_t embSkipLayerNormMTron_3(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output, + T* skip) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1); + embLayerNormKernelMTron_3 + <<>>(ld, + inputIds0, + inputIds1, + inputIds2, + nbLookupTables, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + mIdsEmbDev2, + IdsSize0, + IdsSize1, + IdsSize2, + output, + skip); + return cudaPeekAtLastError(); +} + +template +int32_t embSkipLayerNormMTron_4(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + int32_t const* inputIds3, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + T const* mIdsEmbDev3, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + int32_t IdsSize3, + T* output, + T* skip) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1); + embLayerNormKernelMTron_4 + <<>>(ld, + inputIds0, + inputIds1, + inputIds2, + inputIds3, + nbLookupTables, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + mIdsEmbDev2, + mIdsEmbDev3, + IdsSize0, + IdsSize1, + IdsSize2, + IdsSize3, + output, + skip); + return cudaPeekAtLastError(); +} + +template int32_t embSkipLayerNormMTron_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + float*, + float*); + +template int32_t embSkipLayerNormMTron_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + int32_t, + float*, + float*); + +template int32_t embSkipLayerNormMTron_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + int32_t, + int32_t, + float*, + float*); + +template int32_t embSkipLayerNormMTron_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + int32_t, + int32_t, + half*, + half*); + +template int32_t embSkipLayerNormMTron_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + half const*, + int32_t, + int32_t, + int32_t, + half*, + half*); + +template int32_t embSkipLayerNormMTron_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + half const*, + half const*, + int32_t, + int32_t, + int32_t, + int32_t, + half*, + half*); } // namespace plugin } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu index 9601f97f7d..6a8b39d113 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu @@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24; constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256; constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; -char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"2"}; -char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"3"}; +char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"}; +char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"2"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"ManyEmbLayerNormPluginDynamic"}; // Static class fields initialization nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{}; @@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( tem_weight.values, getWeightsSize(tem_weight, mType), cudaMemcpyHostToDevice)); - mIdsEmbDev.push_back(cudaMem); + mIdsEmbPtrs.push_back(cudaMem); } } @@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase( : mLayerName(name), mGammaDev(nullptr), mBetaDev(nullptr), - mIdsEmbDev{}, + mIdsEmbPtrs{}, mIdsEmb_{} { // Deserialize in the same order as serialization deserialize_value(&data, &length, &mType); @@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron( // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone() const noexcept { - TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron clone"); - auto p = new EmbLayerNormVarSeqlenPluginMTron( + TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace clone"); + auto p = new EmbLayerNormVarSeqlenPluginHFace( mLayerName, mType, mBeta, mGamma, mIdsEmb_); p->setPluginNamespace(mNamespace.c_str()); return p; @@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue( void* const* outputs, void* workspace, cudaStream_t stream) noexcept { - int32_t const batchSize = inputDesc[0].dims.d[0] - 1; + int32_t batchSize = inputDesc[0].dims.d[0] - 1; // read out the maximum sequence length from the dummy input int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; int32_t S = 384; @@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue( } const float* beta = mBetaDev.get(); const float* gamma = mGammaDev.get(); - int32_t** tem_inputs_ptr_dev; - cudaMalloc(reinterpret_cast(&tem_inputs_ptr_dev), - sizeof(void*) * nbLookupTables_); - cudaMemcpy(tem_inputs_ptr_dev, - inputs, - sizeof(void*) * nbLookupTables_, - cudaMemcpyHostToDevice); - int32_t* mIdsVocabSize_dev; - cudaMalloc(reinterpret_cast(&mIdsVocabSize_dev), - sizeof(int32_t) * mIdsVocabSize.size()); - cudaMemcpy(mIdsVocabSize_dev, - &(mIdsVocabSize[0]), - sizeof(int32_t) * mIdsVocabSize.size(), - cudaMemcpyHostToDevice); if (mType == nvinfer1::DataType::kFLOAT) { auto output = static_cast(outputs[0]); - float** mIdsEmbDev_float; - cudaMalloc(reinterpret_cast(&mIdsEmbDev_float), - sizeof(void*) * nbLookupTables_); - cudaMemcpy(mIdsEmbDev_float, - &(mIdsEmbDev[0]), - sizeof(void*) * nbLookupTables_, - cudaMemcpyHostToDevice); - return embSkipLayerNormHFace(stream, - static_cast(mLd), - batchSize, - S, - tem_inputs_ptr_dev, - nbLookupTables_, - beta, - gamma, - mIdsEmbDev_float, - mIdsVocabSize_dev, - output); + if (nbLookupTables_ == 2) { + return embSkipLayerNormHFace_2( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + mIdsVocabSize[0], + mIdsVocabSize[1], + output); + } else if (nbLookupTables_ == 3) { + return embSkipLayerNormHFace_3( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + output); + } else if (nbLookupTables_ == 4) { + return embSkipLayerNormHFace_4( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + static_cast(inputs[3]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + static_cast(mIdsEmbPtrs[3]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + mIdsVocabSize[3], + output); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support 2,3,4 lookup_tables fused ")); + } } else if (mType == nvinfer1::DataType::kHALF) { auto output = static_cast(outputs[0]); - half** mIdsEmbDev_half; - cudaMalloc(reinterpret_cast(&mIdsEmbDev_half), - sizeof(void*) * nbLookupTables_); - cudaMemcpy(mIdsEmbDev_half, - &(mIdsEmbDev[0]), - sizeof(void*) * nbLookupTables_, - cudaMemcpyHostToDevice); - return embSkipLayerNormHFace(stream, - static_cast(mLd), - batchSize, - S, - tem_inputs_ptr_dev, - nbLookupTables_, - beta, - gamma, - mIdsEmbDev_half, - mIdsVocabSize_dev, - output); + if (nbLookupTables_ == 2) { + return embSkipLayerNormHFace_2( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + mIdsVocabSize[0], + mIdsVocabSize[1], + output); + } else if (nbLookupTables_ == 3) { + return embSkipLayerNormHFace_3( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + output); + } else if (nbLookupTables_ == 4) { + return embSkipLayerNormHFace_4( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + static_cast(inputs[3]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + static_cast(mIdsEmbPtrs[3]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + mIdsVocabSize[3], + output); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support 2,3,4 lookup_tables fused ")); + } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Unsupported type error, expected [kHALF,kFLOAT]")); @@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue( void* const* outputs, void* workspace, cudaStream_t stream) noexcept { - int32_t const batchSize = inputDesc[0].dims.d[0] - 1; + int32_t batchSize = inputDesc[0].dims.d[0] - 1; // read out the maximum sequence length from the dummy input int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; int32_t S = 384; @@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue( } const float* beta = mBetaDev.get(); const float* gamma = mGammaDev.get(); - int32_t** tem_inputs_ptr_dev; - cudaMalloc(reinterpret_cast(&tem_inputs_ptr_dev), - sizeof(void*) * nbLookupTables_); - cudaMemcpy(tem_inputs_ptr_dev, - inputs, - sizeof(void*) * nbLookupTables_, - cudaMemcpyHostToDevice); - int32_t* mIdsVocabSize_dev; - cudaMalloc(reinterpret_cast(&mIdsVocabSize_dev), - sizeof(int32_t) * mIdsVocabSize.size()); - cudaMemcpy(mIdsVocabSize_dev, - &(mIdsVocabSize[0]), - sizeof(int32_t) * mIdsVocabSize.size(), - cudaMemcpyHostToDevice); + if (mType == nvinfer1::DataType::kFLOAT) { auto output = static_cast(outputs[0]); auto skip = static_cast(outputs[1]); - float** mIdsEmbDev_float; - cudaMalloc(reinterpret_cast(&mIdsEmbDev_float), - sizeof(void*) * nbLookupTables_); - cudaMemcpy(mIdsEmbDev_float, - &(mIdsEmbDev[0]), - sizeof(void*) * nbLookupTables_, - cudaMemcpyHostToDevice); - return embSkipLayerNormMTron(stream, - static_cast(mLd), - batchSize, - S, - tem_inputs_ptr_dev, - nbLookupTables_, - beta, - gamma, - mIdsEmbDev_float, - mIdsVocabSize_dev, - output, - skip); + if (nbLookupTables_ == 2) { + return embSkipLayerNormMTron_2( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + mIdsVocabSize[0], + mIdsVocabSize[1], + output, + skip); + } else if (nbLookupTables_ == 3) { + return embSkipLayerNormMTron_3( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + output, + skip); + } else if (nbLookupTables_ == 4) { + return embSkipLayerNormMTron_4( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + static_cast(inputs[3]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + static_cast(mIdsEmbPtrs[3]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + mIdsVocabSize[3], + output, + skip); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support 2,3,4 lookup_tables fused ")); + } } else if (mType == nvinfer1::DataType::kHALF) { auto output = static_cast(outputs[0]); auto skip = static_cast(outputs[1]); - half** mIdsEmbDev_half; - cudaMalloc(reinterpret_cast(&mIdsEmbDev_half), - sizeof(void*) * nbLookupTables_); - cudaMemcpy(mIdsEmbDev_half, - &(mIdsEmbDev[0]), - sizeof(void*) * nbLookupTables_, - cudaMemcpyHostToDevice); - return embSkipLayerNormMTron(stream, - static_cast(mLd), - batchSize, - S, - tem_inputs_ptr_dev, - nbLookupTables_, - beta, - gamma, - mIdsEmbDev_half, - mIdsVocabSize_dev, - output, - skip); + if (nbLookupTables_ == 2) { + return embSkipLayerNormMTron_2( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + mIdsVocabSize[0], + mIdsVocabSize[1], + output, + skip); + } else if (nbLookupTables_ == 3) { + return embSkipLayerNormMTron_3( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + output, + skip); + } else if (nbLookupTables_ == 4) { + return embSkipLayerNormMTron_4( + stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + static_cast(inputs[3]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + static_cast(mIdsEmbPtrs[3]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + mIdsVocabSize[3], + output, + skip); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support 2,3,4 lookup_tables fused ")); + } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Unsupported type error, expected [kHALF,kFLOAT]")); @@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept { size_t const wordSize = getElementSize(mType); serFromDev(&d, mBetaDev.get(), mLd); serFromDev(&d, mGammaDev.get(), mLd); - for (size_t i = 0; i < mIdsEmbDev.size(); ++i) { + for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) { serFromDev(&d, - static_cast(mIdsEmbDev[i]), + static_cast(mIdsEmbPtrs[i]), mLd * mIdsVocabSize[i] * wordSize); } } @@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept { // This gets called when the network containing plugin is destroyed mBetaDev.reset(nullptr); mGammaDev.reset(nullptr); - for (size_t i = 0; i < mIdsEmbDev.size(); ++i) { - cudaFree(mIdsEmbDev[i]); + for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) { + cudaFree(mIdsEmbPtrs[i]); } delete this; } @@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin( beta, gamma, IdsEmb); - return p; } diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h index 2886f800a6..944f3abb9d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h @@ -31,32 +31,121 @@ namespace tensorrt { namespace plugin { template -int32_t embSkipLayerNormHFace(cudaStream_t stream, - int32_t ld, - int32_t B, - int32_t S, - int32_t** inputIds, - int32_t const nbLookupTables, - float const* beta, - float const* gamma, - T** idsEmb, - int32_t*, - T* output); +int32_t embSkipLayerNormHFace_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + int32_t, + int32_t, + T*); template -int32_t embSkipLayerNormMTron(cudaStream_t stream, - int32_t ld, - int32_t B, - int32_t S, - int32_t** inputIds, - int32_t const nbLookupTables, - float const* beta, - float const* gamma, - T** idsEmb, - int32_t*, - T* output, - T* skip); +int32_t embSkipLayerNormHFace_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + T const*, + int32_t, + int32_t, + int32_t, + T*); +template +int32_t embSkipLayerNormHFace_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + T const*, + T const*, + int32_t, + int32_t, + int32_t, + int32_t, + T*); + +template +int32_t embSkipLayerNormMTron_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + int32_t, + int32_t, + T*, + T*); + +template +int32_t embSkipLayerNormMTron_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + T const*, + int32_t, + int32_t, + int32_t, + T*, + T*); + +template +int32_t embSkipLayerNormMTron_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + T const*, + T const*, + int32_t, + int32_t, + int32_t, + int32_t, + T*, + T*); class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt { public: EmbLayerNormVarSeqlenPluginBase( @@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt { std::string mNamespace; cuda_unique_ptr mGammaDev; cuda_unique_ptr mBetaDev; - std::vector mIdsEmbDev; + std::vector mIdsEmbPtrs; + // std::vector mIdsEmbDev; size_t mLd; // leading dim = hidden size std::vector mIdsVocabSize; WeightsWithOwnership mBeta; -- GitLab