Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
220c9443
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
220c9443
编写于
4月 07, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
5e7e582d
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
75 addition
and
30 deletion
+75
-30
deepspeech/__init__.py
deepspeech/__init__.py
+4
-2
deepspeech/models/deepspeech2.py
deepspeech/models/deepspeech2.py
+1
-1
deepspeech/models/u2.py
deepspeech/models/u2.py
+4
-1
deepspeech/modules/conformer_convolution.py
deepspeech/modules/conformer_convolution.py
+1
-1
deepspeech/modules/ctc.py
deepspeech/modules/ctc.py
+2
-2
deepspeech/modules/decoder.py
deepspeech/modules/decoder.py
+7
-3
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+5
-2
deepspeech/modules/loss.py
deepspeech/modules/loss.py
+4
-3
deepspeech/modules/mask.py
deepspeech/modules/mask.py
+7
-1
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+30
-10
tests/u2_model_test.py
tests/u2_model_test.py
+10
-4
未找到文件。
deepspeech/__init__.py
浏览文件 @
220c9443
...
...
@@ -168,6 +168,8 @@ if not hasattr(paddle.Tensor, 'new_full'):
def
eq
(
xs
:
paddle
.
Tensor
,
ys
:
Union
[
paddle
.
Tensor
,
float
])
->
paddle
.
Tensor
:
if
convert_dtype_to_string
(
xs
.
dtype
)
==
paddle
.
bool
:
xs
=
xs
.
astype
(
paddle
.
int
)
return
xs
.
equal
(
paddle
.
to_tensor
(
ys
,
dtype
=
convert_dtype_to_string
(
xs
.
dtype
),
place
=
xs
.
place
))
...
...
@@ -262,7 +264,7 @@ def masked_fill_(xs: paddle.Tensor,
mask
=
mask
.
broadcast_to
(
bshape
)
trues
=
paddle
.
ones_like
(
xs
)
*
value
ret
=
paddle
.
where
(
mask
,
trues
,
xs
)
paddle
.
assign
(
ret
,
output
=
xs
)
paddle
.
assign
(
ret
.
detach
()
,
output
=
xs
)
if
not
hasattr
(
paddle
.
Tensor
,
'masked_fill_'
):
...
...
@@ -273,7 +275,7 @@ if not hasattr(paddle.Tensor, 'masked_fill_'):
def
fill_
(
xs
:
paddle
.
Tensor
,
value
:
Union
[
float
,
int
]):
val
=
paddle
.
full_like
(
xs
,
value
)
paddle
.
assign
(
val
,
output
=
xs
)
paddle
.
assign
(
val
.
detach
()
,
output
=
xs
)
if
not
hasattr
(
paddle
.
Tensor
,
'fill_'
):
...
...
deepspeech/models/deepspeech2.py
浏览文件 @
220c9443
...
...
@@ -162,8 +162,8 @@ class DeepSpeech2Model(nn.Layer):
assert
(
self
.
encoder
.
output_size
==
rnn_size
*
2
)
self
.
decoder
=
CTCDecoder
(
enc_n_units
=
self
.
encoder
.
output_size
,
odim
=
dict_size
+
1
,
# <blank> is append after vocab
enc_n_units
=
self
.
encoder
.
output_size
,
blank_id
=
dict_size
,
# last token is <blank>
dropout_rate
=
0.0
,
reduction
=
True
)
...
...
deepspeech/models/u2.py
浏览文件 @
220c9443
...
...
@@ -112,7 +112,10 @@ class U2Model(nn.Module):
text
.
shape
,
text_lengths
.
shape
)
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
#[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
astype
(
paddle
.
int64
).
sum
(
1
)
#[B, 1, T] -> [B]
# 2a. Attention-decoder branch
loss_att
=
None
...
...
deepspeech/modules/conformer_convolution.py
浏览文件 @
220c9443
...
...
@@ -139,7 +139,7 @@ class ConvolutionModule(nn.Layer):
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache
=
paddle
.
to_tensor
([
0.0
],
dtype
=
x
.
dtype
,
place
=
x
.
plac
e
)
new_cache
=
paddle
.
zeros
([
1
],
dtype
=
x
.
dtyp
e
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
...
...
deepspeech/modules/ctc.py
浏览文件 @
220c9443
...
...
@@ -34,16 +34,16 @@ __all__ = ['CTCDecoder']
class
CTCDecoder
(
nn
.
Layer
):
def
__init__
(
self
,
enc_n_units
,
odim
,
enc_n_units
,
blank_id
=
0
,
dropout_rate
:
float
=
0.0
,
reduction
:
bool
=
True
):
"""CTC decoder
Args:
odim ([int]): text vocabulary size
enc_n_units ([int]): encoder output dimention
vocab_size ([int]): text vocabulary size
dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar
"""
...
...
deepspeech/modules/decoder.py
浏览文件 @
220c9443
...
...
@@ -26,7 +26,7 @@ from deepspeech.modules.decoder_layer import DecoderLayer
from
deepspeech.modules.embedding
import
PositionalEncoding
from
deepspeech.modules.positionwise_feed_forward
import
PositionwiseFeedForward
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.mask
import
make_pad_mask
from
deepspeech.modules.mask
import
make_
non_
pad_mask
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -124,7 +124,9 @@ class TransformerDecoder(nn.Module):
# m: (1, L, L)
m
=
subsequent_mask
(
tgt_mask
.
size
(
-
1
)).
unsqueeze
(
0
)
# tgt_mask: (B, L, L)
tgt_mask
=
tgt_mask
&
m
# TODO(Hui Zhang): not support & for tensor
#tgt_mask = tgt_mask & m
tgt_mask
=
tgt_mask
.
logical_and
(
m
)
x
,
_
=
self
.
embed
(
tgt
)
for
layer
in
self
.
decoders
:
...
...
@@ -135,7 +137,9 @@ class TransformerDecoder(nn.Module):
if
self
.
use_output_layer
:
x
=
self
.
output_layer
(
x
)
olens
=
tgt_mask
.
sum
(
1
)
#TODO(Hui Zhang): reduce_sum not support bool type
#olens = tgt_mask.sum(1)
olens
=
tgt_mask
.
astype
(
paddle
.
int
).
sum
(
1
)
return
x
,
olens
def
forward_one_step
(
...
...
deepspeech/modules/encoder.py
浏览文件 @
220c9443
...
...
@@ -155,12 +155,15 @@ class BaseEncoder(nn.Layer):
encoder output tensor, lens and mask
"""
masks
=
make_non_pad_mask
(
xs_lens
).
unsqueeze
(
1
)
# (B, 1, L)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad
=
masks
.
logical_not
()
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
.
type_as
(
xs
),
offset
=
0
)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks
=
masks
.
astype
(
paddle
.
bool
)
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad
=
masks
.
logical_not
()
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
decoding_chunk_size
,
self
.
static_chunk_size
,
...
...
deepspeech/modules/loss.py
浏览文件 @
220c9443
...
...
@@ -117,13 +117,12 @@ class LabelSmoothingLoss(nn.Layer):
B
,
T
,
D
=
paddle
.
shape
(
x
)
assert
D
==
self
.
size
x
=
x
.
reshape
((
-
1
,
self
.
size
))
target
=
target
.
reshape
(
-
1
)
target
=
target
.
reshape
(
[
-
1
]
)
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist
=
paddle
.
full_like
(
x
,
self
.
smoothing
/
(
self
.
size
-
1
))
ignore
=
target
==
self
.
padding_idx
# (B,)
ignore
=
ignore
.
cast
(
target
.
dtype
)
#target = target * (1 - ignore) # avoid -1 index
target
=
target
.
masked_fill
(
ignore
,
0
)
# avoid -1 index
...
...
@@ -131,7 +130,9 @@ class LabelSmoothingLoss(nn.Layer):
kl
=
self
.
criterion
(
F
.
log_softmax
(
x
,
axis
=
1
),
true_dist
)
total
=
len
(
target
)
-
int
(
ignore
.
sum
())
#TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum())
total
=
len
(
target
)
-
int
(
ignore
.
type_as
(
target
).
sum
())
denom
=
total
if
self
.
normalize_length
else
B
#numer = (kl * (1 - ignore)).sum()
numer
=
kl
.
masked_fill
(
ignore
.
unsqueeze
(
1
),
0
).
sum
()
...
...
deepspeech/modules/mask.py
浏览文件 @
220c9443
...
...
@@ -97,6 +97,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
#TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~
return
make_pad_mask
(
lengths
).
logical_not
()
...
...
@@ -119,7 +120,12 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]]
"""
ret
=
paddle
.
ones
([
size
,
size
],
dtype
=
paddle
.
bool
)
return
paddle
.
tril
(
ret
)
#TODO(Hui Zhang): tril not support bool
#return paddle.tril(ret)
ret
=
ret
.
astype
(
paddle
.
float
)
ret
=
paddle
.
tril
(
ret
)
ret
=
ret
.
astype
(
paddle
.
bool
)
return
ret
def
subsequent_chunk_mask
(
...
...
deepspeech/utils/tensor_utils.py
浏览文件 @
220c9443
...
...
@@ -115,14 +115,28 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos
=
paddle
.
to_tensor
(
[
sos
],
dtype
=
paddle
.
long
,
stop_gradient
=
True
,
place
=
ys_pad
.
place
)
_eos
=
paddle
.
to_tensor
(
[
eos
],
dtype
=
paddle
.
long
,
stop_gradient
=
True
,
place
=
ys_pad
.
place
)
ys
=
[
y
[
y
!=
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
ys_in
=
[
paddle
.
cat
([
_sos
,
y
],
dim
=
0
)
for
y
in
ys
]
ys_out
=
[
paddle
.
cat
([
y
,
_eos
],
dim
=
0
)
for
y
in
ys
]
return
pad_sequence
(
ys_in
,
padding_value
=
eos
),
pad_sequence
(
ys_out
,
padding_value
=
ignore_id
)
# TODO(Hui Zhang): using comment code,
#_sos = paddle.to_tensor(
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#_eos = paddle.to_tensor(
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
B
=
ys_pad
.
size
(
0
)
_sos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
sos
_eos
=
paddle
.
ones
([
B
,
1
],
dtype
=
ys_pad
.
dtype
)
*
eos
ys_in
=
paddle
.
cat
([
_sos
,
ys_pad
],
dim
=
1
)
mask_pad
=
(
ys_in
==
ignore_id
)
ys_in
=
ys_in
.
masked_fill
(
mask_pad
,
eos
)
ys_out
=
paddle
.
cat
([
ys_pad
,
_eos
],
dim
=
1
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
eos
)
mask_eos
=
(
ys_in
==
ignore_id
)
ys_out
=
ys_out
.
masked_fill
(
mask_eos
,
eos
)
ys_out
=
ys_out
.
masked_fill
(
mask_pad
,
ignore_id
)
return
ys_in
,
ys_out
def
th_accuracy
(
pad_outputs
:
paddle
.
Tensor
,
...
...
@@ -139,7 +153,13 @@ def th_accuracy(pad_outputs: paddle.Tensor,
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
size
(
0
),
pad_targets
.
size
(
1
),
pad_outputs
.
size
(
1
)).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
numerator
=
paddle
.
sum
(
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator
=
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
denominator
=
paddle
.
sum
(
mask
)
numerator
=
paddle
.
sum
(
numerator
.
type_as
(
pad_targets
))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
return
float
(
numerator
)
/
float
(
denominator
)
tests/u2_model_test.py
浏览文件 @
220c9443
...
...
@@ -86,8 +86,11 @@ class TestU2Model(unittest.TestCase):
cfg
.
freeze
()
model
=
U2TransformerModel
(
cfg
)
summary
(
model
,
None
)
output
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
print
(
output
)
total_loss
,
attention_loss
,
ctc_loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
total_loss
.
numel
(),
1
)
self
.
assertEqual
(
attention_loss
.
numel
(),
1
)
self
.
assertEqual
(
ctc_loss
.
numel
(),
1
)
def
test_conformer
(
self
):
conf_str
=
"""
...
...
@@ -135,8 +138,11 @@ class TestU2Model(unittest.TestCase):
cfg
.
freeze
()
model
=
U2ConformerModel
(
cfg
)
summary
(
model
,
None
)
output
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
print
(
output
)
total_loss
,
attention_loss
,
ctc_loss
=
model
(
self
.
audio
,
self
.
audio_len
,
self
.
text
,
self
.
text_len
)
self
.
assertEqual
(
total_loss
.
numel
(),
1
)
self
.
assertEqual
(
attention_loss
.
numel
(),
1
)
self
.
assertEqual
(
ctc_loss
.
numel
(),
1
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录