Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
38ebec33
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
38ebec33
编写于
4月 15, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
using module logger as default
上级
0d38a670
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
36 addition
and
114 deletion
+36
-114
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+16
-42
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+16
-44
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+4
-28
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
38ebec33
...
@@ -63,7 +63,7 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -63,7 +63,7 @@ class DeepSpeech2Trainer(Trainer):
msg
+=
"batch size: {}, "
.
format
(
self
.
config
.
data
.
batch_size
)
msg
+=
"batch size: {}, "
.
format
(
self
.
config
.
data
.
batch_size
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_np
.
items
())
for
k
,
v
in
losses_np
.
items
())
self
.
logger
.
info
(
msg
)
logger
.
info
(
msg
)
if
dist
.
get_rank
()
==
0
and
self
.
visualizer
:
if
dist
.
get_rank
()
==
0
and
self
.
visualizer
:
for
k
,
v
in
losses_np
.
items
():
for
k
,
v
in
losses_np
.
items
():
...
@@ -74,8 +74,7 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -74,8 +74,7 @@ class DeepSpeech2Trainer(Trainer):
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
valid
(
self
):
def
valid
(
self
):
self
.
logger
.
info
(
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
self
.
model
.
eval
()
self
.
model
.
eval
()
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
...
@@ -92,7 +91,7 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -92,7 +91,7 @@ class DeepSpeech2Trainer(Trainer):
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_losses
.
items
())
for
k
,
v
in
valid_losses
.
items
())
self
.
logger
.
info
(
msg
)
logger
.
info
(
msg
)
if
self
.
visualizer
:
if
self
.
visualizer
:
for
k
,
v
in
valid_losses
.
items
():
for
k
,
v
in
valid_losses
.
items
():
...
@@ -115,7 +114,7 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -115,7 +114,7 @@ class DeepSpeech2Trainer(Trainer):
if
self
.
parallel
:
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
layer_tools
.
print_params
(
model
,
self
.
logger
.
info
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
grad_clip
=
ClipGradByGlobalNormWithLog
(
grad_clip
=
ClipGradByGlobalNormWithLog
(
config
.
training
.
global_grad_clip
)
config
.
training
.
global_grad_clip
)
...
@@ -133,7 +132,7 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -133,7 +132,7 @@ class DeepSpeech2Trainer(Trainer):
self
.
model
=
model
self
.
model
=
model
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
lr_scheduler
=
lr_scheduler
self
.
lr_scheduler
=
lr_scheduler
self
.
logger
.
info
(
"Setup model/optimizer/lr_scheduler!"
)
logger
.
info
(
"Setup model/optimizer/lr_scheduler!"
)
def
setup_dataloader
(
self
):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
...
@@ -178,7 +177,7 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -178,7 +177,7 @@ class DeepSpeech2Trainer(Trainer):
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn
)
collate_fn
=
collate_fn
)
self
.
logger
.
info
(
"Setup train/valid Dataloader!"
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
...
@@ -221,11 +220,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -221,11 +220,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
errors_sum
+=
errors
errors_sum
+=
errors
len_refs
+=
len_ref
len_refs
+=
len_ref
num_ins
+=
1
num_ins
+=
1
self
.
logger
.
info
(
logger
.
info
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
(
target
,
result
))
(
target
,
result
))
logger
.
info
(
"Current error rate [%s] = %f"
%
self
.
logger
.
info
(
"Current error rate [%s] = %f"
%
(
(
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
return
dict
(
return
dict
(
errors_sum
=
errors_sum
,
errors_sum
=
errors_sum
,
...
@@ -237,8 +235,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -237,8 +235,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
test
(
self
):
def
test
(
self
):
self
.
logger
.
info
(
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
self
.
model
.
eval
()
self
.
model
.
eval
()
cfg
=
self
.
config
cfg
=
self
.
config
error_rate_type
=
None
error_rate_type
=
None
...
@@ -250,8 +247,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -250,8 +247,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs
+=
metrics
[
'len_refs'
]
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
num_ins
+=
metrics
[
'num_ins'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
self
.
logger
.
info
(
"Error rate [%s] (%d/?) = %f"
%
logger
.
info
(
"Error rate [%s] (%d/?) = %f"
%
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
# logging
# logging
msg
=
"Test: "
msg
=
"Test: "
...
@@ -259,7 +256,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -259,7 +256,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
", Final error rate [%s] (%d/%d) = %f"
%
(
msg
+=
", Final error rate [%s] (%d/%d) = %f"
%
(
error_rate_type
,
num_ins
,
num_ins
,
errors_sum
/
len_refs
)
error_rate_type
,
num_ins
,
num_ins
,
errors_sum
/
len_refs
)
self
.
logger
.
info
(
msg
)
logger
.
info
(
msg
)
def
run_test
(
self
):
def
run_test
(
self
):
self
.
resume_or_scratch
()
self
.
resume_or_scratch
()
...
@@ -298,7 +295,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -298,7 +295,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self
.
setup_output_dir
()
self
.
setup_output_dir
()
self
.
setup_checkpointer
()
self
.
setup_checkpointer
()
self
.
setup_logger
()
self
.
setup_dataloader
()
self
.
setup_dataloader
()
self
.
setup_model
()
self
.
setup_model
()
...
@@ -317,7 +313,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -317,7 +313,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
use_gru
=
config
.
model
.
use_gru
,
use_gru
=
config
.
model
.
use_gru
,
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
share_rnn_weights
=
config
.
model
.
share_rnn_weights
)
self
.
model
=
model
self
.
model
=
model
self
.
logger
.
info
(
"Setup model!"
)
logger
.
info
(
"Setup model!"
)
def
setup_dataloader
(
self
):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
...
@@ -335,7 +331,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -335,7 +331,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
True
))
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
True
))
self
.
logger
.
info
(
"Setup test Dataloader!"
)
logger
.
info
(
"Setup test Dataloader!"
)
def
setup_output_dir
(
self
):
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""Create a directory used for output.
...
@@ -350,25 +346,3 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
...
@@ -350,25 +346,3 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
output_dir
=
output_dir
self
.
output_dir
=
output_dir
def
setup_logger
(
self
):
"""Initialize a text logger to log the experiment.
Each process has its own text logger. The logging message is write to
the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process.
"""
format
=
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
formatter
=
logging
.
Formatter
(
fmt
=
format
,
datefmt
=
'%Y/%m/%d %H:%M:%S'
)
logger
.
setLevel
(
"INFO"
)
# global logger
stdout
=
True
save_path
=
""
logging
.
basicConfig
(
level
=
logging
.
DEBUG
if
stdout
else
logging
.
INFO
,
format
=
format
,
datefmt
=
'%Y/%m/%d %H:%M:%S'
,
filename
=
save_path
if
not
stdout
else
None
)
self
.
logger
=
logger
deepspeech/exps/u2/model.py
浏览文件 @
38ebec33
...
@@ -109,7 +109,7 @@ class U2Trainer(Trainer):
...
@@ -109,7 +109,7 @@ class U2Trainer(Trainer):
msg
+=
"accum: {}, "
.
format
(
train_conf
.
accum_grad
)
msg
+=
"accum: {}, "
.
format
(
train_conf
.
accum_grad
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
losses_np
.
items
())
for
k
,
v
in
losses_np
.
items
())
self
.
logger
.
info
(
msg
)
logger
.
info
(
msg
)
def
train
(
self
):
def
train
(
self
):
"""The training process control by step."""
"""The training process control by step."""
...
@@ -129,8 +129,7 @@ class U2Trainer(Trainer):
...
@@ -129,8 +129,7 @@ class U2Trainer(Trainer):
if
self
.
parallel
:
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
logger
.
info
(
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
self
.
model
.
train
()
try
:
try
:
...
@@ -145,7 +144,7 @@ class U2Trainer(Trainer):
...
@@ -145,7 +144,7 @@ class U2Trainer(Trainer):
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
e
)
logger
.
error
(
e
)
raise
e
raise
e
valid_losses
=
self
.
valid
()
valid_losses
=
self
.
valid
()
...
@@ -156,8 +155,7 @@ class U2Trainer(Trainer):
...
@@ -156,8 +155,7 @@ class U2Trainer(Trainer):
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
valid
(
self
):
def
valid
(
self
):
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
logger
.
info
(
logger
.
info
(
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
f
"Valid Total Examples:
{
len
(
self
.
valid_loader
.
dataset
)
}
"
)
valid_losses
=
defaultdict
(
list
)
valid_losses
=
defaultdict
(
list
)
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
for
i
,
batch
in
enumerate
(
self
.
valid_loader
):
total_loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch
)
total_loss
,
attention_loss
,
ctc_loss
=
self
.
model
(
*
batch
)
...
@@ -175,7 +173,7 @@ class U2Trainer(Trainer):
...
@@ -175,7 +173,7 @@ class U2Trainer(Trainer):
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
msg
+=
', '
.
join
(
'{}: {:>.6f}'
.
format
(
k
,
v
)
for
k
,
v
in
valid_losses
.
items
())
for
k
,
v
in
valid_losses
.
items
())
self
.
logger
.
info
(
msg
)
logger
.
info
(
msg
)
if
self
.
visualizer
:
if
self
.
visualizer
:
for
k
,
v
in
valid_losses
.
items
():
for
k
,
v
in
valid_losses
.
items
():
...
@@ -239,7 +237,7 @@ class U2Trainer(Trainer):
...
@@ -239,7 +237,7 @@ class U2Trainer(Trainer):
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
True
))
collate_fn
=
SpeechCollator
(
keep_transcription_text
=
True
))
self
.
logger
.
info
(
"Setup train/valid/test Dataloader!"
)
logger
.
info
(
"Setup train/valid/test Dataloader!"
)
def
setup_model
(
self
):
def
setup_model
(
self
):
config
=
self
.
config
config
=
self
.
config
...
@@ -253,7 +251,7 @@ class U2Trainer(Trainer):
...
@@ -253,7 +251,7 @@ class U2Trainer(Trainer):
if
self
.
parallel
:
if
self
.
parallel
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
layer_tools
.
print_params
(
model
,
self
.
logger
.
info
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
train_config
=
config
.
training
train_config
=
config
.
training
optim_type
=
train_config
.
optim
optim_type
=
train_config
.
optim
...
@@ -289,7 +287,7 @@ class U2Trainer(Trainer):
...
@@ -289,7 +287,7 @@ class U2Trainer(Trainer):
self
.
model
=
model
self
.
model
=
model
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
lr_scheduler
=
lr_scheduler
self
.
lr_scheduler
=
lr_scheduler
self
.
logger
.
info
(
"Setup model/optimizer/lr_scheduler!"
)
logger
.
info
(
"Setup model/optimizer/lr_scheduler!"
)
class
U2Tester
(
U2Trainer
):
class
U2Tester
(
U2Trainer
):
...
@@ -367,11 +365,10 @@ class U2Tester(U2Trainer):
...
@@ -367,11 +365,10 @@ class U2Tester(U2Trainer):
num_ins
+=
1
num_ins
+=
1
if
fout
:
if
fout
:
fout
.
write
(
result
+
"
\n
"
)
fout
.
write
(
result
+
"
\n
"
)
self
.
logger
.
info
(
logger
.
info
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
(
target
,
result
))
(
target
,
result
))
logger
.
info
(
"Current error rate [%s] = %f"
%
self
.
logger
.
info
(
"Current error rate [%s] = %f"
%
(
(
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
cfg
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
return
dict
(
return
dict
(
errors_sum
=
errors_sum
,
errors_sum
=
errors_sum
,
...
@@ -385,8 +382,7 @@ class U2Tester(U2Trainer):
...
@@ -385,8 +382,7 @@ class U2Tester(U2Trainer):
def
test
(
self
):
def
test
(
self
):
assert
self
.
args
.
result_file
assert
self
.
args
.
result_file
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
logger
.
info
(
logger
.
info
(
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
f
"Test Total Examples:
{
len
(
self
.
test_loader
.
dataset
)
}
"
)
error_rate_type
=
None
error_rate_type
=
None
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
errors_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
...
@@ -398,9 +394,8 @@ class U2Tester(U2Trainer):
...
@@ -398,9 +394,8 @@ class U2Tester(U2Trainer):
len_refs
+=
metrics
[
'len_refs'
]
len_refs
+=
metrics
[
'len_refs'
]
num_ins
+=
metrics
[
'num_ins'
]
num_ins
+=
metrics
[
'num_ins'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
error_rate_type
=
metrics
[
'error_rate_type'
]
self
.
logger
.
info
(
logger
.
info
(
"Error rate [%s] (%d/?) = %f"
%
"Error rate [%s] (%d/?) = %f"
%
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
(
error_rate_type
,
num_ins
,
errors_sum
/
len_refs
))
# logging
# logging
msg
=
"Test: "
msg
=
"Test: "
...
@@ -408,7 +403,7 @@ class U2Tester(U2Trainer):
...
@@ -408,7 +403,7 @@ class U2Tester(U2Trainer):
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
"step: {}, "
.
format
(
self
.
iteration
)
msg
+=
", Final error rate [%s] (%d/%d) = %f"
%
(
msg
+=
", Final error rate [%s] (%d/%d) = %f"
%
(
error_rate_type
,
num_ins
,
num_ins
,
errors_sum
/
len_refs
)
error_rate_type
,
num_ins
,
num_ins
,
errors_sum
/
len_refs
)
self
.
logger
.
info
(
msg
)
logger
.
info
(
msg
)
def
run_test
(
self
):
def
run_test
(
self
):
self
.
resume_or_scratch
()
self
.
resume_or_scratch
()
...
@@ -459,7 +454,6 @@ class U2Tester(U2Trainer):
...
@@ -459,7 +454,6 @@ class U2Tester(U2Trainer):
self
.
setup_output_dir
()
self
.
setup_output_dir
()
self
.
setup_checkpointer
()
self
.
setup_checkpointer
()
self
.
setup_logger
()
self
.
setup_dataloader
()
self
.
setup_dataloader
()
self
.
setup_model
()
self
.
setup_model
()
...
@@ -480,25 +474,3 @@ class U2Tester(U2Trainer):
...
@@ -480,25 +474,3 @@ class U2Tester(U2Trainer):
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
output_dir
=
output_dir
self
.
output_dir
=
output_dir
def
setup_logger
(
self
):
"""Initialize a text logger to log the experiment.
Each process has its own text logger. The logging message is write to
the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process.
"""
format
=
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
formatter
=
logging
.
Formatter
(
fmt
=
format
,
datefmt
=
'%Y/%m/%d %H:%M:%S'
)
logger
.
setLevel
(
"INFO"
)
# global logger
stdout
=
True
save_path
=
""
logging
.
basicConfig
(
level
=
logging
.
DEBUG
if
stdout
else
logging
.
INFO
,
format
=
format
,
datefmt
=
'%Y/%m/%d %H:%M:%S'
,
filename
=
save_path
if
not
stdout
else
None
)
self
.
logger
=
logger
deepspeech/training/trainer.py
浏览文件 @
38ebec33
...
@@ -92,7 +92,7 @@ class Trainer():
...
@@ -92,7 +92,7 @@ class Trainer():
self
.
visualizer
=
None
self
.
visualizer
=
None
self
.
output_dir
=
None
self
.
output_dir
=
None
self
.
checkpoint_dir
=
None
self
.
checkpoint_dir
=
None
self
.
logger
=
None
logger
=
None
self
.
iteration
=
0
self
.
iteration
=
0
self
.
epoch
=
0
self
.
epoch
=
0
...
@@ -106,7 +106,6 @@ class Trainer():
...
@@ -106,7 +106,6 @@ class Trainer():
self
.
setup_output_dir
()
self
.
setup_output_dir
()
self
.
dump_config
()
self
.
dump_config
()
self
.
setup_visualizer
()
self
.
setup_visualizer
()
self
.
setup_logger
()
self
.
setup_checkpointer
()
self
.
setup_checkpointer
()
self
.
setup_dataloader
()
self
.
setup_dataloader
()
...
@@ -182,8 +181,7 @@ class Trainer():
...
@@ -182,8 +181,7 @@ class Trainer():
if
self
.
parallel
:
if
self
.
parallel
:
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
train_loader
.
batch_sampler
.
set_epoch
(
self
.
epoch
)
self
.
logger
.
info
(
logger
.
info
(
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
f
"Train Total Examples:
{
len
(
self
.
train_loader
.
dataset
)
}
"
)
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
while
self
.
epoch
<
self
.
config
.
training
.
n_epoch
:
self
.
model
.
train
()
self
.
model
.
train
()
try
:
try
:
...
@@ -198,7 +196,7 @@ class Trainer():
...
@@ -198,7 +196,7 @@ class Trainer():
self
.
train_batch
(
batch_index
,
batch
,
msg
)
self
.
train_batch
(
batch_index
,
batch
,
msg
)
data_start_time
=
time
.
time
()
data_start_time
=
time
.
time
()
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
e
)
logger
.
error
(
e
)
raise
e
raise
e
valid_losses
=
self
.
valid
()
valid_losses
=
self
.
valid
()
...
@@ -217,7 +215,7 @@ class Trainer():
...
@@ -217,7 +215,7 @@ class Trainer():
exit
(
-
1
)
exit
(
-
1
)
finally
:
finally
:
self
.
destory
()
self
.
destory
()
self
.
logger
.
info
(
"Training Done."
)
logger
.
info
(
"Training Done."
)
def
setup_output_dir
(
self
):
def
setup_output_dir
(
self
):
"""Create a directory used for output.
"""Create a directory used for output.
...
@@ -262,28 +260,6 @@ class Trainer():
...
@@ -262,28 +260,6 @@ class Trainer():
self
.
visualizer
=
visualizer
self
.
visualizer
=
visualizer
def
setup_logger
(
self
):
"""Initialize a text logger to log the experiment.
Each process has its own text logger. The logging message is write to
the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process.
when - how to split the log file by time interval
'S' : Seconds
'M' : Minutes
'H' : Hours
'D' : Days
'W' : Week day
default value: 'D'
format - format of the log
default format:
%(levelname)s: %(asctime)s: %(filename)s:%(lineno)d * %(thread)d %(message)s
INFO: 12-09 18:02:42: log.py:40 * 139814749787872 HELLO WORLD
backup - how many backup file to keep
default value: 7
"""
self
.
logger
=
logger
@
mp_tools
.
rank_zero_only
@
mp_tools
.
rank_zero_only
def
dump_config
(
self
):
def
dump_config
(
self
):
"""Save the configuration used for this experiment.
"""Save the configuration used for this experiment.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录