Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
2480be8e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2480be8e
编写于
9月 10, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
timer info for st,u2 kaldi
上级
28a0a641
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
62 addition
and
56 deletion
+62
-56
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+31
-28
deepspeech/exps/u2_st/model.py
deepspeech/exps/u2_st/model.py
+31
-28
未找到文件。
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
2480be8e
...
...
@@ -32,6 +32,7 @@ from deepspeech.io.dataloader import BatchDataLoader
from
deepspeech.models.u2
import
U2Model
from
deepspeech.training.optimizer
import
OptimizerFactory
from
deepspeech.training.scheduler
import
LRSchedulerFactory
from
deepspeech.training.timer
import
Timer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
ctc_utils
from
deepspeech.utils
import
error_rate
...
...
@@ -190,35 +191,37 @@ class U2Trainer(Trainer):
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
try
:
data_start_time
=
time
.
time
()
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
dataload_time
=
time
.
time
()
-
data_start_time
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
try
:
data_start_time
=
time
.
time
()
except
Exception
as
e
:
logger
.
error
(
e
)
raise
e
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
dataload_time
=
time
.
time
()
-
data_start_time
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
data_start_time
=
time
.
time
()
except
Exception
as
e
:
logger
.
error
(
e
)
raise
e
with
Timer
(
"Eval Time Cost: {}"
):
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
logger
.
info
(
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
...
...
deepspeech/exps/u2_st/model.py
浏览文件 @
2480be8e
...
...
@@ -38,6 +38,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from
deepspeech.models.u2_st
import
U2STModel
from
deepspeech.training.gradclip
import
ClipGradByGlobalNormWithLog
from
deepspeech.training.scheduler
import
WarmupLR
from
deepspeech.training.timer
import
Timer
from
deepspeech.training.trainer
import
Trainer
from
deepspeech.utils
import
bleu_score
from
deepspeech.utils
import
ctc_utils
...
...
@@ -207,35 +208,37 @@ class U2STTrainer(Trainer):
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
try
:
data_start_time
=
time
.
time
()
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
dataload_time
=
time
.
time
()
-
data_start_time
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
with
Timer
(
"Epoch-Train Time Cost: {}"
):
self
.
model
.
train
()
try
:
data_start_time
=
time
.
time
()
except
Exception
as
e
:
logger
.
error
(
e
)
raise
e
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
for
batch_index
,
batch
in
enumerate
(
self
.
train_loader
):
dataload_time
=
time
.
time
()
-
data_start_time
msg
=
"Train: Rank: {}, "
.
format
(
dist
.
get_rank
())
msg
+=
"epoch: {}, "
.
format
(
self
.
epoch
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"batch : {}/{}, "
.
format
(
batch_index
+
1
,
len
(
self
.
train_loader
))
msg
+=
"lr: {:>.8f}, "
.
format
(
self
.
lr_scheduler
())
msg
+=
"data time: {:>.3f}s, "
.
format
(
dataload_time
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
data_start_time
=
time
.
time
()
except
Exception
as
e
:
logger
.
error
(
e
)
raise
e
with
Timer
(
"Eval Time Cost: {}"
):
total_loss
,
num_seen_utts
=
self
.
valid
()
if
dist
.
get_world_size
()
>
1
:
num_seen_utts
=
paddle
.
to_tensor
(
num_seen_utts
)
# the default operator in all_reduce function is sum.
dist
.
all_reduce
(
num_seen_utts
)
total_loss
=
paddle
.
to_tensor
(
total_loss
)
dist
.
all_reduce
(
total_loss
)
cv_loss
=
total_loss
/
num_seen_utts
cv_loss
=
float
(
cv_loss
)
else
:
cv_loss
=
total_loss
/
num_seen_utts
logger
.
info
(
'Epoch {} Val info val_loss {}'
.
format
(
self
.
epoch
,
cv_loss
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录