Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
02ec66fd
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02ec66fd
编写于
4月 15, 2020
作者:
T
tangwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update setup.py
上级
7c852a31
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
32 addition
and
18 deletion
+32
-18
fleetrec/core/factory.py
fleetrec/core/factory.py
+3
-3
fleetrec/core/utils/envs.py
fleetrec/core/utils/envs.py
+12
-2
fleetrec/run.py
fleetrec/run.py
+17
-13
未找到文件。
fleetrec/core/factory.py
浏览文件 @
02ec66fd
...
...
@@ -33,15 +33,15 @@ class TrainerFactory(object):
def
_build_trainer
(
config
,
yaml_path
):
print
(
envs
.
pretty_print_envs
(
envs
.
get_global_envs
()))
train_mode
=
envs
.
get_
global_env
(
"train.trainer"
)
train_mode
=
envs
.
get_
runtime_envion
(
"train.trainer"
)
if
train_mode
==
"SingleTraining"
:
trainer
=
SingleTrainer
(
yaml_path
)
elif
train_mode
==
"ClusterTraining"
:
trainer
=
ClusterTrainer
(
yaml_path
)
elif
train_mode
==
"CtrTrain
er
"
:
elif
train_mode
==
"CtrTrain
ing
"
:
trainer
=
CtrPaddleTrainer
(
config
)
elif
train_mode
==
"UserDefineTrain
er
"
:
elif
train_mode
==
"UserDefineTrain
ing
"
:
train_location
=
envs
.
get_global_env
(
"train.location"
)
train_dirname
=
os
.
path
.
dirname
(
train_location
)
base_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
train_location
))[
0
]
...
...
fleetrec/core/utils/envs.py
浏览文件 @
02ec66fd
...
...
@@ -12,12 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
copy
global_envs
=
{}
def
set_runtime_envions
(
envs
):
assert
isinstance
(
envs
,
dict
)
for
k
,
v
in
envs
.
items
():
os
.
environ
[
k
]
=
v
def
get_runtime_envion
(
key
):
return
os
.
getenv
(
key
,
None
)
def
set_global_envs
(
envs
):
assert
isinstance
(
envs
,
dict
)
...
...
@@ -87,4 +98,3 @@ def lazy_instance(package, class_name):
model_package
=
__import__
(
package
,
globals
(),
locals
(),
package
.
split
(
"."
))
instance
=
getattr
(
model_package
,
class_name
)
return
instance
fleetrec/run.py
浏览文件 @
02ec66fd
...
...
@@ -11,11 +11,10 @@ def run(model_yaml):
trainer
.
run
()
def
single_engine
(
model_yaml
):
single_envs
=
{}
single_envs
[
"singleTraning"
]
=
True
def
single_engine
(
single_envs
,
model_yaml
):
print
(
envs
.
pretty_print_envs
(
single_envs
,
(
"Single Envs"
,
"Value"
)))
envs
.
set_runtime_envions
(
single_envs
)
run
(
model_yaml
)
...
...
@@ -47,25 +46,30 @@ if __name__ == "__main__":
if
args
.
engine
==
"Single"
:
print
(
"use SingleTraining to run model: {}"
.
format
(
args
.
model
))
single_engine
(
args
.
model
)
single_envs
=
{}
single_envs
[
"train.trainer"
]
=
"SingleTraining"
single_engine
(
single_envs
,
args
.
model
)
elif
args
.
engine
==
"LocalCluster"
:
print
(
"use 1X1 ClusterTraining at localhost to run model: {}"
.
format
(
args
.
model
))
cluster_envs
=
{}
cluster_envs
[
"server_num"
]
=
1
cluster_envs
[
"worker_num"
]
=
1
cluster_envs
[
"start_port"
]
=
36001
cluster_envs
[
"log_dir"
]
=
"logs"
cluster_envs
[
"train.server_num"
]
=
1
cluster_envs
[
"train.worker_num"
]
=
1
cluster_envs
[
"train.start_port"
]
=
36001
cluster_envs
[
"train.log_dir"
]
=
"logs"
cluster_envs
[
"train.trainer"
]
=
"SingleTraining"
local_cluster_engine
(
cluster_envs
,
args
.
model
)
elif
args
.
engine
==
"LocalMPI"
:
print
(
"use 1X1 MPI ClusterTraining at localhost to run model: {}"
.
format
(
args
.
model
))
cluster_envs
=
{}
cluster_envs
[
"server_num"
]
=
1
cluster_envs
[
"worker_num"
]
=
1
cluster_envs
[
"start_port"
]
=
36001
cluster_envs
[
"log_dir"
]
=
"logs"
cluster_envs
[
"train.server_num"
]
=
1
cluster_envs
[
"train.worker_num"
]
=
1
cluster_envs
[
"train.start_port"
]
=
36001
cluster_envs
[
"train.log_dir"
]
=
"logs"
cluster_envs
[
"train.trainer"
]
=
"CtrTraining"
local_mpi_engine
(
cluster_envs
,
args
.
model
)
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录