Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
12fc8c82
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
12fc8c82
编写于
8月 19, 2020
作者:
C
Chengmo
提交者:
GitHub
8月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix k8s_datasplit & save_model (#167)
* fix save inference
上级
d4a280b5
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
93 addition
and
9 deletion
+93
-9
core/engine/cluster/cloud/k8s_config.ini.template
core/engine/cluster/cloud/k8s_config.ini.template
+1
-0
core/engine/cluster/cloud/mpi_config.ini.template
core/engine/cluster/cloud/mpi_config.ini.template
+1
-0
core/trainer.py
core/trainer.py
+10
-0
core/trainers/framework/dataset.py
core/trainers/framework/dataset.py
+11
-0
core/trainers/framework/runner.py
core/trainers/framework/runner.py
+36
-7
core/utils/dataloader_instance.py
core/utils/dataloader_instance.py
+34
-2
未找到文件。
core/engine/cluster/cloud/k8s_config.ini.template
浏览文件 @
12fc8c82
...
@@ -19,6 +19,7 @@ afs_local_mount_point="/root/paddlejob/workspace/env_run/afs/"
...
@@ -19,6 +19,7 @@ afs_local_mount_point="/root/paddlejob/workspace/env_run/afs/"
# 新k8s afs挂载帮助文档: http://wiki.baidu.com/pages/viewpage.action?pageId=906443193
# 新k8s afs挂载帮助文档: http://wiki.baidu.com/pages/viewpage.action?pageId=906443193
PADDLE_PADDLEREC_ROLE=WORKER
PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=K8S
use_python3=<$ USE_PYTHON3 $>
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
GLOG_v=0
...
...
core/engine/cluster/cloud/mpi_config.ini.template
浏览文件 @
12fc8c82
...
@@ -17,6 +17,7 @@ output_path=<$ OUTPUT_PATH $>
...
@@ -17,6 +17,7 @@ output_path=<$ OUTPUT_PATH $>
thirdparty_path=<$ THIRDPARTY_PATH $>
thirdparty_path=<$ THIRDPARTY_PATH $>
PADDLE_PADDLEREC_ROLE=WORKER
PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=MPI
use_python3=<$ USE_PYTHON3 $>
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
GLOG_v=0
...
...
core/trainer.py
浏览文件 @
12fc8c82
...
@@ -107,6 +107,7 @@ class Trainer(object):
...
@@ -107,6 +107,7 @@ class Trainer(object):
self
.
device
=
Device
.
GPU
self
.
device
=
Device
.
GPU
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
self
.
_place
=
fluid
.
CUDAPlace
(
gpu_id
)
self
.
_place
=
fluid
.
CUDAPlace
(
gpu_id
)
print
(
"PaddleRec run on device GPU: {}"
.
format
(
gpu_id
))
self
.
_exe
=
fluid
.
Executor
(
self
.
_place
)
self
.
_exe
=
fluid
.
Executor
(
self
.
_place
)
elif
device
==
"CPU"
:
elif
device
==
"CPU"
:
self
.
device
=
Device
.
CPU
self
.
device
=
Device
.
CPU
...
@@ -146,6 +147,7 @@ class Trainer(object):
...
@@ -146,6 +147,7 @@ class Trainer(object):
elif
engine
.
upper
()
==
"CLUSTER"
:
elif
engine
.
upper
()
==
"CLUSTER"
:
self
.
engine
=
EngineMode
.
CLUSTER
self
.
engine
=
EngineMode
.
CLUSTER
self
.
is_fleet
=
True
self
.
is_fleet
=
True
self
.
which_cluster_type
()
else
:
else
:
raise
ValueError
(
"Not Support Engine {}"
.
format
(
engine
))
raise
ValueError
(
"Not Support Engine {}"
.
format
(
engine
))
self
.
_context
[
"is_fleet"
]
=
self
.
is_fleet
self
.
_context
[
"is_fleet"
]
=
self
.
is_fleet
...
@@ -165,6 +167,14 @@ class Trainer(object):
...
@@ -165,6 +167,14 @@ class Trainer(object):
self
.
_context
[
"is_pslib"
]
=
(
fleet_mode
.
upper
()
==
"PSLIB"
)
self
.
_context
[
"is_pslib"
]
=
(
fleet_mode
.
upper
()
==
"PSLIB"
)
self
.
_context
[
"fleet_mode"
]
=
fleet_mode
self
.
_context
[
"fleet_mode"
]
=
fleet_mode
def
which_cluster_type
(
self
):
cluster_type
=
os
.
getenv
(
"PADDLEREC_CLUSTER_TYPE"
,
"MPI"
)
print
(
"PADDLEREC_CLUSTER_TYPE: {}"
.
format
(
cluster_type
))
if
cluster_type
and
cluster_type
.
upper
()
==
"K8S"
:
self
.
_context
[
"cluster_type"
]
=
"K8S"
else
:
self
.
_context
[
"cluster_type"
]
=
"MPI"
def
which_executor_mode
(
self
):
def
which_executor_mode
(
self
):
executor_mode
=
envs
.
get_runtime_environ
(
"train.trainer.executor_mode"
)
executor_mode
=
envs
.
get_runtime_environ
(
"train.trainer.executor_mode"
)
if
executor_mode
.
upper
()
not
in
[
"TRAIN"
,
"INFER"
]:
if
executor_mode
.
upper
()
not
in
[
"TRAIN"
,
"INFER"
]:
...
...
core/trainers/framework/dataset.py
浏览文件 @
12fc8c82
...
@@ -123,10 +123,21 @@ class QueueDataset(DatasetBase):
...
@@ -123,10 +123,21 @@ class QueueDataset(DatasetBase):
os
.
path
.
join
(
train_data_path
,
x
)
os
.
path
.
join
(
train_data_path
,
x
)
for
x
in
os
.
listdir
(
train_data_path
)
for
x
in
os
.
listdir
(
train_data_path
)
]
]
file_list
.
sort
()
need_split_files
=
False
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
# for local cluster: split files for multi process
need_split_files
=
True
elif
context
[
"engine"
]
==
EngineMode
.
CLUSTER
and
context
[
"cluster_type"
]
==
"K8S"
:
# for k8s mount afs, split files for every node
need_split_files
=
True
if
need_split_files
:
file_list
=
split_files
(
file_list
,
context
[
"fleet"
].
worker_index
(),
file_list
=
split_files
(
file_list
,
context
[
"fleet"
].
worker_index
(),
context
[
"fleet"
].
worker_num
())
context
[
"fleet"
].
worker_num
())
print
(
"File_list: {}"
.
format
(
file_list
))
print
(
"File_list: {}"
.
format
(
file_list
))
dataset
.
set_filelist
(
file_list
)
dataset
.
set_filelist
(
file_list
)
for
model_dict
in
context
[
"phases"
]:
for
model_dict
in
context
[
"phases"
]:
if
model_dict
[
"dataset_name"
]
==
dataset_name
:
if
model_dict
[
"dataset_name"
]
==
dataset_name
:
...
...
core/trainers/framework/runner.py
浏览文件 @
12fc8c82
...
@@ -16,6 +16,7 @@ from __future__ import print_function
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
os
import
os
import
time
import
time
import
warnings
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -284,6 +285,7 @@ class RunnerBase(object):
...
@@ -284,6 +285,7 @@ class RunnerBase(object):
return
(
epoch_id
+
1
)
%
epoch_interval
==
0
return
(
epoch_id
+
1
)
%
epoch_interval
==
0
def
save_inference_model
():
def
save_inference_model
():
# get global env
name
=
"runner."
+
context
[
"runner_name"
]
+
"."
name
=
"runner."
+
context
[
"runner_name"
]
+
"."
save_interval
=
int
(
save_interval
=
int
(
envs
.
get_global_env
(
name
+
"save_inference_interval"
,
-
1
))
envs
.
get_global_env
(
name
+
"save_inference_interval"
,
-
1
))
...
@@ -296,18 +298,44 @@ class RunnerBase(object):
...
@@ -296,18 +298,44 @@ class RunnerBase(object):
if
feed_varnames
is
None
or
fetch_varnames
is
None
or
feed_varnames
==
""
or
fetch_varnames
==
""
or
\
if
feed_varnames
is
None
or
fetch_varnames
is
None
or
feed_varnames
==
""
or
fetch_varnames
==
""
or
\
len
(
feed_varnames
)
==
0
or
len
(
fetch_varnames
)
==
0
:
len
(
feed_varnames
)
==
0
or
len
(
fetch_varnames
)
==
0
:
return
return
fetch_vars
=
[
fluid
.
default_main_program
().
global_block
().
vars
[
varname
]
# check feed var exist
for
varname
in
fetch_varnames
for
var_name
in
feed_varnames
:
]
if
var_name
not
in
fluid
.
default_main_program
().
global_block
(
).
vars
:
raise
ValueError
(
"Feed variable: {} not in default_main_program, global block has follow vars: {}"
.
format
(
var_name
,
fluid
.
default_main_program
().
global_block
()
.
vars
.
keys
()))
# check fetch var exist
fetch_vars
=
[]
for
var_name
in
fetch_varnames
:
if
var_name
not
in
fluid
.
default_main_program
().
global_block
(
).
vars
:
raise
ValueError
(
"Fetch variable: {} not in default_main_program, global block has follow vars: {}"
.
format
(
var_name
,
fluid
.
default_main_program
().
global_block
()
.
vars
.
keys
()))
else
:
fetch_vars
.
append
(
fluid
.
default_main_program
()
.
global_block
().
vars
[
var_name
])
dirname
=
envs
.
get_global_env
(
name
+
"save_inference_path"
,
None
)
dirname
=
envs
.
get_global_env
(
name
+
"save_inference_path"
,
None
)
assert
dirname
is
not
None
assert
dirname
is
not
None
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
if
is_fleet
:
if
is_fleet
:
context
[
"fleet"
].
save_inference_model
(
warnings
.
warn
(
context
[
"exe"
],
dirname
,
feed_varnames
,
fetch_vars
)
"Save inference model in cluster training is not recommended! Using save checkpoint instead."
,
category
=
UserWarning
,
stacklevel
=
2
)
if
context
[
"fleet"
].
worker_index
()
==
0
:
context
[
"fleet"
].
save_inference_model
(
context
[
"exe"
],
dirname
,
feed_varnames
,
fetch_vars
)
else
:
else
:
fluid
.
io
.
save_inference_model
(
dirname
,
feed_varnames
,
fluid
.
io
.
save_inference_model
(
dirname
,
feed_varnames
,
fetch_vars
,
context
[
"exe"
])
fetch_vars
,
context
[
"exe"
])
...
@@ -323,7 +351,8 @@ class RunnerBase(object):
...
@@ -323,7 +351,8 @@ class RunnerBase(object):
return
return
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
if
is_fleet
:
if
is_fleet
:
context
[
"fleet"
].
save_persistables
(
context
[
"exe"
],
dirname
)
if
context
[
"fleet"
].
worker_index
()
==
0
:
context
[
"fleet"
].
save_persistables
(
context
[
"exe"
],
dirname
)
else
:
else
:
fluid
.
io
.
save_persistables
(
context
[
"exe"
],
dirname
)
fluid
.
io
.
save_persistables
(
context
[
"exe"
],
dirname
)
...
...
core/utils/dataloader_instance.py
浏览文件 @
12fc8c82
...
@@ -39,9 +39,21 @@ def dataloader_by_name(readerclass,
...
@@ -39,9 +39,21 @@ def dataloader_by_name(readerclass,
data_path
=
os
.
path
.
join
(
package_base
,
data_path
.
split
(
"::"
)[
1
])
data_path
=
os
.
path
.
join
(
package_base
,
data_path
.
split
(
"::"
)[
1
])
files
=
[
str
(
data_path
)
+
"/%s"
%
x
for
x
in
os
.
listdir
(
data_path
)]
files
=
[
str
(
data_path
)
+
"/%s"
%
x
for
x
in
os
.
listdir
(
data_path
)]
files
.
sort
()
need_split_files
=
False
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
# for local cluster: split files for multi process
need_split_files
=
True
elif
context
[
"engine"
]
==
EngineMode
.
CLUSTER
and
context
[
"cluster_type"
]
==
"K8S"
:
# for k8s mount mode, split files for every node
need_split_files
=
True
print
(
"need_split_files: {}"
.
format
(
need_split_files
))
if
need_split_files
:
files
=
split_files
(
files
,
context
[
"fleet"
].
worker_index
(),
files
=
split_files
(
files
,
context
[
"fleet"
].
worker_index
(),
context
[
"fleet"
].
worker_num
())
context
[
"fleet"
].
worker_num
())
print
(
"file_list : {}"
.
format
(
files
))
print
(
"file_list : {}"
.
format
(
files
))
reader
=
reader_class
(
yaml_file
)
reader
=
reader_class
(
yaml_file
)
...
@@ -81,10 +93,20 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
...
@@ -81,10 +93,20 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
data_path
=
os
.
path
.
join
(
package_base
,
data_path
.
split
(
"::"
)[
1
])
data_path
=
os
.
path
.
join
(
package_base
,
data_path
.
split
(
"::"
)[
1
])
files
=
[
str
(
data_path
)
+
"/%s"
%
x
for
x
in
os
.
listdir
(
data_path
)]
files
=
[
str
(
data_path
)
+
"/%s"
%
x
for
x
in
os
.
listdir
(
data_path
)]
files
.
sort
()
need_split_files
=
False
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
# for local cluster: split files for multi process
need_split_files
=
True
elif
context
[
"engine"
]
==
EngineMode
.
CLUSTER
and
context
[
"cluster_type"
]
==
"K8S"
:
# for k8s mount mode, split files for every node
need_split_files
=
True
if
need_split_files
:
files
=
split_files
(
files
,
context
[
"fleet"
].
worker_index
(),
files
=
split_files
(
files
,
context
[
"fleet"
].
worker_index
(),
context
[
"fleet"
].
worker_num
())
context
[
"fleet"
].
worker_num
())
print
(
"file_list: {}"
.
format
(
files
))
sparse
=
get_global_env
(
name
+
"sparse_slots"
,
"#"
)
sparse
=
get_global_env
(
name
+
"sparse_slots"
,
"#"
)
if
sparse
==
""
:
if
sparse
==
""
:
...
@@ -135,10 +157,20 @@ def slotdataloader(readerclass, train, yaml_file, context):
...
@@ -135,10 +157,20 @@ def slotdataloader(readerclass, train, yaml_file, context):
data_path
=
os
.
path
.
join
(
package_base
,
data_path
.
split
(
"::"
)[
1
])
data_path
=
os
.
path
.
join
(
package_base
,
data_path
.
split
(
"::"
)[
1
])
files
=
[
str
(
data_path
)
+
"/%s"
%
x
for
x
in
os
.
listdir
(
data_path
)]
files
=
[
str
(
data_path
)
+
"/%s"
%
x
for
x
in
os
.
listdir
(
data_path
)]
files
.
sort
()
need_split_files
=
False
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
if
context
[
"engine"
]
==
EngineMode
.
LOCAL_CLUSTER
:
# for local cluster: split files for multi process
need_split_files
=
True
elif
context
[
"engine"
]
==
EngineMode
.
CLUSTER
and
context
[
"cluster_type"
]
==
"K8S"
:
# for k8s mount mode, split files for every node
need_split_files
=
True
if
need_split_files
:
files
=
split_files
(
files
,
context
[
"fleet"
].
worker_index
(),
files
=
split_files
(
files
,
context
[
"fleet"
].
worker_index
(),
context
[
"fleet"
].
worker_num
())
context
[
"fleet"
].
worker_num
())
print
(
"file_list: {}"
.
format
(
files
))
sparse
=
get_global_env
(
"sparse_slots"
,
"#"
,
namespace
)
sparse
=
get_global_env
(
"sparse_slots"
,
"#"
,
namespace
)
if
sparse
==
""
:
if
sparse
==
""
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录