Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
c6a3a9fd
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c6a3a9fd
编写于
6月 10, 2020
作者:
T
tangwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix windows adapter
上级
fc724787
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
21 addition
and
34 deletion
+21
-34
core/factory.py
core/factory.py
+1
-1
core/reader.py
core/reader.py
+2
-2
core/trainer.py
core/trainer.py
+5
-15
core/utils/envs.py
core/utils/envs.py
+6
-7
doc/design.md
doc/design.md
+1
-1
run.py
run.py
+6
-8
未找到文件。
core/factory.py
浏览文件 @
c6a3a9fd
...
...
@@ -59,7 +59,7 @@ class TrainerFactory(object):
@
staticmethod
def
create
(
config
):
_config
=
envs
.
load_yaml
(
config
)
envs
.
set_global_envs
(
_config
,
True
)
envs
.
set_global_envs
(
_config
)
trainer
=
TrainerFactory
.
_build_trainer
(
config
)
return
trainer
...
...
core/reader.py
浏览文件 @
c6a3a9fd
...
...
@@ -26,7 +26,7 @@ class ReaderBase(dg.MultiSlotDataGenerator):
def
__init__
(
self
,
config
):
dg
.
MultiSlotDataGenerator
.
__init__
(
self
)
_config
=
envs
.
load_yaml
(
config
)
envs
.
set_global_envs
(
_config
,
True
)
envs
.
set_global_envs
(
_config
)
@
abc
.
abstractmethod
def
init
(
self
):
...
...
@@ -44,7 +44,7 @@ class SlotReader(dg.MultiSlotDataGenerator):
def
__init__
(
self
,
config
):
dg
.
MultiSlotDataGenerator
.
__init__
(
self
)
_config
=
envs
.
load_yaml
(
config
)
envs
.
set_global_envs
(
_config
,
True
)
envs
.
set_global_envs
(
_config
)
def
init
(
self
,
sparse_slots
,
dense_slots
,
padding
=
0
):
from
operator
import
mul
...
...
core/trainer.py
浏览文件 @
c6a3a9fd
...
...
@@ -16,7 +16,6 @@ import abc
import
os
import
time
import
sys
import
yaml
import
traceback
from
paddle
import
fluid
...
...
@@ -74,11 +73,14 @@ class Trainer(object):
phase_names
=
envs
.
get_global_env
(
"runner."
+
self
.
_runner_name
+
".phases"
,
None
)
_config
=
envs
.
load_yaml
(
config
)
phases
=
[]
if
phase_names
is
None
:
phases
=
envs
.
get_global_env
(
"phase"
)
phases
=
_config
.
get
(
"phase"
)
else
:
for
phase
in
envs
.
get_global_env
(
"phase"
):
for
phase
in
_config
.
get
(
"phase"
):
if
phase
[
"name"
]
in
phase_names
:
phases
.
append
(
phase
)
...
...
@@ -244,15 +246,3 @@ class Trainer(object):
self
.
context_process
(
self
.
_context
)
if
self
.
_context
[
'is_exit'
]:
break
def
user_define_engine
(
engine_yaml
):
_config
=
envs
.
load_yaml
(
engine_yaml
)
envs
.
set_runtime_environs
(
_config
)
train_location
=
envs
.
get_global_env
(
"engine.file"
)
train_dirname
=
os
.
path
.
dirname
(
train_location
)
base_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
train_location
))[
0
]
sys
.
path
.
append
(
train_dirname
)
trainer_class
=
envs
.
lazy_instance_by_fliename
(
base_name
,
"UserDefineTraining"
)
return
trainer_class
core/utils/envs.py
浏览文件 @
c6a3a9fd
...
...
@@ -20,9 +20,8 @@ import socket
import
sys
import
traceback
import
yaml
global_envs
=
{}
global_envs_flatten
=
{}
def
flatten_environs
(
envs
,
separator
=
"."
):
...
...
@@ -68,7 +67,7 @@ def get_fleet_mode():
return
fleet_mode
def
set_global_envs
(
envs
,
adapter
):
def
set_global_envs
(
envs
):
assert
isinstance
(
envs
,
dict
)
def
fatten_env_namespace
(
namespace_nests
,
local_envs
):
...
...
@@ -92,10 +91,9 @@ def set_global_envs(envs, adapter):
fatten_env_namespace
([],
envs
)
if
adapter
:
workspace_adapter
()
os_path_adapter
()
reader_adapter
()
workspace_adapter
()
os_path_adapter
()
reader_adapter
()
def
get_global_env
(
env_name
,
default_value
=
None
,
namespace
=
None
):
...
...
@@ -134,6 +132,7 @@ def workspace_adapter():
workspace
=
global_envs
.
get
(
"workspace"
)
if
not
workspace
:
return
workspace
=
paddlerec_adapter
(
workspace
)
for
name
,
value
in
global_envs
.
items
():
...
...
doc/design.md
浏览文件 @
c6a3a9fd
...
...
@@ -197,7 +197,7 @@ class Reader(dg.MultiSlotDataGenerator):
def
__init__
(
self
,
config
):
dg
.
MultiSlotDataGenerator
.
__init__
(
self
)
_config
=
envs
.
load_yaml
(
config
)
envs
.
set_global_envs
(
_config
,
True
)
envs
.
set_global_envs
(
_config
)
@
abc
.
abstractmethod
def
init
(
self
):
...
...
run.py
浏览文件 @
c6a3a9fd
...
...
@@ -110,7 +110,6 @@ def get_modes(running_config):
def
get_engine
(
args
,
running_config
,
mode
):
transpiler
=
get_transpiler
()
_envs
=
envs
.
load_yaml
(
args
.
model
)
engine_class
=
"."
.
join
([
"runner"
,
mode
,
"class"
])
engine_device
=
"."
.
join
([
"runner"
,
mode
,
"device"
])
...
...
@@ -122,11 +121,14 @@ def get_engine(args, running_config, mode):
mode
,
engine_class
))
device
=
running_config
.
get
(
engine_device
,
None
)
engine
=
engine
.
upper
()
device
=
device
.
upper
()
if
device
is
None
:
print
(
"not find device be specified in yaml, set CPU as default"
)
device
=
"CPU"
if
device
.
upper
()
==
"GPU"
:
if
device
==
"GPU"
:
selected_gpus
=
running_config
.
get
(
device_gpu_choices
,
None
)
if
selected_gpus
is
None
:
...
...
@@ -142,7 +144,6 @@ def get_engine(args, running_config, mode):
if
selected_gpus_num
>
1
:
engine
=
"LOCAL_CLUSTER"
engine
=
engine
.
upper
()
if
engine
not
in
engine_choices
:
raise
ValueError
(
"{} can not be chosen in {}"
.
format
(
engine_class
,
engine_choices
))
...
...
@@ -180,9 +181,7 @@ def set_runtime_envs(cluster_envs, engine_yaml):
def
single_train_engine
(
args
):
_envs
=
envs
.
load_yaml
(
args
.
model
)
run_extras
=
get_all_inters_from_yaml
(
args
.
model
,
[
"runner."
])
mode
=
envs
.
get_runtime_environ
(
"mode"
)
trainer_class
=
"."
.
join
([
"runner"
,
mode
,
"trainer_class"
])
fleet_class
=
"."
.
join
([
"runner"
,
mode
,
"fleet_mode"
])
...
...
@@ -435,7 +434,7 @@ def local_mpi_engine(args):
def
get_abs_model
(
model
):
if
model
.
startswith
(
"paddlerec."
):
dir
=
envs
.
pa
th
_adapter
(
model
)
dir
=
envs
.
pa
ddlerec
_adapter
(
model
)
path
=
os
.
path
.
join
(
dir
,
"config.yaml"
)
else
:
if
not
os
.
path
.
isfile
(
model
):
...
...
@@ -453,13 +452,12 @@ if __name__ == "__main__":
envs
.
set_runtime_environs
({
"PACKAGE_BASE"
:
abs_dir
})
args
=
parser
.
parse_args
()
model_name
=
args
.
model
.
split
(
'.'
)[
-
1
]
args
.
model
=
get_abs_model
(
args
.
model
)
if
not
validation
.
yaml_validation
(
args
.
model
):
sys
.
exit
(
-
1
)
engine_registry
()
engine_registry
()
running_config
=
get_all_inters_from_yaml
(
args
.
model
,
[
"mode"
,
"runner."
])
modes
=
get_modes
(
running_config
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录