Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
281d46da
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
281d46da
编写于
4月 16, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix logger
上级
156ccb94
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
13 addition
and
13 deletion
+13
-13
deepspeech/__init__.py
deepspeech/__init__.py
+0
-2
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+6
-4
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+4
-4
deepspeech/models/u2.py
deepspeech/models/u2.py
+3
-3
未找到文件。
deepspeech/__init__.py
浏览文件 @
281d46da
...
...
@@ -410,13 +410,11 @@ def ctc_loss(logits,
input_lengths
,
label_lengths
)
loss_out
=
paddle
.
fluid
.
layers
.
squeeze
(
loss_out
,
[
-
1
])
logger
.
debug
(
f
"warpctc loss:
{
loss_out
}
/
{
loss_out
.
shape
}
"
)
assert
reduction
in
[
'mean'
,
'sum'
,
'none'
]
if
reduction
==
'mean'
:
loss_out
=
paddle
.
mean
(
loss_out
/
label_lengths
)
elif
reduction
==
'sum'
:
loss_out
=
paddle
.
sum
(
loss_out
)
logger
.
debug
(
f
"ctc loss:
{
loss_out
}
"
)
return
loss_out
...
...
deepspeech/exps/u2/model.py
浏览文件 @
281d46da
...
...
@@ -89,8 +89,9 @@ class U2Trainer(Trainer):
if
(
batch_index
+
1
)
%
train_conf
.
accum_grad
==
0
:
if
dist
.
get_rank
()
==
0
and
self
.
visualizer
:
losses_np
.
update
({
"lr"
:
self
.
lr_scheduler
()})
self
.
visualizer
.
add_scalars
(
"step"
,
losses_np
,
self
.
iteration
)
losses_np_v
=
losses_np
.
copy
()
losses_np_v
.
update
({
"lr"
:
self
.
lr_scheduler
()})
self
.
visualizer
.
add_scalars
(
"step"
,
losses_np_v
,
self
.
iteration
)
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_grad
()
self
.
lr_scheduler
.
step
()
...
...
@@ -171,8 +172,9 @@ class U2Trainer(Trainer):
logger
.
info
(
msg
)
if
self
.
visualizer
:
valid_losses
.
update
({
"lr"
:
self
.
lr_scheduler
()})
self
.
visualizer
.
add_scalars
(
'epoch'
,
valid_losses
,
self
.
epoch
)
valid_losses_v
=
valid_losses
.
copy
()
valid_losses_v
.
update
({
"lr"
:
self
.
lr_scheduler
()})
self
.
visualizer
.
add_scalars
(
'epoch'
,
valid_losses_v
,
self
.
epoch
)
return
valid_losses
def
setup_dataloader
(
self
):
...
...
deepspeech/io/dataset.py
浏览文件 @
281d46da
...
...
@@ -297,13 +297,13 @@ class ManifestDataset(Dataset):
else
:
speech_segment
=
SpeechSegment
.
from_file
(
audio_file
,
transcript
)
load_wav_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"load wav time:
{
load_wav_time
}
"
)
#
logger.debug(f"load wav time: {load_wav_time}")
# audio augment
start_time
=
time
.
time
()
self
.
_augmentation_pipeline
.
transform_audio
(
speech_segment
)
audio_aug_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"audio augmentation time:
{
audio_aug_time
}
"
)
#
logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time
=
time
.
time
()
specgram
,
transcript_part
=
self
.
_speech_featurizer
.
featurize
(
...
...
@@ -311,13 +311,13 @@ class ManifestDataset(Dataset):
if
self
.
_normalizer
:
specgram
=
self
.
_normalizer
.
apply
(
specgram
)
feature_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"audio & test feature time:
{
feature_time
}
"
)
#
logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment
start_time
=
time
.
time
()
specgram
=
self
.
_augmentation_pipeline
.
transform_feature
(
specgram
)
feature_aug_time
=
time
.
time
()
-
start_time
logger
.
debug
(
f
"audio feature augmentation time:
{
feature_aug_time
}
"
)
#
logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return
specgram
,
transcript_part
def
_instance_reader_creator
(
self
,
manifest
):
...
...
deepspeech/models/u2.py
浏览文件 @
281d46da
...
...
@@ -159,7 +159,7 @@ class U2BaseModel(nn.Module):
start
=
time
.
time
()
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_time
=
time
.
time
()
-
start
logger
.
debug
(
f
"encoder time:
{
encoder_time
}
"
)
#
logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
cast
(
paddle
.
int64
).
sum
(
...
...
@@ -172,7 +172,7 @@ class U2BaseModel(nn.Module):
loss_att
,
acc_att
=
self
.
_calc_att_loss
(
encoder_out
,
encoder_mask
,
text
,
text_lengths
)
decoder_time
=
time
.
time
()
-
start
logger
.
debug
(
f
"decoder time:
{
decoder_time
}
"
)
#
logger.debug(f"decoder time: {decoder_time}")
# 2b. CTC branch
loss_ctc
=
None
...
...
@@ -181,7 +181,7 @@ class U2BaseModel(nn.Module):
loss_ctc
=
self
.
ctc
(
encoder_out
,
encoder_out_lens
,
text
,
text_lengths
)
ctc_time
=
time
.
time
()
-
start
logger
.
debug
(
f
"ctc time:
{
ctc_time
}
"
)
#
logger.debug(f"ctc time: {ctc_time}")
if
loss_ctc
is
None
:
loss
=
loss_att
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录