Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
347af638
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
347af638
编写于
5月 11, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
changet vector train.py local_rank to rank, test=doc
上级
597d601d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
6 addition
and
6 deletion
+6
-6
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+6
-6
未找到文件。
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
347af638
...
@@ -54,7 +54,7 @@ def main(args, config):
...
@@ -54,7 +54,7 @@ def main(args, config):
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
nranks
=
paddle
.
distributed
.
get_world_size
()
nranks
=
paddle
.
distributed
.
get_world_size
()
local_
rank
=
paddle
.
distributed
.
get_rank
()
rank
=
paddle
.
distributed
.
get_rank
()
# set the random seed, it is the necessary measures for multiprocess training
# set the random seed, it is the necessary measures for multiprocess training
seed_everything
(
config
.
seed
)
seed_everything
(
config
.
seed
)
...
@@ -112,10 +112,10 @@ def main(args, config):
...
@@ -112,10 +112,10 @@ def main(args, config):
state_dict
=
paddle
.
load
(
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdopt'
))
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdopt'
))
optimizer
.
set_state_dict
(
state_dict
)
optimizer
.
set_state_dict
(
state_dict
)
if
local_
rank
==
0
:
if
rank
==
0
:
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
except
FileExistsError
:
except
FileExistsError
:
if
local_
rank
==
0
:
if
rank
==
0
:
logger
.
info
(
'Train from scratch.'
)
logger
.
info
(
'Train from scratch.'
)
try
:
try
:
...
@@ -219,7 +219,7 @@ def main(args, config):
...
@@ -219,7 +219,7 @@ def main(args, config):
timer
.
count
()
# step plus one in timer
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if
(
batch_idx
+
1
)
%
config
.
log_interval
==
0
and
local_
rank
==
0
:
if
(
batch_idx
+
1
)
%
config
.
log_interval
==
0
and
rank
==
0
:
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
avg_loss
/=
config
.
log_interval
avg_loss
/=
config
.
log_interval
avg_acc
=
num_corrects
/
num_samples
avg_acc
=
num_corrects
/
num_samples
...
@@ -250,7 +250,7 @@ def main(args, config):
...
@@ -250,7 +250,7 @@ def main(args, config):
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if
epoch
%
config
.
save_interval
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
if
epoch
%
config
.
save_interval
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
if
local_
rank
!=
0
:
if
rank
!=
0
:
paddle
.
distributed
.
barrier
(
paddle
.
distributed
.
barrier
(
)
# Wait for valid step in main process
)
# Wait for valid step in main process
continue
# Resume trainning on other process
continue
# Resume trainning on other process
...
@@ -317,7 +317,7 @@ def main(args, config):
...
@@ -317,7 +317,7 @@ def main(args, config):
paddle
.
distributed
.
barrier
()
# Main process
paddle
.
distributed
.
barrier
()
# Main process
# stage 10: create the final trained model.pdparams with soft link
# stage 10: create the final trained model.pdparams with soft link
if
local_
rank
==
0
:
if
rank
==
0
:
final_model
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
"model.pdparams"
)
final_model
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
"model.pdparams"
)
logger
.
info
(
f
"we will create the final model:
{
final_model
}
"
)
logger
.
info
(
f
"we will create the final model:
{
final_model
}
"
)
if
os
.
path
.
islink
(
final_model
):
if
os
.
path
.
islink
(
final_model
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录