Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f2f305cd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2f305cd
编写于
10月 22, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add recog interface
上级
bb75735f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
484 addition
and
6 deletion
+484
-6
deepspeech/decoders/recog.py
deepspeech/decoders/recog.py
+154
-0
deepspeech/decoders/utils.py
deepspeech/decoders/utils.py
+75
-0
deepspeech/models/asr_interface.py
deepspeech/models/asr_interface.py
+148
-0
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+25
-3
deepspeech/modules/decoder.py
deepspeech/modules/decoder.py
+67
-2
deepspeech/modules/mask.py
deepspeech/modules/mask.py
+15
-1
未找到文件。
deepspeech/decoders/recog.py
0 → 100644
浏览文件 @
f2f305cd
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
import
json
import
paddle
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.asr.pytorch_backend.asr import load_trained_model
# from espnet.nets.lm_interface import dynamic_import_lm
# from espnet.nets.asr_interface import ASRInterface
from
.utils
import
add_results_to_json
# from .batch_beam_search import BatchBeamSearch
from
.beam_search
import
BeamSearch
from
.scorer_interface
import
BatchScorerInterface
from
.scorers.length_bonus
import
LengthBonus
from
deepspeech.io.reader
import
LoadInputsAndTargets
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
def
recog_v2
(
args
):
"""Decode with custom models that implements ScorerInterface.
Args:
args (namespace): The program arguments.
See py:func:`bin.asr_recog.get_parser` for details
"""
logger
.
warning
(
"experimental API for custom LMs is selected by --api v2"
)
if
args
.
batchsize
>
1
:
raise
NotImplementedError
(
"multi-utt batch decoding is not implemented"
)
if
args
.
streaming_mode
is
not
None
:
raise
NotImplementedError
(
"streaming mode is not implemented"
)
if
args
.
word_rnnlm
:
raise
NotImplementedError
(
"word LM is not implemented"
)
# set_deterministic(args)
model
,
train_args
=
load_trained_model
(
args
.
model
)
# assert isinstance(model, ASRInterface)
model
.
eval
()
load_inputs_and_targets
=
LoadInputsAndTargets
(
mode
=
"asr"
,
load_output
=
False
,
sort_in_input_length
=
False
,
preprocess_conf
=
train_args
.
preprocess_conf
if
args
.
preprocess_conf
is
None
else
args
.
preprocess_conf
,
preprocess_args
=
{
"train"
:
False
},
)
if
args
.
rnnlm
:
lm_args
=
get_model_conf
(
args
.
rnnlm
,
args
.
rnnlm_conf
)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module
=
getattr
(
lm_args
,
"model_module"
,
"default"
)
lm_class
=
dynamic_import_lm
(
lm_model_module
,
lm_args
.
backend
)
lm
=
lm_class
(
len
(
train_args
.
char_list
),
lm_args
)
torch_load
(
args
.
rnnlm
,
lm
)
lm
.
eval
()
else
:
lm
=
None
if
args
.
ngram_model
:
from
.scorers.ngram
import
NgramFullScorer
from
.scorers.ngram
import
NgramPartScorer
if
args
.
ngram_scorer
==
"full"
:
ngram
=
NgramFullScorer
(
args
.
ngram_model
,
train_args
.
char_list
)
else
:
ngram
=
NgramPartScorer
(
args
.
ngram_model
,
train_args
.
char_list
)
else
:
ngram
=
None
scorers
=
model
.
scorers
()
scorers
[
"lm"
]
=
lm
scorers
[
"ngram"
]
=
ngram
scorers
[
"length_bonus"
]
=
LengthBonus
(
len
(
train_args
.
char_list
))
weights
=
dict
(
decoder
=
1.0
-
args
.
ctc_weight
,
ctc
=
args
.
ctc_weight
,
lm
=
args
.
lm_weight
,
ngram
=
args
.
ngram_weight
,
length_bonus
=
args
.
penalty
,
)
beam_search
=
BeamSearch
(
beam_size
=
args
.
beam_size
,
vocab_size
=
len
(
train_args
.
char_list
),
weights
=
weights
,
scorers
=
scorers
,
sos
=
model
.
sos
,
eos
=
model
.
eos
,
token_list
=
train_args
.
char_list
,
pre_beam_score_key
=
None
if
args
.
ctc_weight
==
1.0
else
"full"
,
)
# TODO(karita): make all scorers batchfied
if
args
.
batchsize
==
1
:
non_batch
=
[
k
for
k
,
v
in
beam_search
.
full_scorers
.
items
()
if
not
isinstance
(
v
,
BatchScorerInterface
)
]
if
len
(
non_batch
)
==
0
:
beam_search
.
__class__
=
BatchBeamSearch
logger
.
info
(
"BatchBeamSearch implementation is selected."
)
else
:
logger
.
warning
(
f
"As non-batch scorers
{
non_batch
}
are found, "
f
"fall back to non-batch implementation."
)
if
args
.
ngpu
>
1
:
raise
NotImplementedError
(
"only single GPU decoding is supported"
)
if
args
.
ngpu
==
1
:
device
=
"gpu:0"
else
:
device
=
"cpu"
dtype
=
getattr
(
paddle
,
args
.
dtype
)
logger
.
info
(
f
"Decoding device=
{
device
}
, dtype=
{
dtype
}
"
)
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
beam_search
.
to
(
device
=
device
,
dtype
=
dtype
)
beam_search
.
eval
()
# read json data
with
open
(
args
.
recog_json
,
"rb"
)
as
f
:
js
=
json
.
load
(
f
)
# josnlines to dict, key by 'utt'
js
=
{
item
[
'utt'
]:
item
for
item
in
js
}
new_js
=
{}
with
paddle
.
no_grad
():
for
idx
,
name
in
enumerate
(
js
.
keys
(),
1
):
logger
.
info
(
"(%d/%d) decoding "
+
name
,
idx
,
len
(
js
.
keys
()))
batch
=
[(
name
,
js
[
name
])]
feat
=
load_inputs_and_targets
(
batch
)[
0
][
0
]
enc
=
model
.
encode
(
paddle
.
to_tensor
(
feat
).
to
(
device
=
device
,
dtype
=
dtype
))
nbest_hyps
=
beam_search
(
x
=
enc
,
maxlenratio
=
args
.
maxlenratio
,
minlenratio
=
args
.
minlenratio
)
nbest_hyps
=
[
h
.
asdict
()
for
h
in
nbest_hyps
[:
min
(
len
(
nbest_hyps
),
args
.
nbest
)]
]
new_js
[
name
]
=
add_results_to_json
(
js
[
name
],
nbest_hyps
,
train_args
.
char_list
)
with
open
(
args
.
result_label
,
"wb"
)
as
f
:
f
.
write
(
json
.
dumps
(
{
"utts"
:
new_js
},
indent
=
4
,
ensure_ascii
=
False
,
sort_keys
=
True
).
encode
(
"utf_8"
)
)
deepspeech/decoders/utils.py
浏览文件 @
f2f305cd
...
...
@@ -47,3 +47,78 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
return
True
else
:
return
False
# * ------------------ recognition related ------------------ *
def
parse_hypothesis
(
hyp
,
char_list
):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list
=
list
(
map
(
int
,
hyp
[
"yseq"
][
1
:]))
token_as_list
=
[
char_list
[
idx
]
for
idx
in
tokenid_as_list
]
score
=
float
(
hyp
[
"score"
])
# convert to string
tokenid
=
" "
.
join
([
str
(
idx
)
for
idx
in
tokenid_as_list
])
token
=
" "
.
join
(
token_as_list
)
text
=
""
.
join
(
token_as_list
).
replace
(
"<space>"
,
" "
)
return
text
,
token
,
tokenid
,
score
def
add_results_to_json
(
js
,
nbest_hyps
,
char_list
):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js
=
dict
()
new_js
[
"utt2spk"
]
=
js
[
"utt2spk"
]
new_js
[
"output"
]
=
[]
for
n
,
hyp
in
enumerate
(
nbest_hyps
,
1
):
# parse hypothesis
rec_text
,
rec_token
,
rec_tokenid
,
score
=
parse_hypothesis
(
hyp
,
char_list
)
# copy ground-truth
if
len
(
js
[
"output"
])
>
0
:
out_dic
=
dict
(
js
[
"output"
][
0
].
items
())
else
:
# for no reference case (e.g., speech translation)
out_dic
=
{
"name"
:
""
}
# update name
out_dic
[
"name"
]
+=
"[%d]"
%
n
# add recognition results
out_dic
[
"rec_text"
]
=
rec_text
out_dic
[
"rec_token"
]
=
rec_token
out_dic
[
"rec_tokenid"
]
=
rec_tokenid
out_dic
[
"score"
]
=
score
# add to list of N-best result dicts
new_js
[
"output"
].
append
(
out_dic
)
# show 1-best result
if
n
==
1
:
if
"text"
in
out_dic
.
keys
():
logging
.
info
(
"groundtruth: %s"
%
out_dic
[
"text"
])
logging
.
info
(
"prediction : %s"
%
out_dic
[
"rec_text"
])
return
new_js
\ No newline at end of file
deepspeech/models/asr_interface.py
0 → 100644
浏览文件 @
f2f305cd
"""ASR Interface module."""
import
argparse
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
ASRInterface
:
"""ASR Interface for ESPnet model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to parser."""
return
parser
@
classmethod
def
build
(
cls
,
idim
:
int
,
odim
:
int
,
**
kwargs
):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
args
=
argparse
.
Namespace
(
**
kwargs
)
return
cls
(
idim
,
odim
,
args
)
def
forward
(
self
,
xs
,
ilens
,
ys
,
olens
):
"""Compute loss for training.
:param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim)
:param ilens: batch of lengths of source sequences (B), paddle.Tensor
:param ys: batch of padded target sequences paddle.Tensor (B, Lmax)
:param olens: batch of lengths of target sequences (B), paddle.Tensor
:return: loss value
:rtype: paddle.Tensor
"""
raise
NotImplementedError
(
"forward method is not implemented"
)
def
recognize
(
self
,
x
,
recog_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"recognize method is not implemented"
)
def
recognize_batch
(
self
,
x
,
recog_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"Batch decoding is not supported yet."
)
def
calculate_all_attentions
(
self
,
xs
,
ilens
,
ys
):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise
NotImplementedError
(
"calculate_all_attentions method is not implemented"
)
def
calculate_all_ctc_probs
(
self
,
xs
,
ilens
,
ys
):
"""Calculate CTC probability.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: CTC probabilities (B, Tmax, vocab)
:rtype: float ndarray
"""
raise
NotImplementedError
(
"calculate_all_ctc_probs method is not implemented"
)
@
property
def
attention_plot_class
(
self
):
"""Get attention plot class."""
from
espnet.asr.asr_utils
import
PlotAttentionReport
return
PlotAttentionReport
@
property
def
ctc_plot_class
(
self
):
"""Get CTC plot class."""
from
espnet.asr.asr_utils
import
PlotCTCReport
return
PlotCTCReport
def
get_total_subsampling_factor
(
self
):
"""Get total subsampling factor."""
raise
NotImplementedError
(
"get_total_subsampling_factor method is not implemented"
)
def
encode
(
self
,
feat
):
"""Encode feature in `beam_search` (optional).
Args:
x (numpy.ndarray): input feature (T, D)
Returns:
paddle.Tensor: encoded feature (T, D)
"""
raise
NotImplementedError
(
"encode method is not implemented"
)
def
scorers
(
self
):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
raise
NotImplementedError
(
"decoders method is not implemented"
)
predefined_asr
=
{
"transformer"
:
"deepspeech.models.u2:E2E"
,
"conformer"
:
"deepspeech.models.u2:E2E"
,
}
def
dynamic_import_asr
(
module
,
name
):
"""Import ASR models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_asr`
name (str): asr name. e.g., transformer, conformer
Returns:
type: ASR class
"""
model_class
=
dynamic_import
(
module
,
predefined_asr
.
get
(
name
,
""
))
assert
issubclass
(
model_class
,
ASRInterface
),
f
"
{
module
}
does not implement ASRInterface"
return
model_class
deepspeech/models/u2/u2.py
浏览文件 @
f2f305cd
...
...
@@ -49,13 +49,15 @@ from deepspeech.utils.tensor_utils import pad_sequence
from
deepspeech.utils.tensor_utils
import
th_accuracy
from
deepspeech.utils.utility
import
log_add
from
deepspeech.utils.utility
import
UpdateConfig
from
deepspeech.models.asr_interface
import
ASRInterface
from
deepspeech.decoders.scorers.ctc_prefix_score
import
CTCPrefixScorer
__all__
=
[
"U2Model"
,
"U2InferModel"
]
logger
=
Log
(
__name__
).
getlog
()
class
U2BaseModel
(
nn
.
Layer
):
class
U2BaseModel
(
ASRInterface
,
nn
.
Layer
):
"""CTC-Attention hybrid Encoder-Decoder model"""
@
classmethod
...
...
@@ -120,7 +122,7 @@ class U2BaseModel(nn.Layer):
**
kwargs
):
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
super
().
__init__
(
)
nn
.
Layer
.
__init__
(
self
)
# note that eos is the same as sos (equivalent ID)
self
.
sos
=
vocab_size
-
1
self
.
eos
=
vocab_size
-
1
...
...
@@ -813,7 +815,27 @@ class U2BaseModel(nn.Layer):
return
res
,
res_tokenids
class
U2Model
(
U2BaseModel
):
class
U2DecodeModel
(
U2BaseModel
):
def
scorers
(
self
):
"""Scorers."""
return
dict
(
decoder
=
self
.
decoder
,
ctc
=
CTCPrefixScorer
(
self
.
ctc
,
self
.
eos
))
def
encode
(
self
,
x
):
"""Encode acoustic features.
:param ndarray x: source acoustic feature (T, D)
:return: encoder outputs
:rtype: paddle.Tensor
"""
self
.
eval
()
x
=
paddle
.
to_tensor
(
x
).
unsqueeze
(
0
)
ilen
=
x
.
size
(
1
)
enc_output
,
_
=
self
.
_forward_encoder
(
x
,
ilen
)
return
enc_output
.
squeeze
(
0
)
class
U2Model
(
U2DecodeModel
):
def
__init__
(
self
,
configs
:
dict
):
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
...
...
deepspeech/modules/decoder.py
浏览文件 @
f2f305cd
...
...
@@ -15,6 +15,7 @@
from
typing
import
List
from
typing
import
Optional
from
typing
import
Tuple
from
typing
import
Any
import
paddle
from
paddle
import
nn
...
...
@@ -25,7 +26,9 @@ from deepspeech.modules.decoder_layer import DecoderLayer
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.mask
import
make_non_pad_mask
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.mask
import
make_xs_mask
from
deepspeech.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
deepspeech.decoders.scorers.score_interface
import
BatchScorerInterface
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
...
...
@@ -33,7 +36,7 @@ logger = Log(__name__).getlog()
__all__
=
[
"TransformerDecoder"
]
class
TransformerDecoder
(
nn
.
Layer
):
class
TransformerDecoder
(
BatchScorerInterface
,
nn
.
Layer
):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
...
...
@@ -71,7 +74,8 @@ class TransformerDecoder(nn.Layer):
concat_after
:
bool
=
False
,
):
assert
check_argument_types
()
super
().
__init__
()
nn
.
Layer
.
__init__
(
self
)
self
.
selfattention_layer_type
=
'selfattn'
attention_dim
=
encoder_output_size
if
input_layer
==
"embed"
:
...
...
@@ -180,3 +184,64 @@ class TransformerDecoder(nn.Layer):
if
self
.
use_output_layer
:
y
=
paddle
.
log_softmax
(
self
.
output_layer
(
y
),
axis
=-
1
)
return
y
,
new_cache
# beam search API (see ScorerInterface)
def
score
(
self
,
ys
,
state
,
x
):
"""Score.
ys: (ylen,)
x: (xlen, n_feat)
"""
ys_mask
=
subsequent_mask
(
len
(
ys
)).
unsqueeze
(
0
)
x_mask
=
make_xs_mask
(
x
.
unsqueeze
(
0
))
if
self
.
selfattention_layer_type
!=
"selfattn"
:
# TODO(karita): implement cache
logging
.
warning
(
f
"
{
self
.
selfattention_layer_type
}
does not support cached decoding."
)
state
=
None
logp
,
state
=
self
.
forward_one_step
(
x
.
unsqueeze
(
0
),
x_mask
,
ys
.
unsqueeze
(
0
),
ys_mask
,
cache
=
state
)
return
logp
.
squeeze
(
0
),
state
# batch beam search API (see BatchScorerInterface)
def
batch_score
(
self
,
ys
:
paddle
.
Tensor
,
states
:
List
[
Any
],
xs
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
List
[
Any
]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch
=
len
(
ys
)
n_layers
=
len
(
self
.
decoders
)
if
states
[
0
]
is
None
:
batch_state
=
None
else
:
# transpose state of [batch, layer] into [layer, batch]
batch_state
=
[
paddle
.
stack
([
states
[
b
][
i
]
for
b
in
range
(
n_batch
)])
for
i
in
range
(
n_layers
)
]
# batch decoding
ys_mask
=
subsequent_mask
(
ys
.
size
(
-
1
)).
unsqueeze
(
0
)
xs_mask
=
make_xs_mask
(
xs
)
logp
,
states
=
self
.
forward_one_step
(
xs
,
xs_mask
,
ys
,
ys_mask
,
cache
=
batch_state
)
# transpose state of [layer, batch] into [batch, layer]
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
return
logp
,
state_list
deepspeech/modules/mask.py
浏览文件 @
f2f305cd
...
...
@@ -18,12 +18,24 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
"make_pad_mask"
,
"make_non_pad_mask"
,
"subsequent_mask"
,
"make_
xs_mask"
,
"make_
pad_mask"
,
"make_non_pad_mask"
,
"subsequent_mask"
,
"subsequent_chunk_mask"
,
"add_optional_chunk_mask"
,
"mask_finished_scores"
,
"mask_finished_preds"
]
def
make_xs_mask
(
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Maks mask tensor containing indices of non-padded part.
Args:
xs (paddle.Tensor): (B, T, D), zeros for pad.
Returns:
paddle.Tensor: Mask Tensor indices of non-padded part. (B, T, D)
"""
pad_frame
=
paddle
.
zeros
([
1
,
1
,
xs
.
shape
[
-
1
]],
dtype
=
xs
.
dtype
)
mask
=
xs
!=
pad_frame
return
mask
def
make_pad_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
...
...
@@ -31,6 +43,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: Mask tensor containing indices of padded part.
(B, T)
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
...
...
@@ -62,6 +75,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: mask tensor containing indices of padded part.
(B, T)
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录