Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
498104b0
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
498104b0
编写于
4月 02, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor data feed order
上级
f5477d31
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
141 addition
and
87 deletion
+141
-87
.notebook/dataloader.ipynb
.notebook/dataloader.ipynb
+2
-2
.notebook/train_test.ipynb
.notebook/train_test.ipynb
+5
-5
deepspeech/__init__.py
deepspeech/__init__.py
+25
-0
deepspeech/exps/deepspeech2/bin/tune.py
deepspeech/exps/deepspeech2/bin/tune.py
+1
-1
deepspeech/io/collator.py
deepspeech/io/collator.py
+6
-7
deepspeech/models/deepspeech2.py
deepspeech/models/deepspeech2.py
+2
-2
deepspeech/models/u2.py
deepspeech/models/u2.py
+42
-16
deepspeech/modules/decoder.py
deepspeech/modules/decoder.py
+2
-2
deepspeech/modules/loss.py
deepspeech/modules/loss.py
+5
-1
deepspeech/modules/mask.py
deepspeech/modules/mask.py
+46
-46
tests/network_test.py
tests/network_test.py
+5
-5
未找到文件。
.notebook/dataloader.ipynb
浏览文件 @
498104b0
...
...
@@ -338,7 +338,7 @@
}
],
"source": [
"for idx, (audio,
text, audio_len
, text_len) in enumerate(batch_reader()):\n",
"for idx, (audio,
audio_len, text
, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join( chr(i) for i in text[0][:int(text_len[0])] ))\n",
" print(\"test raw\", ''.join( chr(i) for i in text[-1][:int(text_len[-1])] ))\n",
...
...
@@ -386,4 +386,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
\ No newline at end of file
.notebook/train_test.ipynb
浏览文件 @
498104b0
...
...
@@ -249,7 +249,7 @@
}
],
"source": [
" for idx, (audio,
text, audio_len
, text_len) in enumerate(batch_reader()):\n",
" for idx, (audio,
audio_len, text
, text_len) in enumerate(batch_reader()):\n",
" print('test', text)\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[0]))\n",
" print(\"test raw\", ''.join(batch_reader.dataset.vocab_list[i] for i in text[-1]))\n",
...
...
@@ -835,7 +835,7 @@
"\n",
" return logits, probs, audio_len\n",
"\n",
" def forward(self, audio,
text, audio_len
, text_len):\n",
" def forward(self, audio,
audio_len, text
, text_len):\n",
" \"\"\"\n",
" audio: shape [B, D, T]\n",
" text: shape [B, T]\n",
...
...
@@ -877,10 +877,10 @@
"metadata": {},
"outputs": [],
"source": [
"audio,
text, audio_len
, text_len = None, None, None, None\n",
"audio,
audio_len, text
, text_len = None, None, None, None\n",
"\n",
"for idx, inputs in enumerate(batch_reader):\n",
" audio,
text, audio_len
, text_len = inputs\n",
" audio,
audio_len, text
, text_len = inputs\n",
"# print(idx)\n",
"# print('a', audio.shape, audio.place)\n",
"# print('t', text)\n",
...
...
@@ -960,7 +960,7 @@
}
],
"source": [
"outputs = dp_model(audio,
text, audio_len
, text_len)\n",
"outputs = dp_model(audio,
audio_len, text
, text_len)\n",
"logits, _, logits_len = outputs\n",
"print('logits len', logits_len)\n",
"loss = loss_fn.forward(logits, text, logits_len, text_len)\n",
...
...
deepspeech/__init__.py
浏览文件 @
498104b0
...
...
@@ -222,6 +222,31 @@ if not hasattr(paddle.Tensor, 'relu'):
logger
.
warn
(
"register user relu to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'relu'
,
paddle
.
nn
.
functional
.
relu
)
def
type_as
(
x
:
paddle
.
Tensor
,
other
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
return
x
.
astype
(
other
.
dtype
)
if
not
hasattr
(
paddle
.
Tensor
,
'type_as'
):
logger
.
warn
(
"register user type_as to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'type_as'
,
type_as
)
def
to
(
x
:
paddle
.
Tensor
,
*
args
,
**
kwargs
)
->
paddle
.
Tensor
:
assert
len
(
args
)
==
1
if
isinstace
(
args
[
0
],
str
):
# dtype
return
x
.
astype
(
args
[
0
])
elif
isinstance
(
args
[
0
],
paddle
.
Tensor
):
#Tensor
return
x
.
astype
(
args
[
0
].
dtype
)
else
:
# Device
return
x
if
not
hasattr
(
paddle
.
Tensor
,
'to'
):
logger
.
warn
(
"register user to to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'to'
,
to
)
########### hcak paddle.nn.functional #############
...
...
deepspeech/exps/deepspeech2/bin/tune.py
浏览文件 @
498104b0
...
...
@@ -103,7 +103,7 @@ def tune(config, args):
trans
.
append
(
''
.
join
([
chr
(
i
)
for
i
in
ids
]))
return
trans
audio
,
text
,
audio_len
,
text_len
=
infer_data
audio
,
audio_len
,
text
,
text_len
=
infer_data
target_transcripts
=
ordid2token
(
text
,
text_len
)
num_ins
+=
audio
.
shape
[
0
]
...
...
deepspeech/io/collator.py
浏览文件 @
498104b0
...
...
@@ -17,6 +17,7 @@ import numpy as np
from
collections
import
namedtuple
from
deepspeech.io.utility
import
pad_sequence
from
deepspeech.utils.tensor_utils
import
IGNORE_ID
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -29,10 +30,6 @@ class SpeechCollator():
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
if ``is_training`` is True, text is token ids else is raw string.
"""
self
.
_is_training
=
is_training
...
...
@@ -48,8 +45,8 @@ class SpeechCollator():
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
text : (B, Umax)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios
=
[]
...
...
@@ -76,7 +73,9 @@ class SpeechCollator():
padded_audios
=
pad_sequence
(
audios
,
padding_value
=
0.0
).
astype
(
np
.
float32
)
#[B, T, D]
padded_texts
=
pad_sequence
(
texts
,
padding_value
=-
1
).
astype
(
np
.
int32
)
audio_lens
=
np
.
array
(
audio_lens
).
astype
(
np
.
int64
)
# (TODO:Hui Zhang) ctc loss does not support int64 labels
padded_texts
=
pad_sequence
(
texts
,
padding_value
=
IGNORE_ID
).
astype
(
np
.
int32
)
text_lens
=
np
.
array
(
text_lens
).
astype
(
np
.
int64
)
return
padded_audios
,
padded_texts
,
audio_len
s
,
text_lens
return
padded_audios
,
audio_lens
,
padded_text
s
,
text_lens
deepspeech/models/deepspeech2.py
浏览文件 @
498104b0
...
...
@@ -168,13 +168,13 @@ class DeepSpeech2Model(nn.Layer):
dropout_rate
=
0.0
,
reduction
=
True
)
def
forward
(
self
,
audio
,
text
,
audio_len
,
text_len
):
def
forward
(
self
,
audio
,
audio_len
,
text
,
text_len
):
"""Compute Model loss
Args:
audio (Tenosr): [B, T, D]
text (Tensor): [B, U]
audio_len (Tensor): [B]
text (Tensor): [B, U]
text_len (Tensor): [B]
Returns:
...
...
deepspeech/models/u2.py
浏览文件 @
498104b0
...
...
@@ -28,7 +28,12 @@ from paddle import jit
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
from
paddle.nn
import
initializer
as
I
from
paddle.nn.utils.rnn
import
pad_sequence
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
mask_finished_preds
from
deepspeech.modules.mask
import
mask_finished_scores
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.cmvn
import
GlobalCMVN
from
deepspeech.modules.encoder
import
ConformerEncoder
...
...
@@ -36,10 +41,6 @@ from deepspeech.modules.encoder import TransformerEncoder
from
deepspeech.modules.ctc
import
CTCDecoder
from
deepspeech.modules.decoder
import
TransformerDecoder
from
deepspeech.modules.label_smoothing_loss
import
LabelSmoothingLoss
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
mask_finished_preds
from
deepspeech.modules.mask
import
mask_finished_scores
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.utils
import
checkpoint
from
deepspeech.utils
import
layer_tools
...
...
@@ -101,6 +102,8 @@ class U2Model(nn.Module):
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
Returns:
total_loss, attention_loss, ctc_loss
"""
assert
text_lengths
.
dim
()
==
1
,
text_lengths
.
shape
# Check that batch_size is unified
...
...
@@ -109,21 +112,19 @@ class U2Model(nn.Module):
text
.
shape
,
text_lengths
.
shape
)
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
#[B, 1, T] -> [B]
# 2a. Attention-decoder branch
loss_att
=
None
if
self
.
ctc_weight
!=
1.0
:
loss_att
,
acc_att
=
self
.
_calc_att_loss
(
encoder_out
,
encoder_mask
,
text
,
text_lengths
)
else
:
loss_att
=
None
# 2b. CTC branch
loss_ctc
=
None
if
self
.
ctc_weight
!=
0.0
:
loss_ctc
=
self
.
ctc
(
encoder_out
,
encoder_out_lens
,
text
,
text_lengths
)
else
:
loss_ctc
=
None
if
loss_ctc
is
None
:
loss
=
loss_att
...
...
@@ -139,6 +140,17 @@ class U2Model(nn.Module):
encoder_mask
:
paddle
.
Tensor
,
ys_pad
:
paddle
.
Tensor
,
ys_pad_lens
:
paddle
.
Tensor
,
)
->
Tuple
[
paddle
.
Tensor
,
float
]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad
,
ys_out_pad
=
add_sos_eos
(
ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
...
...
@@ -163,6 +175,20 @@ class U2Model(nn.Module):
num_decoding_left_chunks
:
int
=-
1
,
simulate_streaming
:
bool
=
False
,
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Encoder pass.
Args:
speech (paddle.Tensor): [B, Tmax, D]
speech_lengths (paddle.Tensor): [B]
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
# 1. Encoder
if
simulate_streaming
and
decoding_chunk_size
>
0
:
...
...
@@ -205,7 +231,7 @@ class U2Model(nn.Module):
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
device
=
speech
.
devi
ce
device
=
speech
.
pla
ce
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size and N = beam_size
...
...
@@ -223,14 +249,14 @@ class U2Model(nn.Module):
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
1
,
maxlen
)
# (B*N, 1, max_len)
hyps
=
torch
.
ones
(
[
running_size
,
1
],
dtype
=
torch
.
long
,
device
=
device
).
fill_
(
self
.
sos
)
# (B*N, 1)
scores
=
paddle
.
tensor
(
[
0.0
]
+
[
-
float
(
'inf'
)]
*
(
beam_size
-
1
),
dtype
=
torch
.
float
)
hyps
=
paddle
.
ones
(
[
running_size
,
1
],
dtype
=
paddle
.
long
).
fill_
(
self
.
sos
)
# (B*N, 1)
# log scale score
scores
=
paddle
.
t
o_t
ensor
(
[
0.0
]
+
[
-
float
(
'inf'
)]
*
(
beam_size
-
1
),
dtype
=
paddle
.
float
)
scores
=
scores
.
to
(
device
).
repeat
([
batch_size
]).
unsqueeze
(
1
).
to
(
device
)
# (B*N, 1)
end_flag
=
torch
.
zeros_like
(
scores
,
dtype
=
torch
.
bool
,
device
=
device
)
end_flag
=
paddle
.
zeros_like
(
scores
,
dtype
=
paddle
.
bool
)
# (B*N, 1
)
cache
:
Optional
[
List
[
paddle
.
Tensor
]]
=
None
# 2. Decoder forward step by step
for
i
in
range
(
1
,
maxlen
+
1
):
...
...
deepspeech/modules/decoder.py
浏览文件 @
498104b0
...
...
@@ -152,12 +152,12 @@ class TransformerDecoder(nn.Module):
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out
, maxlen_out
)
dtype=paddle.bool
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out
, token)
y.shape` is (batch
, token)
"""
x
,
_
=
self
.
embed
(
tgt
)
new_cache
=
[]
...
...
deepspeech/modules/loss.py
浏览文件 @
498104b0
...
...
@@ -88,7 +88,10 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool): True, normalize loss by sequence length; False, normalize loss by batch size. Defaults to False.
normalize_length (bool):
True, normalize loss by sequence length;
False, normalize loss by batch size.
Defaults to False.
"""
super
().
__init__
()
self
.
size
=
size
...
...
@@ -103,6 +106,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
...
...
deepspeech/modules/mask.py
浏览文件 @
498104b0
...
...
@@ -50,6 +50,52 @@ def sequence_mask(x_len, max_len=None, dtype='float32'):
return
mask
def
make_pad_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size
=
int
(
lengths
.
shape
[
0
])
max_len
=
int
(
lengths
.
max
())
seq_range
=
paddle
.
arange
(
0
,
max_len
,
dtype
=
paddle
.
int64
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
([
batch_size
,
max_len
])
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
def
make_non_pad_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return
~
make_pad_mask
(
lengths
)
def
subsequent_mask
(
size
:
int
)
->
paddle
.
Tensor
:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
...
...
@@ -170,52 +216,6 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
return
chunk_masks
def
make_pad_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size
=
int
(
lengths
.
shape
[
0
])
max_len
=
int
(
lengths
.
max
())
seq_range
=
paddle
.
arange
(
0
,
max_len
,
dtype
=
paddle
.
int64
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
([
batch_size
,
max_len
])
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
def
make_non_pad_mask
(
lengths
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
lengths (paddle.Tensor): Batch of lengths (B,).
Returns:
paddle.Tensor: mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return
~
make_pad_mask
(
lengths
)
def
mask_finished_scores
(
score
:
paddle
.
Tensor
,
flag
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""
...
...
tests/network_test.py
浏览文件 @
498104b0
...
...
@@ -46,7 +46,7 @@ if __name__ == '__main__':
rnn_size
=
1024
,
use_gru
=
False
,
share_rnn_weights
=
False
,
)
logits
,
probs
,
logits_len
=
model
(
audio
,
text
,
audio_len
,
text_len
)
logits
,
probs
,
logits_len
=
model
(
audio
,
audio_len
,
text
,
text_len
)
print
(
'probs.shape'
,
probs
.
shape
)
print
(
"-----------------"
)
...
...
@@ -58,7 +58,7 @@ if __name__ == '__main__':
rnn_size
=
1024
,
use_gru
=
True
,
share_rnn_weights
=
False
,
)
logits
,
probs
,
logits_len
=
model2
(
audio
,
text
,
audio_len
,
text_len
)
logits
,
probs
,
logits_len
=
model2
(
audio
,
audio_len
,
text
,
text_len
)
print
(
'probs.shape'
,
probs
.
shape
)
print
(
"-----------------"
)
...
...
@@ -70,7 +70,7 @@ if __name__ == '__main__':
rnn_size
=
1024
,
use_gru
=
False
,
share_rnn_weights
=
True
,
)
logits
,
probs
,
logits_len
=
model3
(
audio
,
text
,
audio_len
,
text_len
)
logits
,
probs
,
logits_len
=
model3
(
audio
,
audio_len
,
text
,
text_len
)
print
(
'probs.shape'
,
probs
.
shape
)
print
(
"-----------------"
)
...
...
@@ -82,7 +82,7 @@ if __name__ == '__main__':
rnn_size
=
1024
,
use_gru
=
True
,
share_rnn_weights
=
True
,
)
logits
,
probs
,
logits_len
=
model4
(
audio
,
text
,
audio_len
,
text_len
)
logits
,
probs
,
logits_len
=
model4
(
audio
,
audio_len
,
text
,
text_len
)
print
(
'probs.shape'
,
probs
.
shape
)
print
(
"-----------------"
)
...
...
@@ -94,6 +94,6 @@ if __name__ == '__main__':
rnn_size
=
1024
,
use_gru
=
False
,
share_rnn_weights
=
False
,
)
logits
,
probs
,
logits_len
=
model5
(
audio
,
text
,
audio_len
,
text_len
)
logits
,
probs
,
logits_len
=
model5
(
audio
,
audio_len
,
text
,
text_len
)
print
(
'probs.shape'
,
probs
.
shape
)
print
(
"-----------------"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录