Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
74ee4f3f
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
74ee4f3f
编写于
4月 30, 2020
作者:
T
tangwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add optimizer config, add workspace
上级
3006e6b2
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
60 addition
and
20 deletion
+60
-20
demo/ctr-dnn_train.yaml
demo/ctr-dnn_train.yaml
+7
-14
fleet_rec/core/factory.py
fleet_rec/core/factory.py
+2
-0
fleet_rec/core/model.py
fleet_rec/core/model.py
+28
-0
fleet_rec/core/utils/envs.py
fleet_rec/core/utils/envs.py
+21
-0
models/rank/dnn/model.py
models/rank/dnn/model.py
+2
-6
未找到文件。
demo/ctr-dnn_train.yaml
浏览文件 @
74ee4f3f
...
...
@@ -18,14 +18,15 @@ train:
strategy
:
"
async"
epochs
:
10
workspace
:
"
fleetrec.models.rank.dnn"
reader
:
batch_size
:
2
class
:
"
fleetrec.models.rank.criteo_reader
"
train_data_path
:
"
fleetrec::models/rank/dnn
/data/train"
class
:
"
{workspace}/../criteo_reader.py
"
train_data_path
:
"
{workspace}
/data/train"
model
:
models
:
"
fleetrec.models.rank.dnn.model
"
models
:
"
{workspace}/model.py
"
hyper_parameters
:
sparse_inputs_slots
:
27
sparse_feature_number
:
1000001
...
...
@@ -33,22 +34,14 @@ train:
dense_input_dim
:
13
fc_sizes
:
[
512
,
256
,
128
,
32
]
learning_rate
:
0.001
optimizer
:
adam
save
:
increment
:
dirname
:
"
models_for_
increment"
dirname
:
"
increment"
epoch_interval
:
2
save_last
:
True
inference
:
dirname
:
"
models_for_
inference"
dirname
:
"
inference"
epoch_interval
:
4
feed_varnames
:
[
"
C1"
,
"
C2"
,
"
C3"
]
fetch_varnames
:
"
predict"
save_last
:
True
evaluate
:
batch_size
:
32
train_thread_num
:
12
reader
:
"
reader.py"
fleet_rec/core/factory.py
浏览文件 @
74ee4f3f
...
...
@@ -67,6 +67,8 @@ class TrainerFactory(object):
raise
ValueError
(
"fleetrec's config only support yaml"
)
envs
.
set_global_envs
(
_config
)
envs
.
update_workspace
()
trainer
=
TrainerFactory
.
_build_trainer
(
config
)
return
trainer
...
...
fleet_rec/core/model.py
浏览文件 @
74ee4f3f
import
abc
import
paddle.fluid
as
fluid
from
fleetrec.core.utils
import
envs
class
Model
(
object
):
"""R
"""
...
...
@@ -33,11 +37,35 @@ class Model(object):
def
get_fetch_period
(
self
):
return
self
.
_fetch_interval
def
_build_optimizer
(
self
,
name
,
lr
):
name
=
name
.
upper
()
optimizers
=
[
"SGD"
,
"ADAM"
,
"ADAGRAD"
]
if
name
not
in
optimizers
:
raise
ValueError
(
"configured optimizer can only supported SGD/Adam/Adagrad"
)
if
name
==
"SGD"
:
optimizer_i
=
fluid
.
optimizer
.
Adam
(
lr
,
lazy_mode
=
True
)
elif
name
==
"ADAM"
:
optimizer_i
=
fluid
.
optimizer
.
Adam
(
lr
,
lazy_mode
=
True
)
elif
name
==
"ADAGRAD"
:
optimizer_i
=
fluid
.
optimizer
.
Adam
(
lr
,
lazy_mode
=
True
)
else
:
raise
ValueError
(
"configured optimizer can only supported SGD/Adam/Adagrad"
)
return
optimizer_i
def
optimizer
(
self
):
learning_rate
=
envs
.
get_global_env
(
"hyper_parameters.learning_rate"
,
None
,
self
.
_namespace
)
optimizer
=
envs
.
get_global_env
(
"hyper_parameters.optimizer"
,
None
,
self
.
_namespace
)
return
self
.
_build_optimizer
(
optimizer
,
learning_rate
)
@
abc
.
abstractmethod
def
train_net
(
self
):
"""R
"""
pass
@
abc
.
abstractmethod
def
infer_net
(
self
):
pass
fleet_rec/core/utils/envs.py
浏览文件 @
74ee4f3f
...
...
@@ -46,9 +46,11 @@ def set_runtime_environs(environs):
for
k
,
v
in
environs
.
items
():
os
.
environ
[
k
]
=
str
(
v
)
def
get_runtime_environ
(
key
):
return
os
.
getenv
(
key
,
None
)
def
get_trainer
():
train_mode
=
get_runtime_environ
(
"train.trainer.trainer"
)
return
train_mode
...
...
@@ -83,6 +85,25 @@ def get_global_envs():
return
global_envs
def
update_workspace
():
workspace
=
global_envs
.
get
(
"train.workspace"
,
None
)
if
not
workspace
:
return
workspace
=
""
# is fleet inner models
if
workspace
.
startswith
(
"fleetrec."
):
fleet_package
=
get_runtime_environ
(
"PACKAGE_BASE"
)
workspace_dir
=
workspace
.
split
(
"fleetrec."
)[
1
].
replace
(
"."
,
"/"
)
path
=
os
.
path
.
join
(
fleet_package
,
workspace_dir
)
else
:
path
=
workspace
for
name
,
value
in
global_envs
.
items
():
if
isinstance
(
value
,
str
):
value
=
value
.
replace
(
"{workspace}"
,
path
)
global_envs
[
name
]
=
value
def
pretty_print_envs
(
envs
,
header
=
None
):
spacing
=
5
max_k
=
45
...
...
models/rank/dnn/model.py
浏览文件 @
74ee4f3f
...
...
@@ -63,12 +63,9 @@ class Model(ModelBase):
feed_list
=
self
.
_data_var
,
capacity
=
64
,
use_double_buffer
=
False
,
iterable
=
False
)
def
net
(
self
):
trainer
=
envs
.
get_trainer
()
is_distributed
=
True
if
trainer
==
"CtrTrainer"
else
False
is_distributed
=
True
if
envs
.
get_trainer
()
==
"CtrTrainer"
else
False
sparse_feature_number
=
envs
.
get_global_env
(
"hyper_parameters.sparse_feature_number"
,
None
,
self
.
_namespace
)
sparse_feature_dim
=
envs
.
get_global_env
(
"hyper_parameters.sparse_feature_dim"
,
None
,
self
.
_namespace
)
sparse_feature_dim
=
9
if
trainer
==
"CtrTrainer"
else
sparse_feature_dim
def
embedding_layer
(
input
):
emb
=
fluid
.
layers
.
embedding
(
...
...
@@ -106,8 +103,7 @@ class Model(ModelBase):
size
=
2
,
act
=
"softmax"
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
scale
=
1
/
math
.
sqrt
(
fcs
[
-
1
].
shape
[
1
]))),
)
scale
=
1
/
math
.
sqrt
(
fcs
[
-
1
].
shape
[
1
]))))
self
.
predict
=
predict
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录