未验证 提交 40115c7e 编写于 作者: W Wangzheee 提交者: GitHub

fix embd for S (#51937)

fix embd plugin: S = mask_id.d[1]
上级 e0b58212
......@@ -353,7 +353,7 @@ int32_t EmbLayerNormVarSeqlenPluginHFace::enqueue(
cudaStream_t stream) noexcept {
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1];
int32_t S = 384;
if (maxSeqlen <= 128) {
S = 128;
......@@ -506,7 +506,7 @@ int32_t EmbLayerNormVarSeqlenPluginMTron::enqueue(
cudaStream_t stream) noexcept {
int32_t batchSize = inputDesc[0].dims.d[0] - 1;
// read out the maximum sequence length from the dummy input
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[0];
int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1];
int32_t S = 384;
if (maxSeqlen <= 128) {
S = 128;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册