Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
ebeb23c2
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ebeb23c2
编写于
9月 17, 2020
作者:
F
frankwhzhang
浏览文件
操作
浏览文件
下载
差异文件
some question
上级
d837c5f0
777be5a0
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
90 addition
and
8 deletion
+90
-8
core/trainers/framework/runner.py
core/trainers/framework/runner.py
+51
-7
core/utils/envs.py
core/utils/envs.py
+20
-1
doc/yaml.md
doc/yaml.md
+2
-0
models/rank/dnn/config.yaml
models/rank/dnn/config.yaml
+17
-0
未找到文件。
core/trainers/framework/runner.py
浏览文件 @
ebeb23c2
...
@@ -100,6 +100,7 @@ class RunnerBase(object):
...
@@ -100,6 +100,7 @@ class RunnerBase(object):
fetch_period
=
int
(
fetch_period
=
int
(
envs
.
get_global_env
(
"runner."
+
context
[
"runner_name"
]
+
envs
.
get_global_env
(
"runner."
+
context
[
"runner_name"
]
+
".print_interval"
,
20
))
".print_interval"
,
20
))
scope
=
context
[
"model"
][
model_name
][
"scope"
]
scope
=
context
[
"model"
][
model_name
][
"scope"
]
program
=
context
[
"model"
][
model_name
][
"main_program"
]
program
=
context
[
"model"
][
model_name
][
"main_program"
]
reader
=
context
[
"dataset"
][
reader_name
]
reader
=
context
[
"dataset"
][
reader_name
]
...
@@ -139,6 +140,9 @@ class RunnerBase(object):
...
@@ -139,6 +140,9 @@ class RunnerBase(object):
fetch_period
=
int
(
fetch_period
=
int
(
envs
.
get_global_env
(
"runner."
+
context
[
"runner_name"
]
+
envs
.
get_global_env
(
"runner."
+
context
[
"runner_name"
]
+
".print_interval"
,
20
))
".print_interval"
,
20
))
save_step_interval
=
int
(
envs
.
get_global_env
(
"runner."
+
context
[
"runner_name"
]
+
".save_step_interval"
,
-
1
))
if
context
[
"is_infer"
]:
if
context
[
"is_infer"
]:
metrics
=
model_class
.
get_infer_results
()
metrics
=
model_class
.
get_infer_results
()
else
:
else
:
...
@@ -202,6 +206,24 @@ class RunnerBase(object):
...
@@ -202,6 +206,24 @@ class RunnerBase(object):
metrics_logging
.
insert
(
1
,
seconds
)
metrics_logging
.
insert
(
1
,
seconds
)
begin_time
=
end_time
begin_time
=
end_time
logging
.
info
(
metrics_format
.
format
(
*
metrics_logging
))
logging
.
info
(
metrics_format
.
format
(
*
metrics_logging
))
if
save_step_interval
>=
1
and
batch_id
%
save_step_interval
==
0
and
context
[
"is_infer"
]
==
False
:
if
context
[
"fleet_mode"
].
upper
()
==
"PS"
:
train_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
"main_program"
]
else
:
train_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
"default_main_program"
]
startup_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
"startup_program"
]
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
self
.
save
(
context
,
is_fleet
=
context
[
"is_fleet"
],
epoch_id
=
None
,
batch_id
=
batch_id
)
batch_id
+=
1
batch_id
+=
1
except
fluid
.
core
.
EOFException
:
except
fluid
.
core
.
EOFException
:
reader
.
reset
()
reader
.
reset
()
...
@@ -314,7 +336,7 @@ class RunnerBase(object):
...
@@ -314,7 +336,7 @@ class RunnerBase(object):
exec_strategy
=
_exe_strategy
)
exec_strategy
=
_exe_strategy
)
return
program
return
program
def
save
(
self
,
epoch_id
,
context
,
is_fleet
=
Fals
e
):
def
save
(
self
,
context
,
is_fleet
=
False
,
epoch_id
=
None
,
batch_id
=
Non
e
):
def
need_save
(
epoch_id
,
epoch_interval
,
is_last
=
False
):
def
need_save
(
epoch_id
,
epoch_interval
,
is_last
=
False
):
name
=
"runner."
+
context
[
"runner_name"
]
+
"."
name
=
"runner."
+
context
[
"runner_name"
]
+
"."
total_epoch
=
int
(
envs
.
get_global_env
(
name
+
"epochs"
,
1
))
total_epoch
=
int
(
envs
.
get_global_env
(
name
+
"epochs"
,
1
))
...
@@ -371,7 +393,8 @@ class RunnerBase(object):
...
@@ -371,7 +393,8 @@ class RunnerBase(object):
assert
dirname
is
not
None
assert
dirname
is
not
None
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
logging
.
info
(
"
\t
save epoch_id:%d model into:
\"
%s
\"
"
%
(
epoch_id
,
dirname
))
if
is_fleet
:
if
is_fleet
:
warnings
.
warn
(
warnings
.
warn
(
"Save inference model in cluster training is not recommended! Using save checkpoint instead."
,
"Save inference model in cluster training is not recommended! Using save checkpoint instead."
,
...
@@ -394,14 +417,35 @@ class RunnerBase(object):
...
@@ -394,14 +417,35 @@ class RunnerBase(object):
if
dirname
is
None
or
dirname
==
""
:
if
dirname
is
None
or
dirname
==
""
:
return
return
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
logging
.
info
(
"
\t
save epoch_id:%d model into:
\"
%s
\"
"
%
(
epoch_id
,
dirname
))
if
is_fleet
:
if
is_fleet
:
if
context
[
"fleet"
].
worker_index
()
==
0
:
if
context
[
"fleet"
].
worker_index
()
==
0
:
context
[
"fleet"
].
save_persistables
(
context
[
"exe"
],
dirname
)
context
[
"fleet"
].
save_persistables
(
context
[
"exe"
],
dirname
)
else
:
else
:
fluid
.
io
.
save_persistables
(
context
[
"exe"
],
dirname
)
fluid
.
io
.
save_persistables
(
context
[
"exe"
],
dirname
)
def
save_checkpoint_step
():
name
=
"runner."
+
context
[
"runner_name"
]
+
"."
save_interval
=
int
(
envs
.
get_global_env
(
name
+
"save_step_interval"
,
-
1
))
dirname
=
envs
.
get_global_env
(
name
+
"save_step_path"
,
None
)
if
dirname
is
None
or
dirname
==
""
:
return
dirname
=
os
.
path
.
join
(
dirname
,
str
(
batch_id
))
logging
.
info
(
"
\t
save batch_id:%d model into:
\"
%s
\"
"
%
(
batch_id
,
dirname
))
if
is_fleet
:
if
context
[
"fleet"
].
worker_index
()
==
0
:
context
[
"fleet"
].
save_persistables
(
context
[
"exe"
],
dirname
)
else
:
fluid
.
io
.
save_persistables
(
context
[
"exe"
],
dirname
)
if
isinstance
(
epoch_id
,
int
):
save_persistables
()
save_persistables
()
save_inference_model
()
save_inference_model
()
if
isinstance
(
batch_id
,
int
):
save_checkpoint_step
()
class
SingleRunner
(
RunnerBase
):
class
SingleRunner
(
RunnerBase
):
...
@@ -453,7 +497,7 @@ class SingleRunner(RunnerBase):
...
@@ -453,7 +497,7 @@ class SingleRunner(RunnerBase):
startup_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
startup_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
"startup_program"
]
"startup_program"
]
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
self
.
save
(
epoch
,
context
)
self
.
save
(
context
=
context
,
epoch_id
=
epoch
)
context
[
"status"
]
=
"terminal_pass"
context
[
"status"
]
=
"terminal_pass"
...
@@ -506,7 +550,7 @@ class PSRunner(RunnerBase):
...
@@ -506,7 +550,7 @@ class PSRunner(RunnerBase):
startup_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
startup_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
"startup_program"
]
"startup_program"
]
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
self
.
save
(
epoch
,
context
,
True
)
self
.
save
(
context
=
context
,
is_fleet
=
True
,
epoch_id
=
epoch
)
context
[
"status"
]
=
"terminal_pass"
context
[
"status"
]
=
"terminal_pass"
...
@@ -539,7 +583,7 @@ class CollectiveRunner(RunnerBase):
...
@@ -539,7 +583,7 @@ class CollectiveRunner(RunnerBase):
startup_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
startup_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
"startup_program"
]
"startup_program"
]
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
self
.
save
(
epoch
,
context
,
True
)
self
.
save
(
context
=
context
,
is_fleet
=
True
,
epoch_id
=
epoch
)
context
[
"status"
]
=
"terminal_pass"
context
[
"status"
]
=
"terminal_pass"
...
...
core/utils/envs.py
浏览文件 @
ebeb23c2
...
@@ -20,7 +20,7 @@ import socket
...
@@ -20,7 +20,7 @@ import socket
import
sys
import
sys
import
six
import
six
import
traceback
import
traceback
import
six
import
warnings
global_envs
=
{}
global_envs
=
{}
global_envs_flatten
=
{}
global_envs_flatten
=
{}
...
@@ -98,6 +98,25 @@ def set_global_envs(envs):
...
@@ -98,6 +98,25 @@ def set_global_envs(envs):
value
=
os_path_adapter
(
workspace_adapter
(
value
))
value
=
os_path_adapter
(
workspace_adapter
(
value
))
global_envs
[
name
]
=
value
global_envs
[
name
]
=
value
for
runner
in
envs
[
"runner"
]:
if
"save_step_interval"
in
runner
or
"save_step_path"
in
runner
:
phase_name
=
runner
[
"phases"
]
phase
=
[
phase
for
phase
in
envs
[
"phase"
]
if
phase
[
"name"
]
==
phase_name
[
0
]
]
dataset_name
=
phase
[
0
].
get
(
"dataset_name"
)
dataset
=
[
dataset
for
dataset
in
envs
[
"dataset"
]
if
dataset
[
"name"
]
==
dataset_name
]
if
dataset
[
0
].
get
(
"type"
)
==
"QueueDataset"
:
runner
[
"save_step_interval"
]
=
None
runner
[
"save_step_path"
]
=
None
warnings
.
warn
(
"QueueDataset can not support save by step, please not config save_step_interval and save_step_path in your yaml"
)
if
get_platform
()
!=
"LINUX"
:
if
get_platform
()
!=
"LINUX"
:
for
dataset
in
envs
[
"dataset"
]:
for
dataset
in
envs
[
"dataset"
]:
name
=
"."
.
join
([
"dataset"
,
dataset
[
"name"
],
"type"
])
name
=
"."
.
join
([
"dataset"
,
dataset
[
"name"
],
"type"
])
...
...
doc/yaml.md
浏览文件 @
ebeb23c2
...
@@ -27,6 +27,8 @@
...
@@ -27,6 +27,8 @@
| init_model_path | string | 路径 | 否 | 初始化模型地址 |
| init_model_path | string | 路径 | 否 | 初始化模型地址 |
| save_checkpoint_interval | int | >= 1 | 否 | Save参数的轮数间隔 |
| save_checkpoint_interval | int | >= 1 | 否 | Save参数的轮数间隔 |
| save_checkpoint_path | string | 路径 | 否 | Save参数的地址 |
| save_checkpoint_path | string | 路径 | 否 | Save参数的地址 |
| save_step_interval | int | >= 1 | 否 | Step save参数的batch数间隔 |
| save_step_path | string | 路径 | 否 | Step save参数的地址 |
| save_inference_interval | int | >= 1 | 否 | Save预测模型的轮数间隔 |
| save_inference_interval | int | >= 1 | 否 | Save预测模型的轮数间隔 |
| save_inference_path | string | 路径 | 否 | Save预测模型的地址 |
| save_inference_path | string | 路径 | 否 | Save预测模型的地址 |
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的入口变量name |
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的入口变量name |
...
...
models/rank/dnn/config.yaml
浏览文件 @
ebeb23c2
...
@@ -114,6 +114,23 @@ runner:
...
@@ -114,6 +114,23 @@ runner:
print_interval
:
1
print_interval
:
1
phases
:
[
phase1
]
phases
:
[
phase1
]
-
name
:
single_multi_gpu_train
class
:
train
# num of epochs
epochs
:
1
# device to run training or infer
device
:
gpu
selected_gpus
:
"
0,1"
# 选择多卡执行训练
save_checkpoint_interval
:
1
# save model interval of epochs
save_inference_interval
:
4
# save inference
save_step_interval
:
1
save_checkpoint_path
:
"
increment_dnn"
# save checkpoint path
save_inference_path
:
"
inference"
# save inference path
save_step_path
:
"
step_save"
save_inference_feed_varnames
:
[]
# feed vars of save inference
save_inference_fetch_varnames
:
[]
# fetch vars of save inference
print_interval
:
1
phases
:
[
phase1
]
# runner will run all the phase in each epoch
# runner will run all the phase in each epoch
phase
:
phase
:
-
name
:
phase1
-
name
:
phase1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录