Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
c8e96d73
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c8e96d73
编写于
10月 02, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bool logical, sum and multiply op; ctc grad norm; support old and new pd api
上级
81f89c53
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
93 addition
and
41 deletion
+93
-41
deepspeech/models/ds2/conv.py
deepspeech/models/ds2/conv.py
+11
-3
deepspeech/models/ds2/rnn.py
deepspeech/models/ds2/rnn.py
+3
-3
deepspeech/models/u2/u2.py
deepspeech/models/u2/u2.py
+9
-3
deepspeech/models/u2_st.py
deepspeech/models/u2_st.py
+6
-2
deepspeech/modules/decoder.py
deepspeech/modules/decoder.py
+6
-2
deepspeech/modules/encoder.py
deepspeech/modules/encoder.py
+2
-1
deepspeech/modules/loss.py
deepspeech/modules/loss.py
+34
-19
deepspeech/modules/mask.py
deepspeech/modules/mask.py
+12
-4
deepspeech/utils/tensor_utils.py
deepspeech/utils/tensor_utils.py
+8
-2
tests/mask_test.py
tests/mask_test.py
+2
-2
未找到文件。
deepspeech/models/ds2/conv.py
浏览文件 @
c8e96d73
...
...
@@ -41,6 +41,13 @@ def conv_output_size(I, F, P, S):
return
(
I
-
F
+
2
*
P
-
S
)
//
S
# receptive field calculator
# https://fomoro.com/research/article/receptive-field-calculator
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
# https://distill.pub/2019/computing-receptive-fields/
# Rl-1 = Sl * Rl + (Kl - Sl)
class
ConvBn
(
nn
.
Layer
):
"""Convolution layer with batch normalization.
...
...
@@ -106,9 +113,10 @@ class ConvBn(nn.Layer):
# reset padding part to 0
masks
=
make_non_pad_mask
(
x_len
)
#[B, T]
masks
=
masks
.
unsqueeze
(
1
).
unsqueeze
(
1
)
# [B, 1, 1, T]
# https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x
=
x
*
masks
# TODO(Hui Zhang): not support bool multiply
# masks = masks.type_as(x)
masks
=
masks
.
astype
(
x
.
dtype
)
x
=
x
.
multiply
(
masks
)
return
x
,
x_len
...
...
deepspeech/models/ds2/rnn.py
浏览文件 @
c8e96d73
...
...
@@ -308,8 +308,8 @@ class RNNStack(nn.Layer):
x
,
x_len
=
rnn
(
x
,
x_len
)
masks
=
make_non_pad_mask
(
x_len
)
#[B, T]
masks
=
masks
.
unsqueeze
(
-
1
)
# [B, T, 1]
#
https://github.com/PaddlePaddle/Paddle/pull/29265
# rhs will type promote to lhs
x
=
x
*
masks
#
TODO(Hui Zhang): not support bool multiply
masks
=
masks
.
astype
(
x
.
dtype
)
x
=
x
.
multiply
(
masks
)
return
x
,
x_len
deepspeech/models/u2/u2.py
浏览文件 @
c8e96d73
...
...
@@ -164,7 +164,10 @@ class U2BaseModel(nn.Layer):
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_time
=
time
.
time
()
-
start
#logger.debug(f"encoder time: {encoder_time}")
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
#[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
cast
(
paddle
.
int64
).
sum
(
1
)
#[B, 1, T] -> [B]
# 2a. Attention-decoder branch
loss_att
=
None
...
...
@@ -319,7 +322,8 @@ class U2BaseModel(nn.Layer):
# 2. Decoder forward step by step
for
i
in
range
(
1
,
maxlen
+
1
):
# Stop if all batch and all beam produce eos
if
end_flag
.
sum
()
==
running_size
:
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if
end_flag
.
cast
(
paddle
.
int64
).
sum
()
==
running_size
:
break
# 2.1 Forward decoder step
...
...
@@ -405,7 +409,9 @@ class U2BaseModel(nn.Layer):
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
maxlen
=
encoder_out
.
shape
[
1
]
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
# (TODO Hui Zhang): bool no support reduce_sum
# encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
astype
(
paddle
.
int
).
sum
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (B, maxlen, vocab_size)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
axis
=
2
)
# (B, maxlen, 1)
...
...
deepspeech/models/u2_st.py
浏览文件 @
c8e96d73
...
...
@@ -165,7 +165,10 @@ class U2STBaseModel(nn.Layer):
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_time
=
time
.
time
()
-
start
#logger.debug(f"encoder time: {encoder_time}")
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
#[B, 1, T] -> [B]
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
cast
(
paddle
.
int64
).
sum
(
1
)
#[B, 1, T] -> [B]
# 2a. ST-decoder branch
start
=
time
.
time
()
...
...
@@ -362,7 +365,8 @@ class U2STBaseModel(nn.Layer):
# 2. Decoder forward step by step
for
i
in
range
(
1
,
maxlen
+
1
):
# Stop if all batch and all beam produce eos
if
end_flag
.
sum
()
==
running_size
:
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if
end_flag
.
cast
(
paddle
.
int64
).
sum
()
==
running_size
:
break
# 2.1 Forward decoder step
...
...
deepspeech/modules/decoder.py
浏览文件 @
c8e96d73
...
...
@@ -124,7 +124,9 @@ class TransformerDecoder(nn.Layer):
# m: (1, L, L)
m
=
subsequent_mask
(
tgt_mask
.
shape
[
-
1
]).
unsqueeze
(
0
)
# tgt_mask: (B, L, L)
tgt_mask
=
tgt_mask
&
m
# TODO(Hui Zhang): not support & for tensor
# tgt_mask = tgt_mask & m
tgt_mask
=
tgt_mask
.
logical_and
(
m
)
x
,
_
=
self
.
embed
(
tgt
)
for
layer
in
self
.
decoders
:
...
...
@@ -135,7 +137,9 @@ class TransformerDecoder(nn.Layer):
if
self
.
use_output_layer
:
x
=
self
.
output_layer
(
x
)
olens
=
tgt_mask
.
sum
(
1
)
# TODO(Hui Zhang): reduce_sum not support bool type
# olens = tgt_mask.sum(1)
olens
=
tgt_mask
.
astype
(
paddle
.
int
).
sum
(
1
)
return
x
,
olens
def
forward_one_step
(
...
...
deepspeech/modules/encoder.py
浏览文件 @
c8e96d73
...
...
@@ -162,7 +162,8 @@ class BaseEncoder(nn.Layer):
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
.
astype
(
xs
.
dtype
),
offset
=
0
)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
masks
=
masks
.
astype
(
paddle
.
bool
)
mask_pad
=
~
masks
#TODO(Hui Zhang): mask_pad = ~masks
mask_pad
=
masks
.
logical_not
()
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
decoding_chunk_size
,
self
.
static_chunk_size
,
...
...
deepspeech/modules/loss.py
浏览文件 @
c8e96d73
...
...
@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
from
functools
import
partial
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
...
...
@@ -32,18 +35,19 @@ class CTCLoss(nn.Layer):
# last token id as blank id
self
.
loss
=
nn
.
CTCLoss
(
blank
=
blank
,
reduction
=
reduction
)
self
.
batch_average
=
batch_average
logger
.
info
(
f
"CTCLoss Loss reduction:
{
reduction
}
, div-bs:
{
batch_average
}
"
)
logger
.
info
(
f
"CTCLoss Grad Norm Type:
{
grad_norm_type
}
"
)
# instance for norm_by_times
# batch for norm_by_batchsize
# frame for norm_by_total_logits_len
assert
grad_norm_type
in
(
'instance'
,
'batch'
,
'frame'
,
None
)
self
.
norm_by_times
=
False
self
.
norm_by_batchsize
=
False
self
.
norm_by_total_logits_len
=
False
logger
.
info
(
f
"CTCLoss Grad Norm Type:
{
grad_norm_type
}
"
)
if
grad_norm_type
==
'instance'
:
if
grad_norm_type
is
None
:
# no grad norm
pass
elif
grad_norm_type
==
'instance'
:
self
.
norm_by_times
=
True
elif
grad_norm_type
==
'batch'
:
self
.
norm_by_batchsize
=
True
...
...
@@ -51,6 +55,22 @@ class CTCLoss(nn.Layer):
self
.
norm_by_total_logits_len
=
True
else
:
raise
ValueError
(
f
"CTCLoss Grad Norm no support
{
grad_norm_type
}
"
)
self
.
kwargs
=
{
"norm_by_times"
:
self
.
norm_by_times
,
"norm_by_batchsize"
:
self
.
norm_by_batchsize
,
"norm_by_total_logits_len"
:
self
.
norm_by_total_logits_len
,
}
# Derive only the args which the func has
try
:
param
=
inspect
.
signature
(
self
.
loss
.
forward
).
parameters
except
ValueError
:
# Some function, e.g. built-in function, are failed
param
=
{}
_kwargs
=
{
k
:
v
for
k
,
v
in
self
.
kwargs
.
items
()
if
k
in
param
}
_notin
=
{
k
:
v
for
k
,
v
in
self
.
kwargs
.
items
()
if
k
not
in
param
}
logger
.
info
(
f
"
{
self
.
loss
}
kwargs:
{
_kwargs
}
, not support:
{
_notin
}
"
)
self
.
loss_fn
=
partial
(
self
.
loss
.
forward
,
**
_kwargs
)
def
forward
(
self
,
logits
,
ys_pad
,
hlens
,
ys_lens
):
"""Compute CTC loss.
...
...
@@ -70,14 +90,7 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D)
logits
=
logits
.
transpose
([
1
,
0
,
2
])
ys_pad
=
ys_pad
.
astype
(
paddle
.
int32
)
loss
=
self
.
loss
(
logits
,
ys_pad
,
hlens
,
ys_lens
,
norm_by_times
=
self
.
norm_by_times
,
norm_by_batchsize
=
self
.
norm_by_batchsize
,
norm_by_total_logits_len
=
self
.
norm_by_total_logits_len
)
loss
=
self
.
loss_fn
(
logits
,
ys_pad
,
hlens
,
ys_lens
)
if
self
.
batch_average
:
# Batch-size average
loss
=
loss
/
B
...
...
@@ -118,8 +131,8 @@ class LabelSmoothingLoss(nn.Layer):
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
True, normalize loss by sequence length;
normalize_length (bool):
True, normalize loss by sequence length;
False, normalize loss by batch size.
Defaults to False.
"""
...
...
@@ -136,7 +149,7 @@ class LabelSmoothingLoss(nn.Layer):
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (paddle.Tensor): prediction (batch, seqlen, class)
target (paddle.Tensor):
...
...
@@ -152,7 +165,7 @@ class LabelSmoothingLoss(nn.Layer):
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist
=
paddle
.
full_like
(
x
,
self
.
smoothing
/
(
self
.
size
-
1
))
ignore
=
(
target
==
self
.
padding_idx
)
# (B,)
ignore
=
target
==
self
.
padding_idx
# (B,)
#TODO(Hui Zhang): target = target * (1 - ignore) # avoid -1 index
target
=
target
.
masked_fill
(
ignore
,
0
)
# avoid -1 index
...
...
@@ -163,8 +176,10 @@ class LabelSmoothingLoss(nn.Layer):
kl
=
self
.
criterion
(
F
.
log_softmax
(
x
,
axis
=
1
),
true_dist
)
total
=
len
(
target
)
-
int
(
ignore
.
sum
())
#TODO(Hui Zhang): sum not support bool type
#total = len(target) - int(ignore.sum())
total
=
len
(
target
)
-
int
(
ignore
.
type_as
(
target
).
sum
())
denom
=
total
if
self
.
normalize_length
else
B
#
TODO(Hui Zhang):
numer = (kl * (1 - ignore)).sum()
#numer = (kl * (1 - ignore)).sum()
numer
=
kl
.
masked_fill
(
ignore
.
unsqueeze
(
1
),
0
).
sum
()
return
numer
/
denom
deepspeech/modules/mask.py
浏览文件 @
c8e96d73
...
...
@@ -69,7 +69,8 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return
~
make_pad_mask
(
lengths
)
#return ~make_pad_mask(lengths)
return
make_pad_mask
(
lengths
).
logical_not
()
def
subsequent_mask
(
size
:
int
)
->
paddle
.
Tensor
:
...
...
@@ -91,7 +92,12 @@ def subsequent_mask(size: int) -> paddle.Tensor:
[1, 1, 1]]
"""
ret
=
paddle
.
ones
([
size
,
size
],
dtype
=
paddle
.
bool
)
return
paddle
.
tril
(
ret
)
#TODO(Hui Zhang): tril not support bool
#return paddle.tril(ret)
ret
=
ret
.
astype
(
paddle
.
float
)
ret
=
paddle
.
tril
(
ret
)
ret
=
ret
.
astype
(
paddle
.
bool
)
return
ret
def
subsequent_chunk_mask
(
...
...
@@ -180,13 +186,15 @@ def add_optional_chunk_mask(xs: paddle.Tensor,
chunk_masks
=
subsequent_chunk_mask
(
xs
.
shape
[
1
],
chunk_size
,
num_left_chunks
)
# (L, L)
chunk_masks
=
chunk_masks
.
unsqueeze
(
0
)
# (1, L, L)
chunk_masks
=
masks
&
chunk_masks
# (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks
=
masks
.
logical_and
(
chunk_masks
)
# (B, L, L)
elif
static_chunk_size
>
0
:
num_left_chunks
=
num_decoding_left_chunks
chunk_masks
=
subsequent_chunk_mask
(
xs
.
shape
[
1
],
static_chunk_size
,
num_left_chunks
)
# (L, L)
chunk_masks
=
chunk_masks
.
unsqueeze
(
0
)
# (1, L, L)
chunk_masks
=
masks
&
chunk_masks
# (B, L, L)
# chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks
=
masks
.
logical_and
(
chunk_masks
)
# (B, L, L)
else
:
chunk_masks
=
masks
return
chunk_masks
...
...
deepspeech/utils/tensor_utils.py
浏览文件 @
c8e96d73
...
...
@@ -183,7 +183,13 @@ def th_accuracy(pad_outputs: paddle.Tensor,
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
shape
[
0
],
pad_targets
.
shape
[
1
],
pad_outputs
.
shape
[
1
]).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
numerator
=
paddle
.
sum
(
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
numerator
=
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
denominator
=
paddle
.
sum
(
mask
)
numerator
=
paddle
.
sum
(
numerator
.
type_as
(
pad_targets
))
#TODO(Hui Zhang): sum not support bool type
# denominator = paddle.sum(mask)
denominator
=
paddle
.
sum
(
mask
.
type_as
(
pad_targets
))
return
float
(
numerator
)
/
float
(
denominator
)
tests/mask_test.py
浏览文件 @
c8e96d73
...
...
@@ -37,13 +37,13 @@ class TestU2Model(unittest.TestCase):
def
test_make_non_pad_mask
(
self
):
res
=
make_non_pad_mask
(
self
.
lengths
)
res2
=
~
make_pad_mask
(
self
.
lengths
)
res2
=
make_pad_mask
(
self
.
lengths
).
logical_not
(
)
self
.
assertSequenceEqual
(
res
.
numpy
().
tolist
(),
self
.
masks
.
tolist
())
self
.
assertSequenceEqual
(
res
.
numpy
().
tolist
(),
res2
.
numpy
().
tolist
())
def
test_make_pad_mask
(
self
):
res
=
make_pad_mask
(
self
.
lengths
)
res1
=
~
make_non_pad_mask
(
self
.
lengths
)
res1
=
make_non_pad_mask
(
self
.
lengths
).
logical_not
(
)
self
.
assertSequenceEqual
(
res
.
numpy
().
tolist
(),
self
.
pad_masks
.
tolist
())
self
.
assertSequenceEqual
(
res
.
numpy
().
tolist
(),
res1
.
tolist
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录