Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
2e91f58f
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2e91f58f
编写于
4月 16, 2020
作者:
T
tangwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix import
上级
3eba3369
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
14 addition
and
10 deletion
+14
-10
fleetrec/core/factory.py
fleetrec/core/factory.py
+4
-4
fleetrec/core/trainer.py
fleetrec/core/trainer.py
+6
-1
fleetrec/core/trainers/ctr_coding_trainer.py
fleetrec/core/trainers/ctr_coding_trainer.py
+1
-1
fleetrec/run.py
fleetrec/run.py
+3
-4
未找到文件。
fleetrec/core/factory.py
浏览文件 @
2e91f58f
...
...
@@ -25,7 +25,7 @@ class TrainerFactory(object):
pass
@
staticmethod
def
_build_trainer
(
config
,
yaml_path
):
def
_build_trainer
(
yaml_path
):
print
(
envs
.
pretty_print_envs
(
envs
.
get_global_envs
()))
train_mode
=
envs
.
get_global_env
(
"train.trainer"
)
...
...
@@ -40,8 +40,8 @@ class TrainerFactory(object):
from
fleetrec.core.trainers.cluster_trainer
import
ClusterTrainer
trainer
=
ClusterTrainer
(
yaml_path
)
elif
train_mode
==
"CtrTraining"
:
from
fleetrec.core.trainers.ctr_
modul
_trainer
import
CtrPaddleTrainer
trainer
=
CtrPaddleTrainer
(
config
)
from
fleetrec.core.trainers.ctr_
coding
_trainer
import
CtrPaddleTrainer
trainer
=
CtrPaddleTrainer
(
yaml_path
)
elif
train_mode
==
"UserDefineTraining"
:
train_location
=
envs
.
get_global_env
(
"train.location"
)
train_dirname
=
os
.
path
.
dirname
(
train_location
)
...
...
@@ -63,7 +63,7 @@ class TrainerFactory(object):
raise
ValueError
(
"fleetrec's config only support yaml"
)
envs
.
set_global_envs
(
_config
)
trainer
=
TrainerFactory
.
_build_trainer
(
_config
,
config
)
trainer
=
TrainerFactory
.
_build_trainer
(
config
)
return
trainer
...
...
fleetrec/core/trainer.py
浏览文件 @
2e91f58f
...
...
@@ -14,6 +14,8 @@
import
abc
import
time
import
yaml
from
paddle
import
fluid
...
...
@@ -28,7 +30,10 @@ class Trainer(object):
self
.
_exe
=
fluid
.
Executor
(
self
.
_place
)
self
.
_exector_context
=
{}
self
.
_context
=
{
'status'
:
'uninit'
,
'is_exit'
:
False
}
self
.
_config
=
config
self
.
_config_yaml
=
config
with
open
(
config
,
'r'
)
as
rb
:
self
.
_config
=
yaml
.
load
(
rb
.
read
(),
Loader
=
yaml
.
FullLoader
)
def
regist_context_processor
(
self
,
status_name
,
processor
):
"""
...
...
fleetrec/core/trainers/ctr_coding_trainer.py
浏览文件 @
2e91f58f
...
...
@@ -62,7 +62,7 @@ class CtrPaddleTrainer(Trainer):
reader_class
=
envs
.
get_global_env
(
"class"
,
None
,
namespace
)
abs_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
reader
=
os
.
path
.
join
(
abs_dir
,
'../utils'
,
'reader_instance.py'
)
pipe_cmd
=
"python {} {} {} {}"
.
format
(
reader
,
reader_class
,
"TRAIN"
,
self
.
_config
)
pipe_cmd
=
"python {} {} {} {}"
.
format
(
reader
,
reader_class
,
"TRAIN"
,
self
.
_config
_yaml
)
train_data_path
=
envs
.
get_global_env
(
"train_data_path"
,
None
,
namespace
)
dataset
=
fluid
.
DatasetFactory
().
create_dataset
()
...
...
fleetrec/run.py
浏览文件 @
2e91f58f
...
...
@@ -17,7 +17,6 @@ def run(model_yaml):
def
single_engine
(
single_envs
,
model_yaml
):
print
(
envs
.
pretty_print_envs
(
single_envs
,
(
"Single Envs"
,
"Value"
)))
envs
.
set_runtime_envions
(
single_envs
)
run
(
model_yaml
)
...
...
@@ -33,8 +32,8 @@ def local_cluster_engine(cluster_envs, model_yaml):
def
local_mpi_engine
(
cluster_envs
,
model_yaml
):
from
fleetrec.core.engine.local_mpi_engine
import
LocalMPIEngine
print
(
envs
.
pretty_print_envs
(
cluster_envs
,
(
"Local MPI Cluster Envs"
,
"Value"
)))
print
(
envs
.
pretty_print_envs
(
cluster_envs
,
(
"Local MPI Cluster Envs"
,
"Value"
)))
envs
.
set_runtime_envions
(
cluster_envs
)
launch
=
LocalMPIEngine
(
cluster_envs
,
model_yaml
)
launch
.
run
()
...
...
@@ -79,7 +78,7 @@ if __name__ == "__main__":
if
not
mpi_path
:
raise
RuntimeError
(
"can not find mpirun, please check environment"
)
cluster_envs
=
{
"mpirun"
:
mpi_path
,
"train.trainer"
:
"CtrTraining"
}
cluster_envs
=
{
"mpirun"
:
mpi_path
,
"train.trainer"
:
"CtrTraining"
,
"log_dir"
:
"logs"
}
local_mpi_engine
(
cluster_envs
,
args
.
model
)
elif
args
.
engine
.
upper
()
==
"LOCAL_CLUSTER"
:
print
(
"use 1X1 ClusterTraining at localhost to run model: {}"
.
format
(
args
.
model
))
...
...
@@ -100,7 +99,7 @@ if __name__ == "__main__":
if
not
mpi_path
:
raise
RuntimeError
(
"can not find mpirun, please check environment"
)
cluster_envs
=
{
"mpirun"
:
mpi_path
,
"train.trainer"
:
"CtrTraining"
}
cluster_envs
=
{
"mpirun"
:
mpi_path
,
"train.trainer"
:
"CtrTraining"
,
"log_dir"
:
"logs"
}
local_mpi_engine
(
cluster_envs
,
args
.
model
)
elif
args
.
engine
.
upper
()
==
"CLUSTER"
:
print
(
"launch ClusterTraining with cluster to run model: {}"
.
format
(
args
.
model
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录