From 40115c7ed6147f58c3913b7738bea616e4259af6 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 22 Mar 2023 09:59:33 +0800 Subject: [PATCH] fix embd for S (#51937) fix embd plugin: S = mask_id.d[1] --- .../tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu index 42d92f2bb48..e9ad763b43c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu @@ -353,7 +353,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue( cudaStream_t stream) noexcept { int32_t batchSize = inputDesc[0].dims.d[0] - 1; // read out the maximum sequence length from the dummy input - int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; + int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1]; int32_t S = 384; if (maxSeqlen <= 128) { S = 128; @@ -506,7 +506,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue( cudaStream_t stream) noexcept { int32_t batchSize = inputDesc[0].dims.d[0] - 1; // read out the maximum sequence length from the dummy input - int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; + int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1]; int32_t S = 384; if (maxSeqlen <= 128) { S = 128; -- GitLab