未验证 提交 6512e087 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]fix embedding fused (#46789)

* fix embedding fused
上级 ae6b4713
......@@ -210,14 +210,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "2");
"ManyEmbLayerNormPluginDynamic", "1");
auto plugin_obj =
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V2(Output: " +
plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V1(Output: " +
op_desc.Output("Out")[0] + ")")
.c_str());
free(plugin_ptr);
......@@ -248,7 +248,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer,
"ManyEmbLayerNormPluginDynamic_V2",
"ManyEmbLayerNormPluginDynamic_V1",
{output_name, std::string("qkv_plugin_mask")},
test_mode);
}
......
......@@ -194,7 +194,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
"max_seqlen_tensor")); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"ManyEmbLayerNormPluginDynamic", "3");
"ManyEmbLayerNormPluginDynamic", "2");
auto plugin_obj =
creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr);
......
......@@ -30,20 +30,22 @@ namespace tensorrt {
namespace plugin {
template <typename T, unsigned TPB>
__global__ void embLayerNormKernelHFace(int32_t ld,
int32_t** inputIds,
int32_t const nbLookupTables,
__global__ void embLayerNormKernelHFace_2(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T** mIdsEmbDev,
int32_t* IdsSize,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
int32_t IdsSize0,
int32_t IdsSize1,
T* output) {
cub::Sum pairSum;
int32_t const s = blockIdx.x;
int32_t const b = blockIdx.y;
int32_t* cuSeqlens = inputIds[0];
int32_t const sumS = cuSeqlens[b];
int32_t const s_b = cuSeqlens[b + 1] - sumS;
int32_t const sumS = inputIds0[b];
int32_t const s_b = inputIds0[b + 1] - sumS;
if (s >= s_b) {
return; // This CTA has nothing to do
}
......@@ -52,17 +54,87 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
for (int i = 1; i < nbLookupTables; ++i) {
if (static_cast<int32_t const*>(inputIds[i])[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds[i])[seqPos] >= IdsSize[i]) {
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[i - 1] = static_cast<int32_t const*>(inputIds[i])[seqPos];
word_id[0] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
}
__syncthreads();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
int32_t const poffset = blockIdx.x * ld;
int32_t const outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
T p(mIdsEmbDev0[poffset + it]); // pos id
T val = p;
int32_t const offset = word_id[0] * ld;
val += mIdsEmbDev1[offset + it];
output[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// 3. layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T, unsigned TPB>
__global__ void embLayerNormKernelHFace_3(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
T* output) {
cub::Sum pairSum;
int32_t const s = blockIdx.x;
int32_t const b = blockIdx.y;
int32_t const sumS = inputIds0[b];
int32_t const s_b = inputIds0[b + 1] - sumS;
if (s >= s_b) {
return; // This CTA has nothing to do
}
T const rld = T(1.f) / T(ld);
int32_t const seqPos = sumS + s;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[0] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
if (static_cast<int32_t const*>(inputIds2)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds2)[seqPos] >= IdsSize2) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[1] = static_cast<int32_t const*>(inputIds2)[seqPos];
}
}
__syncthreads();
......@@ -74,12 +146,101 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
T p(mIdsEmbDev[0][poffset + it]); // pos id
T p(mIdsEmbDev0[poffset + it]); // pos id
T val = p;
for (int i = 1; i < nbLookupTables; ++i) {
int32_t const offset = word_id[i - 1] * ld;
val += mIdsEmbDev[i][offset + it];
int32_t const offset0 = word_id[0] * ld;
val += mIdsEmbDev1[offset0 + it];
int32_t const offset1 = word_id[1] * ld;
val += mIdsEmbDev2[offset1 + it];
output[outOffset + it] = val;
T const rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// 3. layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T, unsigned TPB>
__global__ void embLayerNormKernelHFace_4(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
int32_t const* inputIds3,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
T const* mIdsEmbDev3,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
int32_t IdsSize3,
T* output) {
cub::Sum pairSum;
int32_t const s = blockIdx.x;
int32_t const b = blockIdx.y;
int32_t const sumS = inputIds0[b];
int32_t const s_b = inputIds0[b + 1] - sumS;
if (s >= s_b) {
return; // This CTA has nothing to do
}
T const rld = T(1.f) / T(ld);
int32_t const seqPos = sumS + s;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[0] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
if (static_cast<int32_t const*>(inputIds2)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds2)[seqPos] >= IdsSize2) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[1] = static_cast<int32_t const*>(inputIds2)[seqPos];
}
if (static_cast<int32_t const*>(inputIds3)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds3)[seqPos] >= IdsSize3) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[2] = static_cast<int32_t const*>(inputIds3)[seqPos];
}
}
__syncthreads();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
int32_t const poffset = blockIdx.x * ld;
int32_t const outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
T p(mIdsEmbDev0[poffset + it]); // pos id
T val = p;
int32_t const offset0 = word_id[0] * ld;
val += mIdsEmbDev1[offset0 + it];
int32_t const offset1 = word_id[1] * ld;
val += mIdsEmbDev2[offset1 + it];
int32_t const offset2 = word_id[2] * ld;
val += mIdsEmbDev3[offset2 + it];
output[outOffset + it] = val;
T const rldval = rld * val;
......@@ -89,52 +250,233 @@ __global__ void embLayerNormKernelHFace(int32_t ld,
// 3. layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T>
int32_t embSkipLayerNormHFace_2(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int const* inputIds0,
int const* inputIds1,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
int32_t IdsSize0,
int32_t IdsSize1,
T* output) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1);
embLayerNormKernelHFace_2<T, tpb>
<<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
nbLookupTables,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
IdsSize0,
IdsSize1,
output);
return cudaPeekAtLastError();
}
template <typename T>
int32_t embSkipLayerNormHFace_3(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int const* inputIds0,
int const* inputIds1,
int const* inputIds2,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
T* output) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1);
embLayerNormKernelHFace_3<T, tpb>
<<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
inputIds2,
nbLookupTables,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
mIdsEmbDev2,
IdsSize0,
IdsSize1,
IdsSize2,
output);
return cudaPeekAtLastError();
}
template <typename T>
int32_t embSkipLayerNormHFace(cudaStream_t stream,
int32_t embSkipLayerNormHFace_4(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t** inputIds,
int32_t const nbLookupTables,
int const* inputIds0,
int const* inputIds1,
int const* inputIds2,
int const* inputIds3,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T** mIdsEmbDev,
int32_t* IdsSize,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
T const* mIdsEmbDev3,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
int32_t IdsSize3,
T* output) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1);
embLayerNormKernelHFace<T, tpb><<<grid, block, cache_size, stream>>>(
ld, inputIds, nbLookupTables, beta, gamma, mIdsEmbDev, IdsSize, output);
embLayerNormKernelHFace_4<T, tpb>
<<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
inputIds2,
inputIds3,
nbLookupTables,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
mIdsEmbDev2,
mIdsEmbDev3,
IdsSize0,
IdsSize1,
IdsSize2,
IdsSize3,
output);
return cudaPeekAtLastError();
}
template int32_t embSkipLayerNormHFace<float>(cudaStream_t,
template int32_t embSkipLayerNormHFace_2<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
int32_t,
int32_t,
float*);
template int32_t embSkipLayerNormHFace_3<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t**,
int32_t const,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
float const*,
int32_t,
int32_t,
int32_t,
float*);
template int32_t embSkipLayerNormHFace_4<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
float const*,
float**,
int32_t*,
float const*,
int32_t,
int32_t,
int32_t,
int32_t,
float*);
template int32_t embSkipLayerNormHFace<half>(cudaStream_t,
template int32_t embSkipLayerNormHFace_2<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
int32_t,
int32_t,
half*);
template int32_t embSkipLayerNormHFace_3<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t**,
int32_t const,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half**,
int32_t*,
half const*,
half const*,
half const*,
int32_t,
int32_t,
int32_t,
half*);
template int32_t embSkipLayerNormHFace_4<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
half const*,
half const*,
int32_t,
int32_t,
int32_t,
int32_t,
half*);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -30,121 +30,469 @@ namespace tensorrt {
namespace plugin {
template <typename T, unsigned TPB>
__global__ void embLayerNormKernelMTron(int32_t ld,
int32_t** inputIds,
int32_t const nbLookupTables,
__global__ void embLayerNormKernelMTron_2(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T** mIdsEmbDev,
int32_t* IdsSize,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
int32_t IdsSize0,
int32_t IdsSize1,
T* output,
T* skip) {
cub::Sum pairSum;
int32_t const s = blockIdx.x;
int32_t const b = blockIdx.y;
int32_t* cuSeqlens = inputIds[0];
int32_t const sumS = cuSeqlens[b];
int32_t const s_b = cuSeqlens[b + 1] - sumS;
int32_t const sumS = inputIds0[b];
int32_t const s_b = inputIds0[b + 1] - sumS;
if (s >= s_b) {
return; // This CTA has nothing to do
}
T const rld = T(1.f) / T(ld);
int32_t const seqPos = sumS + s;
const int32_t seqPos = sumS + s;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
for (int i = 1; i < nbLookupTables; ++i) {
if (static_cast<int32_t const*>(inputIds[i])[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds[i])[seqPos] >= IdsSize[i]) {
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error !!!!!!!!!!!!!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot "
"be lookup table: ID < 0 or ID > max ");
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[i - 1] = static_cast<int32_t const*>(inputIds[i])[seqPos];
word_id[0] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
}
__syncthreads();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
const int32_t poffset = blockIdx.x * ld;
const int32_t outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
T p(mIdsEmbDev0[poffset + it]); // pos id
T val = p;
const int32_t offset = word_id[0] * ld;
val += mIdsEmbDev1[offset + it];
output[outOffset + it] = val;
skip[outOffset + it] = val;
const T rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// 3. layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T, unsigned TPB>
__global__ void embLayerNormKernelMTron_3(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
T* output,
T* skip) {
cub::Sum pairSum;
const int32_t s = blockIdx.x;
const int32_t b = blockIdx.y;
const int32_t sumS = inputIds0[b];
const int32_t s_b = inputIds0[b + 1] - sumS;
if (s >= s_b) {
return; // This CTA has nothing to do
}
const T rld = T(1.f) / T(ld);
const int32_t seqPos = sumS + s;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[0] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
if (static_cast<int32_t const*>(inputIds2)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds2)[seqPos] >= IdsSize2) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[1] = static_cast<int32_t const*>(inputIds2)[seqPos];
}
}
__syncthreads();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
int32_t const poffset = blockIdx.x * ld;
int32_t const outOffset = seqPos * ld;
const int32_t poffset = blockIdx.x * ld;
const int32_t outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
T p(mIdsEmbDev[0][poffset + it]); // pos id
T p(mIdsEmbDev0[poffset + it]); // pos id
T val = p;
for (int i = 1; i < nbLookupTables; ++i) {
int32_t const offset = word_id[i - 1] * ld;
val += mIdsEmbDev[i][offset + it];
const int32_t offset0 = word_id[0] * ld;
val += mIdsEmbDev1[offset0 + it];
const int32_t offset1 = word_id[1] * ld;
val += mIdsEmbDev2[offset1 + it];
output[outOffset + it] = val;
skip[outOffset + it] = val;
const T rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// 3. layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T, unsigned TPB>
__global__ void embLayerNormKernelMTron_4(int32_t ld,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
int32_t const* inputIds3,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
T const* mIdsEmbDev3,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
int32_t IdsSize3,
T* output,
T* skip) {
cub::Sum pairSum;
const int32_t s = blockIdx.x;
const int32_t b = blockIdx.y;
const int32_t sumS = inputIds0[b];
const int32_t s_b = inputIds0[b + 1] - sumS;
if (s >= s_b) {
return; // This CTA has nothing to do
}
const T rld = T(1.f) / T(ld);
const int32_t seqPos = sumS + s;
extern __shared__ int32_t word_id[];
if (threadIdx.x == 0) {
if (static_cast<int32_t const*>(inputIds1)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds1)[seqPos] >= IdsSize1) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[0] = static_cast<int32_t const*>(inputIds1)[seqPos];
}
if (static_cast<int32_t const*>(inputIds2)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds2)[seqPos] >= IdsSize2) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[1] = static_cast<int32_t const*>(inputIds2)[seqPos];
}
if (static_cast<int32_t const*>(inputIds3)[seqPos] < 0 ||
static_cast<int32_t const*>(inputIds3)[seqPos] >= IdsSize3) {
printf(
"Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
} else {
word_id[2] = static_cast<int32_t const*>(inputIds3)[seqPos];
}
}
__syncthreads();
// 2. load pos/tok/word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
const int32_t poffset = blockIdx.x * ld;
const int32_t outOffset = seqPos * ld;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
kvp<T> threadData(0, 0);
for (int32_t it = threadIdx.x; it < ld; it += TPB) {
T p(mIdsEmbDev0[poffset + it]); // pos id
T val = p;
const int32_t offset0 = word_id[0] * ld;
val += mIdsEmbDev1[offset0 + it];
const int32_t offset1 = word_id[1] * ld;
val += mIdsEmbDev2[offset1 + it];
const int32_t offset2 = word_id[2] * ld;
val += mIdsEmbDev3[offset2 + it];
output[outOffset + it] = val;
skip[outOffset + it] = val;
T const rldval = rld * val;
const T rldval = rld * val;
threadData = pairSum(threadData, kvp<T>(rldval, rldval * val));
}
// 3. layer norm on the sum
layerNorm<T, T, float, TPB>(threadData, ld, outOffset, beta, gamma, output);
}
template <typename T>
int32_t embSkipLayerNormMTron_2(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
int32_t IdsSize0,
int32_t IdsSize1,
T* output,
T* skip) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1);
embLayerNormKernelMTron_2<T, tpb>
<<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
nbLookupTables,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
IdsSize0,
IdsSize1,
output,
skip);
return cudaPeekAtLastError();
}
template <typename T>
int32_t embSkipLayerNormMTron_3(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
T* output,
T* skip) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1);
embLayerNormKernelMTron_3<T, tpb>
<<<grid, block, cache_size, stream>>>(ld,
inputIds0,
inputIds1,
inputIds2,
nbLookupTables,
beta,
gamma,
mIdsEmbDev0,
mIdsEmbDev1,
mIdsEmbDev2,
IdsSize0,
IdsSize1,
IdsSize2,
output,
skip);
return cudaPeekAtLastError();
}
template <typename T>
int32_t embSkipLayerNormMTron(cudaStream_t stream,
int32_t embSkipLayerNormMTron_4(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t** inputIds,
int32_t const nbLookupTables,
int32_t const* inputIds0,
int32_t const* inputIds1,
int32_t const* inputIds2,
int32_t const* inputIds3,
int32_t nbLookupTables,
float const* beta,
float const* gamma,
T** mIdsEmbDev,
int32_t* IdsSize,
T const* mIdsEmbDev0,
T const* mIdsEmbDev1,
T const* mIdsEmbDev2,
T const* mIdsEmbDev3,
int32_t IdsSize0,
int32_t IdsSize1,
int32_t IdsSize2,
int32_t IdsSize3,
T* output,
T* skip) {
constexpr int32_t tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1);
embLayerNormKernelMTron<T, tpb>
embLayerNormKernelMTron_4<T, tpb>
<<<grid, block, cache_size, stream>>>(ld,
inputIds,
inputIds0,
inputIds1,
inputIds2,
inputIds3,
nbLookupTables,
beta,
gamma,
mIdsEmbDev,
IdsSize,
mIdsEmbDev0,
mIdsEmbDev1,
mIdsEmbDev2,
mIdsEmbDev3,
IdsSize0,
IdsSize1,
IdsSize2,
IdsSize3,
output,
skip);
return cudaPeekAtLastError();
}
template int32_t embSkipLayerNormMTron<float>(cudaStream_t,
template int32_t embSkipLayerNormMTron_2<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t**,
int32_t const,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
float**,
int32_t*,
int32_t,
int32_t,
float*,
float*);
template int32_t embSkipLayerNormMTron<half>(cudaStream_t,
template int32_t embSkipLayerNormMTron_3<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
int32_t**,
int32_t const,
float const*,
float const*,
half**,
int32_t*,
float const*,
float const*,
float const*,
int32_t,
int32_t,
int32_t,
float*,
float*);
template int32_t embSkipLayerNormMTron_4<float>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
float const*,
float const*,
float const*,
float const*,
int32_t,
int32_t,
int32_t,
int32_t,
float*,
float*);
template int32_t embSkipLayerNormMTron_2<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
int32_t,
int32_t,
half*,
half*);
template int32_t embSkipLayerNormMTron_3<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
half const*,
int32_t,
int32_t,
int32_t,
half*,
half*);
template int32_t embSkipLayerNormMTron_4<half>(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
half const*,
half const*,
half const*,
half const*,
int32_t,
int32_t,
int32_t,
int32_t,
half*,
half*);
......
......@@ -37,8 +37,8 @@ constexpr size_t xmmasM384 = 24;
constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256;
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384;
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"2"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"3"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"2"};
char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"ManyEmbLayerNormPluginDynamic"};
// Static class fields initialization
nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{};
......@@ -74,7 +74,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
tem_weight.values,
getWeightsSize(tem_weight, mType),
cudaMemcpyHostToDevice));
mIdsEmbDev.push_back(cudaMem);
mIdsEmbPtrs.push_back(cudaMem);
}
}
......@@ -83,7 +83,7 @@ EmbLayerNormVarSeqlenPluginBase::EmbLayerNormVarSeqlenPluginBase(
: mLayerName(name),
mGammaDev(nullptr),
mBetaDev(nullptr),
mIdsEmbDev{},
mIdsEmbPtrs{},
mIdsEmb_{} {
// Deserialize in the same order as serialization
deserialize_value(&data, &length, &mType);
......@@ -141,8 +141,8 @@ EmbLayerNormVarSeqlenPluginMTron::EmbLayerNormVarSeqlenPluginMTron(
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* EmbLayerNormVarSeqlenPluginHFace::clone()
const noexcept {
TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginMTron clone");
auto p = new EmbLayerNormVarSeqlenPluginMTron(
TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenPluginHFace clone");
auto p = new EmbLayerNormVarSeqlenPluginHFace(
mLayerName, mType, mBeta, mGamma, mIdsEmb_);
p->setPluginNamespace(mNamespace.c_str());
return p;
......@@ -333,7 +333,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
int32_t const batchSize = inputDesc[0].dims.d[0] - 1;
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t S = 384;
......@@ -346,60 +346,132 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
}
const float* beta = mBetaDev.get();
const float* gamma = mGammaDev.get();
int32_t** tem_inputs_ptr_dev;
cudaMalloc(reinterpret_cast<void**>(&tem_inputs_ptr_dev),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(tem_inputs_ptr_dev,
inputs,
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
int32_t* mIdsVocabSize_dev;
cudaMalloc(reinterpret_cast<void**>(&mIdsVocabSize_dev),
sizeof(int32_t) * mIdsVocabSize.size());
cudaMemcpy(mIdsVocabSize_dev,
&(mIdsVocabSize[0]),
sizeof(int32_t) * mIdsVocabSize.size(),
cudaMemcpyHostToDevice);
if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]);
float** mIdsEmbDev_float;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_float),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_float,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormHFace<float>(stream,
if (nbLookupTables_ == 2) {
return embSkipLayerNormHFace_2<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormHFace_3<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_float,
mIdsVocabSize_dev,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormHFace_4<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
static_cast<float const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]);
half** mIdsEmbDev_half;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_half),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_half,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormHFace<half>(stream,
if (nbLookupTables_ == 2) {
return embSkipLayerNormHFace_2<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_half,
mIdsVocabSize_dev,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormHFace_3<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormHFace_4<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
static_cast<half const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported type error, expected [kHALF,kFLOAT]"));
......@@ -414,7 +486,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
int32_t const batchSize = inputDesc[0].dims.d[0] - 1;
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t S = 384;
......@@ -427,64 +499,141 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
}
const float* beta = mBetaDev.get();
const float* gamma = mGammaDev.get();
int32_t** tem_inputs_ptr_dev;
cudaMalloc(reinterpret_cast<void**>(&tem_inputs_ptr_dev),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(tem_inputs_ptr_dev,
inputs,
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
int32_t* mIdsVocabSize_dev;
cudaMalloc(reinterpret_cast<void**>(&mIdsVocabSize_dev),
sizeof(int32_t) * mIdsVocabSize.size());
cudaMemcpy(mIdsVocabSize_dev,
&(mIdsVocabSize[0]),
sizeof(int32_t) * mIdsVocabSize.size(),
cudaMemcpyHostToDevice);
if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]);
auto skip = static_cast<float*>(outputs[1]);
float** mIdsEmbDev_float;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_float),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_float,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormMTron<float>(stream,
if (nbLookupTables_ == 2) {
return embSkipLayerNormMTron_2<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output,
skip);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormMTron_3<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_float,
mIdsVocabSize_dev,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output,
skip);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormMTron_4<float>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<float const*>(mIdsEmbPtrs[0]),
static_cast<float const*>(mIdsEmbPtrs[1]),
static_cast<float const*>(mIdsEmbPtrs[2]),
static_cast<float const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output,
skip);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]);
auto skip = static_cast<half*>(outputs[1]);
half** mIdsEmbDev_half;
cudaMalloc(reinterpret_cast<void**>(&mIdsEmbDev_half),
sizeof(void*) * nbLookupTables_);
cudaMemcpy(mIdsEmbDev_half,
&(mIdsEmbDev[0]),
sizeof(void*) * nbLookupTables_,
cudaMemcpyHostToDevice);
return embSkipLayerNormMTron<half>(stream,
if (nbLookupTables_ == 2) {
return embSkipLayerNormMTron_2<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
mIdsVocabSize[0],
mIdsVocabSize[1],
output,
skip);
} else if (nbLookupTables_ == 3) {
return embSkipLayerNormMTron_3<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
tem_inputs_ptr_dev,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
nbLookupTables_,
beta,
gamma,
mIdsEmbDev_half,
mIdsVocabSize_dev,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
output,
skip);
} else if (nbLookupTables_ == 4) {
return embSkipLayerNormMTron_4<half>(
stream,
static_cast<int32_t>(mLd),
batchSize,
S,
static_cast<int32_t const*>(inputs[0]),
static_cast<int32_t const*>(inputs[1]),
static_cast<int32_t const*>(inputs[2]),
static_cast<int32_t const*>(inputs[3]),
nbLookupTables_,
beta,
gamma,
static_cast<half const*>(mIdsEmbPtrs[0]),
static_cast<half const*>(mIdsEmbPtrs[1]),
static_cast<half const*>(mIdsEmbPtrs[2]),
static_cast<half const*>(mIdsEmbPtrs[3]),
mIdsVocabSize[0],
mIdsVocabSize[1],
mIdsVocabSize[2],
mIdsVocabSize[3],
output,
skip);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support 2,3,4 lookup_tables fused "));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported type error, expected [kHALF,kFLOAT]"));
......@@ -566,9 +715,9 @@ void EmbLayerNormVarSeqlenPluginBase::serialize(void* buffer) const noexcept {
size_t const wordSize = getElementSize(mType);
serFromDev(&d, mBetaDev.get(), mLd);
serFromDev(&d, mGammaDev.get(), mLd);
for (size_t i = 0; i < mIdsEmbDev.size(); ++i) {
for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
serFromDev(&d,
static_cast<char*>(mIdsEmbDev[i]),
static_cast<char*>(mIdsEmbPtrs[i]),
mLd * mIdsVocabSize[i] * wordSize);
}
}
......@@ -577,8 +726,8 @@ void EmbLayerNormVarSeqlenPluginBase::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
mBetaDev.reset(nullptr);
mGammaDev.reset(nullptr);
for (size_t i = 0; i < mIdsEmbDev.size(); ++i) {
cudaFree(mIdsEmbDev[i]);
for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) {
cudaFree(mIdsEmbPtrs[i]);
}
delete this;
}
......@@ -680,7 +829,6 @@ nvinfer1::IPluginV2* EmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin(
beta,
gamma,
IdsEmb);
return p;
}
......
......@@ -31,32 +31,121 @@ namespace tensorrt {
namespace plugin {
template <typename T>
int32_t embSkipLayerNormHFace(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t** inputIds,
int32_t const nbLookupTables,
float const* beta,
float const* gamma,
T** idsEmb,
int32_t*,
T* output);
int32_t embSkipLayerNormHFace_2(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormMTron(cudaStream_t stream,
int32_t ld,
int32_t B,
int32_t S,
int32_t** inputIds,
int32_t const nbLookupTables,
float const* beta,
float const* gamma,
T** idsEmb,
int32_t*,
T* output,
T* skip);
int32_t embSkipLayerNormHFace_3(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormHFace_4(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
int32_t,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_2(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
int32_t,
int32_t,
T*,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_3(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
T*,
T*);
template <typename T>
int32_t embSkipLayerNormMTron_4(cudaStream_t,
int32_t,
int32_t,
int32_t,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t const*,
int32_t,
float const*,
float const*,
T const*,
T const*,
T const*,
T const*,
int32_t,
int32_t,
int32_t,
int32_t,
T*,
T*);
class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
public:
EmbLayerNormVarSeqlenPluginBase(
......@@ -104,7 +193,8 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt {
std::string mNamespace;
cuda_unique_ptr<float> mGammaDev;
cuda_unique_ptr<float> mBetaDev;
std::vector<void*> mIdsEmbDev;
std::vector<void*> mIdsEmbPtrs;
// std::vector<void*> mIdsEmbDev;
size_t mLd; // leading dim = hidden size
std::vector<int32_t> mIdsVocabSize;
WeightsWithOwnership mBeta;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册