Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
8b9b9cf9
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b9b9cf9
编写于
4月 21, 2020
作者:
T
tangwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge yaml two to one
上级
59f4df30
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
37 addition
and
31 deletion
+37
-31
fleetrec/core/trainer.py
fleetrec/core/trainer.py
+1
-1
fleetrec/core/utils/envs.py
fleetrec/core/utils/envs.py
+12
-5
fleetrec/examples/ctr-dnn_train.yaml
fleetrec/examples/ctr-dnn_train.yaml
+11
-0
fleetrec/examples/runtime.yaml
fleetrec/examples/runtime.yaml
+0
-13
fleetrec/run.py
fleetrec/run.py
+13
-12
未找到文件。
fleetrec/core/trainer.py
浏览文件 @
8b9b9cf9
...
...
@@ -89,7 +89,7 @@ def user_define_engine(engine_yaml):
_config
=
yaml
.
load
(
rb
.
read
(),
Loader
=
yaml
.
FullLoader
)
assert
_config
is
not
None
envs
.
set_runtime_envions
(
_config
)
envs
.
set_runtime_envi
r
ons
(
_config
)
train_location
=
envs
.
get_global_env
(
"engine.file"
)
train_dirname
=
os
.
path
.
dirname
(
train_location
)
...
...
fleetrec/core/utils/envs.py
浏览文件 @
8b9b9cf9
...
...
@@ -18,13 +18,14 @@ import copy
global_envs
=
{}
def
set_runtime_envions
(
envs
):
def
flatten_environs
(
envs
):
flatten_dict
=
{}
assert
isinstance
(
envs
,
dict
)
def
fatten_env_namespace
(
namespace_nests
,
local_envs
):
if
not
isinstance
(
local_envs
,
dict
):
global_k
=
"."
.
join
(
namespace_nests
)
os
.
environ
[
global_k
]
=
str
(
local_envs
)
flatten_dict
[
global_k
]
=
str
(
local_envs
)
else
:
for
k
,
v
in
local_envs
.
items
():
if
isinstance
(
v
,
dict
):
...
...
@@ -33,18 +34,24 @@ def set_runtime_envions(envs):
fatten_env_namespace
(
nests
,
v
)
else
:
global_k
=
"."
.
join
(
namespace_nests
+
[
k
])
os
.
environ
[
global_k
]
=
str
(
v
)
flatten_dict
[
global_k
]
=
str
(
v
)
for
k
,
v
in
envs
.
items
():
fatten_env_namespace
([
k
],
v
)
return
flatten_dict
def
get_runtime_envion
(
key
):
def
set_runtime_environs
(
environs
):
for
k
,
v
in
environs
.
items
():
os
.
environ
[
k
]
=
v
def
get_runtime_environ
(
key
):
return
os
.
getenv
(
key
,
None
)
def
get_trainer
():
train_mode
=
get_runtime_envion
(
"trainer.trainer"
)
train_mode
=
get_runtime_envi
r
on
(
"trainer.trainer"
)
return
train_mode
...
...
fleetrec/examples/ctr-dnn_train.yaml
浏览文件 @
8b9b9cf9
...
...
@@ -13,6 +13,17 @@
# limitations under the License.
train
:
trainer
:
trainer
:
"
/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads
:
4
# for cluster training
strategy
:
"
async"
communicator
:
send_queue_size
:
4
min_send_grad_num_before_recv
:
4
thread_pool_size
:
5
max_merge_var_num
:
4
epochs
:
10
reader
:
...
...
fleetrec/examples/runtime.yaml
已删除
100644 → 0
浏览文件 @
59f4df30
trainer
:
trainer
:
"
/root/FleetRec/fleetrec/examples/user_define_trainer.py"
threads
:
4
# for cluster training
strategy
:
"
async"
communicator
:
send_queue_size
:
4
min_send_grad_num_before_recv
:
4
thread_pool_size
:
5
max_merge_var_num
:
4
fleetrec/run.py
浏览文件 @
8b9b9cf9
...
...
@@ -13,16 +13,23 @@ clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
def
set_runtime_envs
(
cluster_envs
,
engine_yaml
):
if
engine_yaml
is
not
None
:
def
get_engine_extras
()
:
with
open
(
engine_yaml
,
'r'
)
as
rb
:
_envs
=
yaml
.
load
(
rb
.
read
(),
Loader
=
yaml
.
FullLoader
)
else
:
_envs
=
{}
flattens
=
envs
.
flatten_environs
(
_envs
)
engine_extras
=
{}
for
k
,
v
in
flattens
.
items
():
if
k
.
startswith
(
"train.trainer."
):
engine_extras
[
k
]
=
v
return
engine_extras
if
cluster_envs
is
None
:
cluster_envs
=
{}
envs
.
set_runtime_envions
(
cluster_envs
)
envs
.
set_runtime_envions
(
_envs
)
envs
.
set_runtime_environs
(
cluster_envs
)
envs
.
set_runtime_environs
(
get_engine_extras
())
need_print
=
{}
for
k
,
v
in
os
.
environ
.
items
():
...
...
@@ -62,9 +69,8 @@ def cluster_engine(args):
cluster_envs
=
{}
cluster_envs
[
"trainer.trainer"
]
=
"ClusterTrainer"
cluster_envs
[
"trainer.engine"
]
=
"cluster"
set_runtime_envs
(
cluster_envs
,
args
.
engine_extras
)
set_runtime_envs
(
cluster_envs
,
args
.
model
)
envs
.
set_runtime_envions
(
cluster_envs
)
trainer
=
TrainerFactory
.
create
(
args
.
model
)
return
trainer
...
...
@@ -146,11 +152,6 @@ if __name__ == "__main__":
if
args
.
engine
.
upper
()
not
in
clusters
:
raise
ValueError
(
"argument engine: {} error, must in {}"
.
format
(
args
.
engine
,
clusters
))
if
args
.
engine_extras
is
not
None
:
if
not
os
.
path
.
exists
(
args
.
engine_extras
)
or
not
os
.
path
.
isfile
(
args
.
engine_extras
):
raise
ValueError
(
"argument engine_extras: {} error, must specify an existed YAML file"
.
format
(
args
.
engine_extras
))
which_engine
=
get_engine
(
args
.
engine
)
engine
=
which_engine
(
args
)
engine
.
run
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录