diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu index 42d92f2bb4806955bcf64072df8364df772334ce..e9ad763b43c8bc5e0fcffe10d51f245a37c6faa0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu @@ -353,7 +353,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue( cudaStream_t stream) noexcept { int32_t batchSize = inputDesc[0].dims.d[0] - 1; // read out the maximum sequence length from the dummy input - int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; + int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1]; int32_t S = 384; if (maxSeqlen <= 128) { S = 128; @@ -506,7 +506,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue( cudaStream_t stream) noexcept { int32_t batchSize = inputDesc[0].dims.d[0] - 1; // read out the maximum sequence length from the dummy input - int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0]; + int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1]; int32_t S = 384; if (maxSeqlen <= 128) { S = 128;