Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
1e37e2cc
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1e37e2cc
编写于
4月 20, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multi gpu valid
上级
77e5641a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
83 addition
and
53 deletion
+83
-53
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+4
-4
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+55
-45
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+19
-1
examples/tiny/s1/local/train.sh
examples/tiny/s1/local/train.sh
+5
-3
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
1e37e2cc
...
...
@@ -67,7 +67,6 @@ class DeepSpeech2Trainer(Trainer):
self
.
iteration
)
self
.
iteration
+=
1
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
def
valid
(
self
):
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
...
...
@@ -84,11 +83,10 @@ class DeepSpeech2Trainer(Trainer):
valid_losses
[
'val_loss'
].
append
(
float
(
loss
))
if
(
i
+
1
)
%
self
.
config
.
training
.
log_interval
==
0
:
valid_losses
[
'val_history_loss'
]
=
total_loss
/
num_seen_utts
# write visual log
valid_losses
=
{
k
:
np
.
mean
(
v
)
for
k
,
v
in
valid_losses
.
items
()}
valid_losses
[
'val_history_loss'
]
=
total_loss
/
num_seen_utts
# logging
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
...
...
@@ -103,6 +101,8 @@ class DeepSpeech2Trainer(Trainer):
self
.
visualizer
.
add_scalar
(
"valid/{}"
.
format
(
k
),
v
,
self
.
iteration
)
logger
.
info
(
'Rank {} Val info val_loss {}'
.
format
(
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
return
total_loss
,
num_seen_utts
def
setup_model
(
self
):
...
...
deepspeech/exps/u2/model.py
浏览文件 @
1e37e2cc
...
...
@@ -109,6 +109,43 @@ class U2Trainer(Trainer):
self
.
visualizer
.
add_scalars
(
"step"
,
losses_np_v
,
self
.
iteration
-
1
)
@
paddle
.
no_grad
()
def
valid
(
self
):
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch
)
if
paddle
.
isfinite
(
loss
):
num_utts
=
batch
[
0
].
shape
[
0
]
num_seen_utts
+=
num_utts
total_loss
+=
float
(
loss
)
*
num_utts
valid_losses
[
'val_loss'
].
append
(
float
(
loss
))
if
attention_loss
:
valid_losses
[
'val_att_loss'
].
append
(
float
(
attention_loss
))
if
ctc_loss
:
valid_losses
[
'val_ctc_loss'
].
append
(
float
(
ctc_loss
))
if
(
i
+
1
)
%
self
.
config
.
training
.
log_interval
==
0
:
valid_losses
=
{
k
:
np
.
mean
(
v
)
for
k
,
v
in
valid_losses
.
items
()}
valid_losses
[
'val_history_loss'
]
=
total_loss
/
num_seen_utts
# logging
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_losses
.
items
())
logger
.
info
(
msg
)
logger
.
info
(
'Rank {} Val info val_loss {}'
.
format
(
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
return
total_loss
,
num_seen_utts
def
train
(
self
):
"""The training process control by step."""
# !!!IMPORTANT!!!
...
...
@@ -148,53 +185,26 @@ class U2Trainer(Trainer):
raise
e
total_loss
,
num_seen_utts
=
self
.
valid
()
self
.
save
(
tag
=
self
.
epoch
,
infos
=
{
'val_loss'
:
total_loss
/
num_seen_utts
})
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
logger
.
info
(
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
if
self
.
visualizer
:
self
.
visualizer
.
add_scalars
(
'epoch'
,
{
'cv_loss'
:
cv_loss
,
'lr'
:
self
.
lr_scheduler
()},
self
.
epoch
)
self
.
save
(
tag
=
self
.
epoch
,
infos
=
{
'val_loss'
:
cv_loss
})
self
.
new_epoch
()
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
def
valid
(
self
):
self
.
model
.
eval
()
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
num_seen_utts
=
1
total_loss
=
0.0
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch
)
if
paddle
.
isfinite
(
loss
):
num_utts
=
batch
[
0
].
shape
[
0
]
num_seen_utts
+=
num_utts
total_loss
+=
float
(
loss
)
*
num_utts
valid_losses
=
{
'val_loss'
:
float
(
loss
)}
if
attention_loss
:
valid_losses
[
'val_att_loss'
]
=
float
(
attention_loss
)
if
ctc_loss
:
valid_losses
[
'val_ctc_loss'
]
=
float
(
ctc_loss
)
if
(
i
+
1
)
%
self
.
config
.
training
.
log_interval
==
0
:
valid_losses
[
'val_history_loss'
]
=
total_loss
/
num_seen_utts
# write visual log
valid_losses
=
{
k
:
np
.
mean
(
v
)
for
k
,
v
in
valid_losses
.
items
()}
# logging
msg
=
f
"Valid: Rank:
{
dist
.
get_rank
()
}
, "
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
i
+
1
,
len
(
self
.
valid_loader
))
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_losses
.
items
())
logger
.
info
(
msg
)
if
self
.
visualizer
:
valid_losses_v
=
valid_losses
.
copy
()
valid_losses_v
.
update
({
"lr"
:
self
.
lr_scheduler
()})
self
.
visualizer
.
add_scalars
(
'epoch'
,
valid_losses_v
,
self
.
epoch
)
return
total_loss
,
num_seen_utts
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
.
defrost
()
...
...
deepspeech/training/trainer.py
浏览文件 @
1e37e2cc
...
...
@@ -202,7 +202,25 @@ class Trainer():
raise
e
total_loss
,
num_seen_utts
=
self
.
valid
()
self
.
save
(
infos
=
{
'val_loss'
:
total_loss
/
num_seen_utts
})
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
logger
.
info
(
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
if
self
.
visualizer
:
self
.
visualizer
.
add_scalars
(
'epoch'
,
{
'cv_loss'
:
cv_loss
,
'lr'
:
self
.
lr_scheduler
()},
self
.
epoch
)
self
.
save
(
infos
=
{
'val_loss'
:
cv_loss
})
self
.
lr_scheduler
.
step
()
self
.
new_epoch
()
...
...
examples/tiny/s1/local/train.sh
浏览文件 @
1e37e2cc
#! /usr/bin/env bash
CUDA_VISIBLE_DEVICES
=
0
\
ngpu
=
$(
echo
${
CUDA_VISIBLE_DEVICES
}
| python
-c
'import sys; a = sys.stdin.read(); print(len(a.split(",")));'
)
echo
"using
$ngpu
gpus..."
python3
-u
${
BIN_DIR
}
/train.py
\
--device
'gpu'
\
--nproc
1
\
--nproc
${
ngpu
}
\
--config
conf/conformer.yaml
\
--output
ckpt
--output
ckpt
-
${
1
}
if
[
$?
-ne
0
]
;
then
echo
"Failed in training!"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录