Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
25978176
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
25978176
编写于
7月 07, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
comment u2 model for easy understand
上级
96c64237
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
31 addition
and
15 deletion
+31
-15
deepspeech/models/u2.py
deepspeech/models/u2.py
+27
-14
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+4
-1
未找到文件。
deepspeech/models/u2.py
浏览文件 @
25978176
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""U2 ASR Model
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
"""
import
sys
...
...
@@ -83,7 +83,7 @@ class U2BaseModel(nn.Module):
# cnn_module_kernel=15,
# activation_type='swish',
# pos_enc_layer_type='rel_pos',
# selfattention_layer_type='rel_selfattn',
# selfattention_layer_type='rel_selfattn',
))
# decoder related
default
.
decoder
=
'transformer'
...
...
@@ -244,8 +244,8 @@ class U2BaseModel(nn.Module):
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
...
...
@@ -399,6 +399,7 @@ class U2BaseModel(nn.Module):
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size
# encoder_out: (B, maxlen, encoder_dim)
# encoder_mask: (B, 1, Tmax)
...
...
@@ -410,10 +411,12 @@ class U2BaseModel(nn.Module):
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
astype
(
paddle
.
int
).
sum
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (B, maxlen, vocab_size)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
axis
=
2
)
# (B, maxlen, 1)
topk_index
=
topk_index
.
view
(
batch_size
,
maxlen
)
# (B, maxlen)
pad_mask
=
make_pad_mask
(
encoder_out_lens
)
# (B, maxlen)
topk_index
=
topk_index
.
masked_fill_
(
pad_mask
,
self
.
eos
)
# (B, maxlen)
hyps
=
[
hyp
.
tolist
()
for
hyp
in
topk_index
]
hyps
=
[
remove_duplicates_and_blank
(
hyp
)
for
hyp
in
hyps
]
return
hyps
...
...
@@ -449,6 +452,7 @@ class U2BaseModel(nn.Module):
batch_size
=
speech
.
shape
[
0
]
# For CTC prefix beam search, we only support batch_size=1
assert
batch_size
==
1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
...
...
@@ -458,7 +462,9 @@ class U2BaseModel(nn.Module):
maxlen
=
encoder_out
.
size
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
...
...
@@ -498,6 +504,7 @@ class U2BaseModel(nn.Module):
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
cur_hyps
=
next_hyps
[:
beam_size
]
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
cur_hyps
]
return
hyps
,
encoder_out
...
...
@@ -561,12 +568,13 @@ class U2BaseModel(nn.Module):
batch_size
=
speech
.
shape
[
0
]
# For attention rescoring we only support batch_size=1
assert
batch_size
==
1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
# len(hyps) = beam_size, encoder_out: (1, maxlen, encoder_dim)
hyps
,
encoder_out
=
self
.
_ctc_prefix_beam_search
(
speech
,
speech_lengths
,
beam_size
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
assert
len
(
hyps
)
==
beam_size
hyps_pad
=
pad_sequence
([
paddle
.
to_tensor
(
hyp
[
0
],
place
=
device
,
dtype
=
paddle
.
long
)
for
hyp
in
hyps
...
...
@@ -576,23 +584,28 @@ class U2BaseModel(nn.Module):
dtype
=
paddle
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
paddle
.
ones
(
(
beam_size
,
1
,
encoder_out
.
size
(
1
)),
dtype
=
paddle
.
bool
)
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
)
# (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out
=
paddle
.
nn
.
functional
.
log_softmax
(
decoder_out
,
axis
=-
1
)
decoder_out
=
decoder_out
.
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
# last decoder output token is `eos`, for laste decoder input token.
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
# add ctc score
# add ctc score
(which in ln domain)
score
+=
hyp
[
1
]
*
ctc_weight
if
score
>
best_score
:
best_score
=
score
...
...
@@ -719,8 +732,8 @@ class U2BaseModel(nn.Module):
feats (Tenosr): audio features, (B, T, D)
feats_lengths (Tenosr): (B)
text_feature (TextFeaturizer): text feature object.
decoding_method (str): decoding mode, e.g.
'attention', 'ctc_greedy_search',
decoding_method (str): decoding mode, e.g.
'attention', 'ctc_greedy_search',
'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
...
...
@@ -728,19 +741,19 @@ class U2BaseModel(nn.Module):
beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
Raises:
ValueError: when not support decoding_method.
Returns:
List[List[int]]: transcripts.
"""
...
...
@@ -821,7 +834,7 @@ class U2Model(U2BaseModel):
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
...
...
deepspeech/modules/encoder.py
浏览文件 @
25978176
...
...
@@ -219,11 +219,14 @@ class BaseEncoder(nn.Layer):
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
=
offset
)
#xs=(B, T, D), pos_emb=(B=1, T, D)
if
subsampling_cache
is
not
None
:
cache_size
=
subsampling_cache
.
size
(
1
)
#T
xs
=
paddle
.
cat
((
subsampling_cache
,
xs
),
dim
=
1
)
else
:
cache_size
=
0
# only used when using `RelPositionMultiHeadedAttention`
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_size
,
size
=
xs
.
size
(
1
))
...
...
@@ -237,7 +240,7 @@ class BaseEncoder(nn.Layer):
# Real mask for transformer/conformer layers
masks
=
paddle
.
ones
([
1
,
xs
.
size
(
1
)],
dtype
=
paddle
.
bool
)
masks
=
masks
.
unsqueeze
(
1
)
#[B=1,
C
=1, T]
masks
=
masks
.
unsqueeze
(
1
)
#[B=1,
L'
=1, T]
r_elayers_output_cache
=
[]
r_conformer_cnn_cache
=
[]
for
i
,
layer
in
enumerate
(
self
.
encoders
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录