Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5ffccbd5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
11 个月 前同步成功
通知
204
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5ffccbd5
编写于
10月 26, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
exp with eval mode
上级
e8bc9a2a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
78 addition
and
66 deletion
+78
-66
deepspeech/exps/deepspeech2/model.py
deepspeech/exps/deepspeech2/model.py
+68
-60
deepspeech/exps/u2/model.py
deepspeech/exps/u2/model.py
+1
-1
deepspeech/exps/u2_kaldi/model.py
deepspeech/exps/u2_kaldi/model.py
+1
-1
deepspeech/exps/u2_st/model.py
deepspeech/exps/u2_st/model.py
+1
-1
deepspeech/training/trainer.py
deepspeech/training/trainer.py
+7
-3
未找到文件。
deepspeech/exps/deepspeech2/model.py
浏览文件 @
5ffccbd5
...
@@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -167,6 +167,11 @@ class DeepSpeech2Trainer(Trainer):
logger
.
info
(
f
"
{
model
}
"
)
logger
.
info
(
f
"
{
model
}
"
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
layer_tools
.
print_params
(
model
,
logger
.
info
)
self
.
model
=
model
logger
.
info
(
"Setup model!"
)
if
not
self
.
train
:
return
grad_clip
=
ClipGradByGlobalNormWithLog
(
grad_clip
=
ClipGradByGlobalNormWithLog
(
config
.
training
.
global_grad_clip
)
config
.
training
.
global_grad_clip
)
...
@@ -180,74 +185,77 @@ class DeepSpeech2Trainer(Trainer):
...
@@ -180,74 +185,77 @@ class DeepSpeech2Trainer(Trainer):
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
config
.
training
.
weight_decay
),
config
.
training
.
weight_decay
),
grad_clip
=
grad_clip
)
grad_clip
=
grad_clip
)
self
.
model
=
model
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
lr_scheduler
=
lr_scheduler
self
.
lr_scheduler
=
lr_scheduler
logger
.
info
(
"Setup model/optimizer/lr_scheduler!"
)
logger
.
info
(
"Setup optimizer/lr_scheduler!"
)
def
setup_dataloader
(
self
):
def
setup_dataloader
(
self
):
config
=
self
.
config
.
clone
()
config
=
self
.
config
.
clone
()
config
.
defrost
()
config
.
defrost
()
config
.
collator
.
keep_transcription_text
=
False
if
self
.
train
:
# train
config
.
data
.
manifest
=
config
.
data
.
train_manifest
config
.
data
.
manifest
=
config
.
data
.
train_manifest
train_dataset
=
ManifestDataset
.
from_config
(
config
)
train_dataset
=
ManifestDataset
.
from_config
(
config
)
if
self
.
parallel
:
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
batch_sampler
=
SortagradDistributedBatchSampler
(
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
train_dataset
,
batch_size
=
config
.
collator
.
batch_size
,
config
.
data
.
manifest
=
config
.
data
.
test_manifest
num_replicas
=
None
,
test_dataset
=
ManifestDataset
.
from_config
(
config
)
rank
=
None
,
shuffle
=
True
,
if
self
.
parallel
:
drop_last
=
True
,
batch_sampler
=
SortagradDistributedBatchSampler
(
sortagrad
=
config
.
collator
.
sortagrad
,
shuffle_method
=
config
.
collator
.
shuffle_method
)
else
:
batch_sampler
=
SortagradBatchSampler
(
train_dataset
,
shuffle
=
True
,
batch_size
=
config
.
collator
.
batch_size
,
drop_last
=
True
,
sortagrad
=
config
.
collator
.
sortagrad
,
shuffle_method
=
config
.
collator
.
shuffle_method
)
config
.
collator
.
keep_transcription_text
=
False
collate_fn_train
=
SpeechCollator
.
from_config
(
config
)
self
.
train_loader
=
DataLoader
(
train_dataset
,
train_dataset
,
batch_size
=
config
.
collator
.
batch_size
,
batch_sampler
=
batch_sampler
,
num_replicas
=
None
,
collate_fn
=
collate_fn_train
,
rank
=
None
,
num_workers
=
config
.
collator
.
num_workers
)
shuffle
=
True
,
drop_last
=
True
,
# dev
sortagrad
=
config
.
collator
.
sortagrad
,
config
.
data
.
manifest
=
config
.
data
.
dev_manifest
shuffle_method
=
config
.
collator
.
shuffle_method
)
dev_dataset
=
ManifestDataset
.
from_config
(
config
)
config
.
collator
.
augmentation_config
=
""
config
.
collator
.
keep_transcription_text
=
False
collate_fn_dev
=
SpeechCollator
.
from_config
(
config
)
self
.
valid_loader
=
DataLoader
(
dev_dataset
,
batch_size
=
int
(
config
.
collator
.
batch_size
),
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_dev
,
num_workers
=
config
.
collator
.
num_workers
)
logger
.
info
(
"Setup train/valid Dataloader!"
)
else
:
else
:
batch_sampler
=
SortagradBatchSampler
(
# test
train_dataset
,
config
.
data
.
manifest
=
config
.
data
.
test_manifest
shuffle
=
True
,
test_dataset
=
ManifestDataset
.
from_config
(
config
)
batch_size
=
config
.
collator
.
batch_size
,
drop_last
=
True
,
config
.
collator
.
augmentation_config
=
""
sortagrad
=
config
.
collator
.
sortagrad
,
config
.
collator
.
keep_transcription_text
=
True
shuffle_method
=
config
.
collator
.
shuffle_method
)
collate_fn_test
=
SpeechCollator
.
from_config
(
config
)
collate_fn_train
=
SpeechCollator
.
from_config
(
config
)
self
.
test_loader
=
DataLoader
(
test_dataset
,
config
.
collator
.
augmentation_config
=
""
batch_size
=
config
.
decoding
.
batch_size
,
collate_fn_dev
=
SpeechCollator
.
from_config
(
config
)
shuffle
=
False
,
drop_last
=
False
,
config
.
collator
.
keep_transcription_text
=
True
collate_fn
=
collate_fn_test
,
config
.
collator
.
augmentation_config
=
""
num_workers
=
config
.
collator
.
num_workers
)
collate_fn_test
=
SpeechCollator
.
from_config
(
config
)
logger
.
info
(
"Setup test Dataloader!"
)
self
.
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
batch_sampler
,
collate_fn
=
collate_fn_train
,
num_workers
=
config
.
collator
.
num_workers
)
self
.
valid_loader
=
DataLoader
(
dev_dataset
,
batch_size
=
int
(
config
.
collator
.
batch_size
),
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_dev
,
num_workers
=
config
.
collator
.
num_workers
)
self
.
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
config
.
decoding
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
collate_fn_test
,
num_workers
=
config
.
collator
.
num_workers
)
logger
.
info
(
"Setup train/valid/test Dataloader!"
)
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
class
DeepSpeech2Tester
(
DeepSpeech2Trainer
):
...
...
deepspeech/exps/u2/model.py
浏览文件 @
5ffccbd5
...
@@ -172,7 +172,7 @@ class U2Trainer(Trainer):
...
@@ -172,7 +172,7 @@ class U2Trainer(Trainer):
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
return
total_loss
,
num_seen_utts
return
total_loss
,
num_seen_utts
def
train
(
self
):
def
do_
train
(
self
):
"""The training process control by step."""
"""The training process control by step."""
# !!!IMPORTANT!!!
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# Try to export the model by script, if fails, we should refine
...
...
deepspeech/exps/u2_kaldi/model.py
浏览文件 @
5ffccbd5
...
@@ -173,7 +173,7 @@ class U2Trainer(Trainer):
...
@@ -173,7 +173,7 @@ class U2Trainer(Trainer):
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
return
total_loss
,
num_seen_utts
return
total_loss
,
num_seen_utts
def
train
(
self
):
def
do_
train
(
self
):
"""The training process control by step."""
"""The training process control by step."""
# !!!IMPORTANT!!!
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# Try to export the model by script, if fails, we should refine
...
...
deepspeech/exps/u2_st/model.py
浏览文件 @
5ffccbd5
...
@@ -184,7 +184,7 @@ class U2STTrainer(Trainer):
...
@@ -184,7 +184,7 @@ class U2STTrainer(Trainer):
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
dist
.
get_rank
(),
total_loss
/
num_seen_utts
))
return
total_loss
,
num_seen_utts
return
total_loss
,
num_seen_utts
def
train
(
self
):
def
do_
train
(
self
):
"""The training process control by step."""
"""The training process control by step."""
# !!!IMPORTANT!!!
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# Try to export the model by script, if fails, we should refine
...
...
deepspeech/training/trainer.py
浏览文件 @
5ffccbd5
...
@@ -134,6 +134,10 @@ class Trainer():
...
@@ -134,6 +134,10 @@ class Trainer():
logger
.
info
(
logger
.
info
(
f
"Benchmark reset batch-size:
{
self
.
args
.
benchmark_batch_size
}
"
)
f
"Benchmark reset batch-size:
{
self
.
args
.
benchmark_batch_size
}
"
)
@
property
def
train
(
self
):
return
self
.
_train
@
contextmanager
@
contextmanager
def
eval
(
self
):
def
eval
(
self
):
self
.
_train
=
False
self
.
_train
=
False
...
@@ -248,7 +252,7 @@ class Trainer():
...
@@ -248,7 +252,7 @@ class Trainer():
sys
.
exit
(
sys
.
exit
(
f
"Reach benchmark-max-step:
{
self
.
args
.
benchmark_max_step
}
"
)
f
"Reach benchmark-max-step:
{
self
.
args
.
benchmark_max_step
}
"
)
def
train
(
self
):
def
do_
train
(
self
):
"""The training process control by epoch."""
"""The training process control by epoch."""
self
.
before_train
()
self
.
before_train
()
...
@@ -321,7 +325,7 @@ class Trainer():
...
@@ -321,7 +325,7 @@ class Trainer():
"""
"""
try
:
try
:
with
Timer
(
"Training Done: {}"
):
with
Timer
(
"Training Done: {}"
):
self
.
train
()
self
.
do_
train
()
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
exit
(
-
1
)
exit
(
-
1
)
finally
:
finally
:
...
@@ -432,7 +436,7 @@ class Trainer():
...
@@ -432,7 +436,7 @@ class Trainer():
beginning of the experiment.
beginning of the experiment.
"""
"""
config_file
=
self
.
config_dir
/
"config.yaml"
config_file
=
self
.
config_dir
/
"config.yaml"
if
self
.
_
train
and
config_file
.
exists
():
if
self
.
train
and
config_file
.
exists
():
time_stamp
=
time
.
strftime
(
"%Y_%m_%d_%H_%M_%s"
,
time
.
gmtime
())
time_stamp
=
time
.
strftime
(
"%Y_%m_%d_%H_%M_%s"
,
time
.
gmtime
())
target_path
=
self
.
config_dir
/
"."
.
join
(
target_path
=
self
.
config_dir
/
"."
.
join
(
[
time_stamp
,
"config.yaml"
])
[
time_stamp
,
"config.yaml"
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录