Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
95f2364b
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
95f2364b
编写于
5月 27, 2020
作者:
X
xjqbest
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix
上级
fc94d505
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
57 addition
and
12 deletion
+57
-12
core/trainers/single_trainer.py
core/trainers/single_trainer.py
+55
-7
core/utils/dataloader_instance.py
core/utils/dataloader_instance.py
+1
-5
models/rank/dnn/config.yaml
models/rank/dnn/config.yaml
+1
-0
未找到文件。
core/trainers/single_trainer.py
浏览文件 @
95f2364b
...
...
@@ -70,7 +70,6 @@ class SingleTrainer(TranspileTrainer):
if
sparse_slots
is
None
and
dense_slots
is
None
:
pipe_cmd
=
"python {} {} {} {}"
.
format
(
reader
,
reader_class
,
"fake"
,
self
.
_config_yaml
)
else
:
if
sparse_slots
is
None
:
sparse_slots
=
"#"
...
...
@@ -98,7 +97,7 @@ class SingleTrainer(TranspileTrainer):
break
return
dataset
def
_get_dataloader
(
self
,
dataset_name
):
def
_get_dataloader
(
self
,
dataset_name
,
dataloader
):
name
=
"dataset."
+
dataset_name
+
"."
sparse_slots
=
envs
.
get_global_env
(
name
+
"sparse_slots"
)
dense_slots
=
envs
.
get_global_env
(
name
+
"dense_slots"
)
...
...
@@ -106,9 +105,7 @@ class SingleTrainer(TranspileTrainer):
batch_size
=
envs
.
get_global_env
(
name
+
"batch_size"
)
reader_class
=
envs
.
get_global_env
(
"data_convertor"
)
abs_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
#reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
if
sparse_slots
is
None
and
dense_slots
is
None
:
#reader_class = envs.get_global_env("class")
reader
=
dataloader_instance
.
dataloader_by_name
(
reader_class
,
dataset_name
,
self
.
_config_yaml
)
reader_class
=
envs
.
lazy_instance_by_fliename
(
reader_class
,
"TrainReader"
)
reader_ins
=
reader_class
(
self
.
_config_yaml
)
...
...
@@ -181,10 +178,10 @@ class SingleTrainer(TranspileTrainer):
model_path
=
model_dict
[
"model"
].
replace
(
"{workspace}"
,
envs
.
path_adapter
(
self
.
_env
[
"workspace"
]))
model
=
envs
.
lazy_instance_by_fliename
(
model_path
,
"Model"
)(
self
.
_env
)
model
.
_data_var
=
model
.
input_data
(
dataset_name
=
model_dict
[
"dataset_name"
])
#model._init_slots(name=model_dict["name"])
if
envs
.
get_global_env
(
"dataset."
+
dataset_name
+
".type"
)
==
"DataLoader"
:
model
.
_init_dataloader
()
model
.
net
(
model
.
_data_var
)
self
.
_get_dataloader
(
dataset_name
,
model
.
_data_loader
)
model
.
net
(
model
.
_data_var
,
is_infer
=
model_dict
[
"is_infer"
])
optimizer
=
model
.
_build_optimizer
(
opt_name
,
opt_lr
,
opt_strategy
)
optimizer
.
minimize
(
model
.
_cost
)
self
.
_model
[
model_dict
[
"name"
]][
0
]
=
train_program
...
...
@@ -215,6 +212,8 @@ class SingleTrainer(TranspileTrainer):
self
.
_executor_dataloader_train
(
model_dict
)
else
:
self
.
_executor_dataset_train
(
model_dict
)
with
fluid
.
scope_guard
(
self
.
_model
[
model_name
][
2
]):
self
.
save
(
self
,
j
)
end_time
=
time
.
time
()
seconds
=
end_time
-
begin_time
print
(
"epoch {} done, time elasped: {}"
.
format
(
j
,
seconds
))
...
...
@@ -270,7 +269,6 @@ class SingleTrainer(TranspileTrainer):
batch_id
=
0
scope
=
self
.
_model
[
model_name
][
2
]
program
=
self
.
_model
[
model_name
][
0
]
#print(metrics_varnames)
with
fluid
.
scope_guard
(
scope
):
try
:
while
True
:
...
...
@@ -287,3 +285,53 @@ class SingleTrainer(TranspileTrainer):
def
terminal
(
self
,
context
):
context
[
'is_exit'
]
=
True
def
save
(
self
,
epoch_id
,
is_fleet
=
False
):
def
need_save
(
epoch_id
,
epoch_interval
,
is_last
=
False
):
if
is_last
:
return
True
if
epoch_id
==
-
1
:
return
False
return
epoch_id
%
epoch_interval
==
0
def
save_inference_model
():
save_interval
=
envs
.
get_global_env
(
"epoch.save_inference_interval"
,
-
1
)
if
not
need_save
(
epoch_id
,
save_interval
,
False
):
return
feed_varnames
=
envs
.
get_global_env
(
"epoch.save_inference_feed_varnames"
,
None
)
fetch_varnames
=
envs
.
get_global_env
(
"epoch.save_inference_fetch_varnames"
,
None
)
if
feed_varnames
is
None
or
fetch_varnames
is
None
:
return
fetch_vars
=
[
fluid
.
default_main_program
().
global_block
().
vars
[
varname
]
for
varname
in
fetch_varnames
]
dirname
=
envs
.
get_global_env
(
"epoch.save_inference_path"
,
None
)
assert
dirname
is
not
None
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
if
is_fleet
:
fleet
.
save_inference_model
(
self
.
_exe
,
dirname
,
feed_varnames
,
fetch_vars
)
else
:
fluid
.
io
.
save_inference_model
(
dirname
,
feed_varnames
,
fetch_vars
,
self
.
_exe
)
self
.
inference_models
.
append
((
epoch_id
,
dirname
))
def
save_persistables
():
save_interval
=
envs
.
get_global_env
(
"epoch.save_checkpoint_interval"
,
-
1
)
if
not
need_save
(
epoch_id
,
save_interval
,
False
):
return
dirname
=
envs
.
get_global_env
(
"epoch.save_checkpoint_path"
,
None
)
assert
dirname
is
not
None
dirname
=
os
.
path
.
join
(
dirname
,
str
(
epoch_id
))
if
is_fleet
:
fleet
.
save_persistables
(
self
.
_exe
,
dirname
)
else
:
fluid
.
io
.
save_persistables
(
self
.
_exe
,
dirname
)
self
.
increment_models
.
append
((
epoch_id
,
dirname
))
save_persistables
()
save_inference_model
()
core/utils/dataloader_instance.py
浏览文件 @
95f2364b
...
...
@@ -23,10 +23,6 @@ def dataloader_by_name(readerclass, dataset_name, yaml_file):
reader_class
=
lazy_instance_by_fliename
(
readerclass
,
"TrainReader"
)
name
=
"dataset."
+
dataset_name
+
"."
data_path
=
get_global_env
(
name
+
"data_path"
)
#else:
# reader_name = "SlotReader"
# namespace = "evaluate.reader"
# data_path = get_global_env("test_data_path", None, namespace)
if
data_path
.
startswith
(
"paddlerec::"
):
package_base
=
get_runtime_environ
(
"PACKAGE_BASE"
)
...
...
@@ -71,7 +67,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file):
data_path
=
os
.
path
.
join
(
package_base
,
data_path
.
split
(
"::"
)[
1
])
files
=
[
str
(
data_path
)
+
"/%s"
%
x
for
x
in
os
.
listdir
(
data_path
)]
sparse
=
get_global_env
(
name
+
"sparse_slots"
)
dense
=
get_global_env
(
name
+
"dense_slots"
)
padding
=
get_global_env
(
name
+
"padding"
,
0
)
...
...
models/rank/dnn/config.yaml
浏览文件 @
95f2364b
...
...
@@ -51,3 +51,4 @@ executor:
model
:
"
{workspace}/model.py"
dataset_name
:
dataset_2
thread_num
:
1
is_infer
:
False
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录