Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f2f305cd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f2f305cd
编写于
10月 22, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add recog interface
上级
bb75735f
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
484 addition
and
6 deletion
+484
-6
deepspeech/decoders/recog.py
deepspeech/decoders/recog.py
+154
-0
deepspeech/decoders/utils.py
deepspeech/decoders/utils.py
+75
-0
deepspeech/models/asr_interface.py
deepspeech/models/asr_interface.py
+148
-0
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+25
-3
deepspeech/modules/decoder.py
deepspeech/modules/decoder.py
+67
-2
deepspeech/modules/mask.py
deepspeech/modules/mask.py
+15
-1
未找到文件。
deepspeech/decoders/recog.py
0 → 100644
浏览文件 @
f2f305cd
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
import
json
import
paddle
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.asr.pytorch_backend.asr import load_trained_model
# from espnet.nets.lm_interface import dynamic_import_lm
# from espnet.nets.asr_interface import ASRInterface
from
.utils
import
add_results_to_json
# from .batch_beam_search import BatchBeamSearch
from
.beam_search
import
BeamSearch
from
.scorer_interface
import
BatchScorerInterface
from
.scorers.length_bonus
import
LengthBonus
from
deepspeech.io.reader
import
LoadInputsAndTargets
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
def
recog_v2
(
args
):
"""Decode with custom models that implements ScorerInterface.
Args:
args (namespace): The program arguments.
See py:func:`bin.asr_recog.get_parser` for details
"""
logger
.
warning
(
"experimental API for custom LMs is selected by --api v2"
)
if
args
.
batchsize
>
1
:
raise
NotImplementedError
(
"multi-utt batch decoding is not implemented"
)
if
args
.
streaming_mode
is
not
None
:
raise
NotImplementedError
(
"streaming mode is not implemented"
)
if
args
.
word_rnnlm
:
raise
NotImplementedError
(
"word LM is not implemented"
)
# set_deterministic(args)
model
,
train_args
=
load_trained_model
(
args
.
model
)
# assert isinstance(model, ASRInterface)
model
.
eval
()
load_inputs_and_targets
=
LoadInputsAndTargets
(
mode
=
"asr"
,
load_output
=
False
,
sort_in_input_length
=
False
,
preprocess_conf
=
train_args
.
preprocess_conf
if
args
.
preprocess_conf
is
None
else
args
.
preprocess_conf
,
preprocess_args
=
{
"train"
:
False
},
)
if
args
.
rnnlm
:
lm_args
=
get_model_conf
(
args
.
rnnlm
,
args
.
rnnlm_conf
)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module
=
getattr
(
lm_args
,
"model_module"
,
"default"
)
lm_class
=
dynamic_import_lm
(
lm_model_module
,
lm_args
.
backend
)
lm
=
lm_class
(
len
(
train_args
.
char_list
),
lm_args
)
torch_load
(
args
.
rnnlm
,
lm
)
lm
.
eval
()
else
:
lm
=
None
if
args
.
ngram_model
:
from
.scorers.ngram
import
NgramFullScorer
from
.scorers.ngram
import
NgramPartScorer
if
args
.
ngram_scorer
==
"full"
:
ngram
=
NgramFullScorer
(
args
.
ngram_model
,
train_args
.
char_list
)
else
:
ngram
=
NgramPartScorer
(
args
.
ngram_model
,
train_args
.
char_list
)
else
:
ngram
=
None
scorers
=
model
.
scorers
()
scorers
[
"lm"
]
=
lm
scorers
[
"ngram"
]
=
ngram
scorers
[
"length_bonus"
]
=
LengthBonus
(
len
(
train_args
.
char_list
))
weights
=
dict
(
decoder
=
1.0
-
args
.
ctc_weight
,
ctc
=
args
.
ctc_weight
,
lm
=
args
.
lm_weight
,
ngram
=
args
.
ngram_weight
,
length_bonus
=
args
.
penalty
,
)
beam_search
=
BeamSearch
(
beam_size
=
args
.
beam_size
,
vocab_size
=
len
(
train_args
.
char_list
),
weights
=
weights
,
scorers
=
scorers
,
sos
=
model
.
sos
,
eos
=
model
.
eos
,
token_list
=
train_args
.
char_list
,
pre_beam_score_key
=
None
if
args
.
ctc_weight
==
1.0
else
"full"
,
)
# TODO(karita): make all scorers batchfied
if
args
.
batchsize
==
1
:
non_batch
=
[
k
for
k
,
v
in
beam_search
.
full_scorers
.
items
()
if
not
isinstance
(
v
,
BatchScorerInterface
)
]
if
len
(
non_batch
)
==
0
:
beam_search
.
__class__
=
BatchBeamSearch
logger
.
info
(
"BatchBeamSearch implementation is selected."
)
else
:
logger
.
warning
(
f
"As non-batch scorers
{
non_batch
}
are found, "
f
"fall back to non-batch implementation."
)
if
args
.
ngpu
>
1
:
raise
NotImplementedError
(
"only single GPU decoding is supported"
)
if
args
.
ngpu
==
1
:
device
=
"gpu:0"
else
:
device
=
"cpu"
dtype
=
getattr
(
paddle
,
args
.
dtype
)
logger
.
info
(
f
"Decoding device=
{
device
}
, dtype=
{
dtype
}
"
)
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
beam_search
.
to
(
device
=
device
,
dtype
=
dtype
)
beam_search
.
eval
()
# read json data
with
open
(
args
.
recog_json
,
"rb"
)
as
f
:
js
=
json
.
load
(
f
)
# josnlines to dict, key by 'utt'
js
=
{
item
[
'utt'
]:
item
for
item
in
js
}
new_js
=
{}
with
paddle
.
no_grad
():
for
idx
,
name
in
enumerate
(
js
.
keys
(),
1
):
logger
.
info
(
"(%d/%d) decoding "
+
name
,
idx
,
len
(
js
.
keys
()))
batch
=
[(
name
,
js
[
name
])]
feat
=
load_inputs_and_targets
(
batch
)[
0
][
0
]
enc
=
model
.
encode
(
paddle
.
to_tensor
(
feat
).
to
(
device
=
device
,
dtype
=
dtype
))
nbest_hyps
=
beam_search
(
x
=
enc
,
maxlenratio
=
args
.
maxlenratio
,
minlenratio
=
args
.
minlenratio
)
nbest_hyps
=
[
h
.
asdict
()
for
h
in
nbest_hyps
[:
min
(
len
(
nbest_hyps
),
args
.
nbest
)]
]
new_js
[
name
]
=
add_results_to_json
(
js
[
name
],
nbest_hyps
,
train_args
.
char_list
)
with
open
(
args
.
result_label
,
"wb"
)
as
f
:
f
.
write
(
json
.
dumps
(
{
"utts"
:
new_js
},
indent
=
4
,
ensure_ascii
=
False
,
sort_keys
=
True
).
encode
(
"utf_8"
)
)
deepspeech/decoders/utils.py
浏览文件 @
f2f305cd
...
@@ -47,3 +47,78 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
...
@@ -47,3 +47,78 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
return
True
return
True
else
:
else
:
return
False
return
False
# * ------------------ recognition related ------------------ *
def
parse_hypothesis
(
hyp
,
char_list
):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list
=
list
(
map
(
int
,
hyp
[
"yseq"
][
1
:]))
token_as_list
=
[
char_list
[
idx
]
for
idx
in
tokenid_as_list
]
score
=
float
(
hyp
[
"score"
])
# convert to string
tokenid
=
" "
.
join
([
str
(
idx
)
for
idx
in
tokenid_as_list
])
token
=
" "
.
join
(
token_as_list
)
text
=
""
.
join
(
token_as_list
).
replace
(
"<space>"
,
" "
)
return
text
,
token
,
tokenid
,
score
def
add_results_to_json
(
js
,
nbest_hyps
,
char_list
):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js
=
dict
()
new_js
[
"utt2spk"
]
=
js
[
"utt2spk"
]
new_js
[
"output"
]
=
[]
for
n
,
hyp
in
enumerate
(
nbest_hyps
,
1
):
# parse hypothesis
rec_text
,
rec_token
,
rec_tokenid
,
score
=
parse_hypothesis
(
hyp
,
char_list
)
# copy ground-truth
if
len
(
js
[
"output"
])
>
0
:
out_dic
=
dict
(
js
[
"output"
][
0
].
items
())
else
:
# for no reference case (e.g., speech translation)
out_dic
=
{
"name"
:
""
}
# update name
out_dic
[
"name"
]
+=
"[%d]"
%
n
# add recognition results
out_dic
[
"rec_text"
]
=
rec_text
out_dic
[
"rec_token"
]
=
rec_token
out_dic
[
"rec_tokenid"
]
=
rec_tokenid
out_dic
[
"score"
]
=
score
# add to list of N-best result dicts
new_js
[
"output"
].
append
(
out_dic
)
# show 1-best result
if
n
==
1
:
if
"text"
in
out_dic
.
keys
():
logging
.
info
(
"groundtruth: %s"
%
out_dic
[
"text"
])
logging
.
info
(
"prediction : %s"
%
out_dic
[
"rec_text"
])
return
new_js
\ No newline at end of file
deepspeech/models/asr_interface.py
0 → 100644
浏览文件 @
f2f305cd
"""ASR Interface module."""
import
argparse
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
ASRInterface
:
"""ASR Interface for ESPnet model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to parser."""
return
parser
@
classmethod
def
build
(
cls
,
idim
:
int
,
odim
:
int
,
**
kwargs
):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
args
=
argparse
.
Namespace
(
**
kwargs
)
return
cls
(
idim
,
odim
,
args
)
def
forward
(
self
,
xs
,
ilens
,
ys
,
olens
):
"""Compute loss for training.
:param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim)
:param ilens: batch of lengths of source sequences (B), paddle.Tensor
:param ys: batch of padded target sequences paddle.Tensor (B, Lmax)
:param olens: batch of lengths of target sequences (B), paddle.Tensor
:return: loss value
:rtype: paddle.Tensor
"""
raise
NotImplementedError
(
"forward method is not implemented"
)
def
recognize
(
self
,
x
,
recog_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"recognize method is not implemented"
)
def
recognize_batch
(
self
,
x
,
recog_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"Batch decoding is not supported yet."
)
def
calculate_all_attentions
(
self
,
xs
,
ilens
,
ys
):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise
NotImplementedError
(
"calculate_all_attentions method is not implemented"
)
def
calculate_all_ctc_probs
(
self
,
xs
,
ilens
,
ys
):
"""Calculate CTC probability.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: CTC probabilities (B, Tmax, vocab)
:rtype: float ndarray
"""
raise
NotImplementedError
(
"calculate_all_ctc_probs method is not implemented"
)
@
property
def
attention_plot_class
(
self
):
"""Get attention plot class."""
from
espnet.asr.asr_utils
import
PlotAttentionReport
return
PlotAttentionReport
@
property
def
ctc_plot_class
(
self
):
"""Get CTC plot class."""
from
espnet.asr.asr_utils
import
PlotCTCReport
return
PlotCTCReport
def
get_total_subsampling_factor
(
self
):
"""Get total subsampling factor."""
raise
NotImplementedError
(
"get_total_subsampling_factor method is not implemented"
)
def
encode
(
self
,
feat
):
"""Encode feature in `beam_search` (optional).
Args:
x (numpy.ndarray): input feature (T, D)
Returns:
paddle.Tensor: encoded feature (T, D)
"""
raise
NotImplementedError
(
"encode method is not implemented"
)
def
scorers
(
self
):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
raise
NotImplementedError
(
"decoders method is not implemented"
)
predefined_asr
=
{
"transformer"
:
"deepspeech.models.u2:E2E"
,
"conformer"
:
"deepspeech.models.u2:E2E"
,
}
def
dynamic_import_asr
(
module
,
name
):
"""Import ASR models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_asr`
name (str): asr name. e.g., transformer, conformer
Returns:
type: ASR class
"""
model_class
=
dynamic_import
(
module
,
predefined_asr
.
get
(
name
,
""
))
assert
issubclass
(
model_class
,
ASRInterface
),
f
"
{
module
}
does not implement ASRInterface"
return
model_class
deepspeech/models/u2/u2.py
浏览文件 @
f2f305cd
...
@@ -49,13 +49,15 @@ from deepspeech.utils.tensor_utils import pad_sequence
...
@@ -49,13 +49,15 @@ from deepspeech.utils.tensor_utils import pad_sequence
from
deepspeech.utils.tensor_utils
import
th_accuracy
from
deepspeech.utils.tensor_utils
import
th_accuracy
from
deepspeech.utils.utility
import
log_add
from
deepspeech.utils.utility
import
log_add
from
deepspeech.utils.utility
import
UpdateConfig
from
deepspeech.utils.utility
import
UpdateConfig
from
deepspeech.models.asr_interface
import
ASRInterface
from
deepspeech.decoders.scorers.ctc_prefix_score
import
CTCPrefixScorer
__all__
=
[
"U2Model"
,
"U2InferModel"
]
__all__
=
[
"U2Model"
,
"U2InferModel"
]
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
class
U2BaseModel
(
nn
.
Layer
):
class
U2BaseModel
(
ASRInterface
,
nn
.
Layer
):
"""CTC-Attention hybrid Encoder-Decoder model"""
"""CTC-Attention hybrid Encoder-Decoder model"""
@
classmethod
@
classmethod
...
@@ -120,7 +122,7 @@ class U2BaseModel(nn.Layer):
...
@@ -120,7 +122,7 @@ class U2BaseModel(nn.Layer):
**
kwargs
):
**
kwargs
):
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
super
().
__init__
(
)
nn
.
Layer
.
__init__
(
self
)
# note that eos is the same as sos (equivalent ID)
# note that eos is the same as sos (equivalent ID)
self
.
sos
=
vocab_size
-
1
self
.
sos
=
vocab_size
-
1
self
.
eos
=
vocab_size
-
1
self
.
eos
=
vocab_size
-
1
...
@@ -813,7 +815,27 @@ class U2BaseModel(nn.Layer):
...
@@ -813,7 +815,27 @@ class U2BaseModel(nn.Layer):
return
res
,
res_tokenids
return
res
,
res_tokenids
class
U2Model
(
U2BaseModel
):
class
U2DecodeModel
(
U2BaseModel
):
def
scorers
(
self
):
"""Scorers."""
return
dict
(
decoder
=
self
.
decoder
,
ctc
=
CTCPrefixScorer
(
self
.
ctc
,
self
.
eos
))
def
encode
(
self
,
x
):
"""Encode acoustic features.
:param ndarray x: source acoustic feature (T, D)
:return: encoder outputs
:rtype: paddle.Tensor
"""
self
.
eval
()
x
=
paddle
.
to_tensor
(
x
).
unsqueeze
(
0
)
ilen
=
x
.
size
(
1
)
enc_output
,
_
=
self
.
_forward_encoder
(
x
,
ilen
)
return
enc_output
.
squeeze
(
0
)
class
U2Model
(
U2DecodeModel
):
def
__init__
(
self
,
configs
:
dict
):
def
__init__
(
self
,
configs
:
dict
):
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
vocab_size
,
encoder
,
decoder
,
ctc
=
U2Model
.
_init_from_config
(
configs
)
...
...
deepspeech/modules/decoder.py
浏览文件 @
f2f305cd
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
typing
import
List
from
typing
import
List
from
typing
import
Optional
from
typing
import
Optional
from
typing
import
Tuple
from
typing
import
Tuple
from
typing
import
Any
import
paddle
import
paddle
from
paddle
import
nn
from
paddle
import
nn
...
@@ -25,7 +26,9 @@ from deepspeech.modules.decoder_layer import DecoderLayer
...
@@ -25,7 +26,9 @@ from deepspeech.modules.decoder_layer import DecoderLayer
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.mask
import
make_non_pad_mask
from
deepspeech.modules.mask
import
make_non_pad_mask
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.mask
import
make_xs_mask
from
deepspeech.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
deepspeech.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
deepspeech.decoders.scorers.score_interface
import
BatchScorerInterface
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
@@ -33,7 +36,7 @@ logger = Log(__name__).getlog()
...
@@ -33,7 +36,7 @@ logger = Log(__name__).getlog()
__all__
=
[
"TransformerDecoder"
]
__all__
=
[
"TransformerDecoder"
]
class
TransformerDecoder
(
nn
.
Layer
):
class
TransformerDecoder
(
BatchScorerInterface
,
nn
.
Layer
):
"""Base class of Transfomer decoder module.
"""Base class of Transfomer decoder module.
Args:
Args:
vocab_size: output dim
vocab_size: output dim
...
@@ -71,7 +74,8 @@ class TransformerDecoder(nn.Layer):
...
@@ -71,7 +74,8 @@ class TransformerDecoder(nn.Layer):
concat_after
:
bool
=
False
,
):
concat_after
:
bool
=
False
,
):
assert
check_argument_types
()
assert
check_argument_types
()
super
().
__init__
()
nn
.
Layer
.
__init__
(
self
)
self
.
selfattention_layer_type
=
'selfattn'
attention_dim
=
encoder_output_size
attention_dim
=
encoder_output_size
if
input_layer
==
"embed"
:
if
input_layer
==
"embed"
:
...
@@ -180,3 +184,64 @@ class TransformerDecoder(nn.Layer):
...
@@ -180,3 +184,64 @@ class TransformerDecoder(nn.Layer):
if
self
.
use_output_layer
:
if
self
.
use_output_layer
:
y
=
paddle
.
log_softmax
(
self
.
output_layer
(
y
),
axis
=-
1
)
y
=
paddle
.
log_softmax
(
self
.
output_layer
(
y
),
axis
=-
1
)
return
y
,
new_cache
return
y
,
new_cache
# beam search API (see ScorerInterface)
def
score
(
self
,
ys
,
state
,
x
):
"""Score.
ys: (ylen,)
x: (xlen, n_feat)
"""
ys_mask
=
subsequent_mask
(
len
(
ys
)).
unsqueeze
(
0
)
x_mask
=
make_xs_mask
(
x
.
unsqueeze
(
0
))
if
self
.
selfattention_layer_type
!=
"selfattn"
:
# TODO(karita): implement cache
logging
.
warning
(
f
"
{
self
.
selfattention_layer_type
}
does not support cached decoding."
)
state
=
None
logp
,
state
=
self
.
forward_one_step
(
x
.
unsqueeze
(
0
),
x_mask
,
ys
.
unsqueeze
(
0
),
ys_mask
,
cache
=
state
)
return
logp
.
squeeze
(
0
),
state
# batch beam search API (see BatchScorerInterface)
def
batch_score
(
self
,
ys
:
paddle
.
Tensor
,
states
:
List
[
Any
],
xs
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
List
[
Any
]]:
"""Score new token batch (required).
Args:
ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (paddle.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[paddle.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch
=
len
(
ys
)
n_layers
=
len
(
self
.
decoders
)
if
states
[
0
]
is
None
:
batch_state
=
None
else
:
# transpose state of [batch, layer] into [layer, batch]
batch_state
=
[
paddle
.
stack
([
states
[
b
][
i
]
for
b
in
range
(
n_batch
)])
for
i
in
range
(
n_layers
)
]
# batch decoding
ys_mask
=
subsequent_mask
(
ys
.
size
(
-
1
)).
unsqueeze
(
0
)
xs_mask
=
make_xs_mask
(
xs
)
logp
,
states
=
self
.
forward_one_step
(
xs
,
xs_mask
,
ys
,
ys_mask
,
cache
=
batch_state
)
# transpose state of [layer, batch] into [batch, layer]
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
return
logp
,
state_list
deepspeech/modules/mask.py
浏览文件 @
f2f305cd
...
@@ -18,12 +18,24 @@ from deepspeech.utils.log import Log
...
@@ -18,12 +18,24 @@ from deepspeech.utils.log import Log
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
__all__
=
[
"make_pad_mask"
,
"make_non_pad_mask"
,
"subsequent_mask"
,
"make_
xs_mask"
,
"make_
pad_mask"
,
"make_non_pad_mask"
,
"subsequent_mask"
,
"subsequent_chunk_mask"
,
"add_optional_chunk_mask"
,
"mask_finished_scores"
,
"subsequent_chunk_mask"
,
"add_optional_chunk_mask"
,
"mask_finished_scores"
,
"mask_finished_preds"
"mask_finished_preds"
]
]
def
make_xs_mask
(
xs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Maks mask tensor containing indices of non-padded part.
Args:
xs (paddle.Tensor): (B, T, D), zeros for pad.
Returns:
paddle.Tensor: Mask Tensor indices of non-padded part. (B, T, D)
"""
pad_frame
=
paddle
.
zeros
([
1
,
1
,
xs
.
shape
[
-
1
]],
dtype
=
xs
.
dtype
)
mask
=
xs
!=
pad_frame
return
mask
def
make_pad_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
def
make_pad_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Make mask tensor containing indices of padded part.
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
See description of make_non_pad_mask.
...
@@ -31,6 +43,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
...
@@ -31,6 +43,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,).
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
Returns:
paddle.Tensor: Mask tensor containing indices of padded part.
paddle.Tensor: Mask tensor containing indices of padded part.
(B, T)
Examples:
Examples:
>>> lengths = [5, 3, 2]
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
>>> make_pad_mask(lengths)
...
@@ -62,6 +75,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
...
@@ -62,6 +75,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
lengths (paddle.Tensor): Batch of lengths (B,).
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
Returns:
paddle.Tensor: mask tensor containing indices of padded part.
paddle.Tensor: mask tensor containing indices of padded part.
(B, T)
Examples:
Examples:
>>> lengths = [5, 3, 2]
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
>>> make_non_pad_mask(lengths)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录