Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
情韵~
PaddleRec
提交
9f64b843
P
PaddleRec
项目概览
情韵~
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9f64b843
编写于
5月 06, 2020
作者:
C
chengmo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update & ifx
上级
12c654fe
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
47 addition
and
22 deletion
+47
-22
fleet_rec/core/factory.py
fleet_rec/core/factory.py
+14
-6
fleet_rec/core/trainers/tdm_single_trainer.py
fleet_rec/core/trainers/tdm_single_trainer.py
+1
-1
fleet_rec/run.py
fleet_rec/run.py
+1
-1
models/recall/tdm/config.yaml
models/recall/tdm/config.yaml
+8
-8
models/recall/tdm/model.py
models/recall/tdm/model.py
+17
-1
models/recall/tdm/tdm_reader.py
models/recall/tdm/tdm_reader.py
+6
-5
未找到文件。
fleet_rec/core/factory.py
浏览文件 @
9f64b843
...
...
@@ -19,15 +19,22 @@ import yaml
from
fleetrec.core.utils
import
envs
trainer_abs
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"trainers"
)
trainer_abs
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"trainers"
)
trainers
=
{}
def
trainer_registry
():
trainers
[
"SingleTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"single_trainer.py"
)
trainers
[
"ClusterTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"cluster_trainer.py"
)
trainers
[
"CtrCodingTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"ctr_coding_trainer.py"
)
trainers
[
"CtrModulTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"ctr_modul_trainer.py"
)
trainers
[
"SingleTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"single_trainer.py"
)
trainers
[
"ClusterTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"cluster_trainer.py"
)
trainers
[
"CtrCodingTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"ctr_coding_trainer.py"
)
trainers
[
"CtrModulTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"ctr_modul_trainer.py"
)
trainers
[
"TDMSingleTrainer"
]
=
os
.
path
.
join
(
trainer_abs
,
"tdm_single_trainer.py"
)
trainer_registry
()
...
...
@@ -46,7 +53,8 @@ class TrainerFactory(object):
if
trainer_abs
is
None
:
if
not
os
.
path
.
isfile
(
train_mode
):
raise
IOError
(
"trainer {} can not be recognized"
.
format
(
train_mode
))
raise
IOError
(
"trainer {} can not be recognized"
.
format
(
train_mode
))
trainer_abs
=
train_mode
train_mode
=
"UserDefineTrainer"
...
...
fleet_rec/core/trainers/tdm_trainer.py
→
fleet_rec/core/trainers/tdm_
single_
trainer.py
浏览文件 @
9f64b843
...
...
@@ -30,7 +30,7 @@ logger = logging.getLogger("fluid")
logger
.
setLevel
(
logging
.
INFO
)
class
T
dm
SingleTrainer
(
SingleTrainer
):
class
T
DM
SingleTrainer
(
SingleTrainer
):
def
processor_register
(
self
):
self
.
regist_context_processor
(
'uninit'
,
self
.
instance
)
self
.
regist_context_processor
(
'init_pass'
,
self
.
init
)
...
...
fleet_rec/run.py
浏览文件 @
9f64b843
...
...
@@ -202,7 +202,7 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
description
=
'fleet-rec run'
)
parser
.
add_argument
(
"-m"
,
"--model"
,
type
=
str
)
parser
.
add_argument
(
"-e"
,
"--engine"
,
type
=
str
,
choices
=
[
"single"
,
"local_cluster"
,
"cluster"
])
choices
=
[
"single"
,
"local_cluster"
,
"cluster"
,
"tdm_single"
])
parser
.
add_argument
(
"-d"
,
"--device"
,
type
=
str
,
choices
=
[
"cpu"
,
"gpu"
],
default
=
"cpu"
)
...
...
models/recall/tdm/config.yaml
浏览文件 @
9f64b843
...
...
@@ -17,20 +17,20 @@ train:
# for cluster training
strategy
:
"
async"
epochs
:
10
epochs
:
4
workspace
:
"
fleetrec.models.recall.tdm"
reader
:
batch_size
:
32
class
:
"
{workspace}/tdm_reader.py"
train_data_path
:
"
{workspace}/data/train
_data
"
test_data_path
:
"
{workspace}/data/test
_data
"
train_data_path
:
"
{workspace}/data/train"
test_data_path
:
"
{workspace}/data/test"
model
:
models
:
"
{workspace}/model.py"
hyper_parameters
:
node_emb_size
:
64
input_emb_size
:
64
input_emb_size
:
768
neg_sampling_list
:
[
1
,
2
,
3
,
4
]
output_positive
:
True
topK
:
1
...
...
@@ -52,10 +52,10 @@ train:
persistables_model_path
:
"
"
load_tree
:
True
tree_layer_path
:
"
"
tree_travel_path
:
"
"
tree_info_path
:
"
"
tree_emb_path
:
"
"
tree_layer_path
:
"
{workspace}/tree/layer_list.txt
"
tree_travel_path
:
"
{workspace}/tree/travel_list.npy
"
tree_info_path
:
"
{workspace}/tree/tree_info.npy
"
tree_emb_path
:
"
{workspace}/tree/tree_emb.npy
"
save_init_model
:
True
init_model_path
:
"
"
...
...
models/recall/tdm/model.py
浏览文件 @
9f64b843
...
...
@@ -45,7 +45,7 @@ class Model(ModelBase):
self
.
node_emb_size
=
envs
.
get_global_env
(
"hyper_parameters.node_emb_size"
,
64
,
self
.
_namespace
)
self
.
input_emb_size
=
envs
.
get_global_env
(
"hyper_parameters.input_emb_size"
,
64
,
self
.
_namespace
)
"hyper_parameters.input_emb_size"
,
768
,
self
.
_namespace
)
self
.
act
=
envs
.
get_global_env
(
"hyper_parameters.act"
,
"tanh"
,
self
.
_namespace
)
self
.
neg_sampling_list
=
envs
.
get_global_env
(
...
...
@@ -61,6 +61,7 @@ class Model(ModelBase):
def
train_net
(
self
):
self
.
train_input
()
self
.
tdm_net
()
self
.
create_info
()
self
.
avg_loss
()
self
.
metrics
()
...
...
@@ -174,11 +175,26 @@ class Model(ModelBase):
mask_index
.
stop_gradient
=
True
self
.
mask_cost
=
fluid
.
layers
.
gather_nd
(
cost
,
mask_index
)
softmax_prob
=
fluid
.
layers
.
unsqueeze
(
input
=
softmax_prob
,
axes
=
[
1
])
self
.
mask_prob
=
fluid
.
layers
.
gather_nd
(
softmax_prob
,
mask_index
)
self
.
mask_label
=
fluid
.
layers
.
gather_nd
(
labels_reshape
,
mask_index
)
self
.
_predict
=
self
.
mask_prob
def
create_info
(
self
):
fluid
.
default_startup_program
().
global_block
().
create_var
(
name
=
"TDM_Tree_Info"
,
dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
INT32
,
shape
=
[
self
.
node_nums
,
3
+
self
.
child_nums
],
persistable
=
True
,
initializer
=
fluid
.
initializer
.
ConstantInitializer
(
0
))
fluid
.
default_main_program
().
global_block
().
create_var
(
name
=
"TDM_Tree_Info"
,
dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
INT32
,
shape
=
[
self
.
node_nums
,
3
+
self
.
child_nums
],
persistable
=
True
)
def
avg_loss
(
self
):
avg_cost
=
fluid
.
layers
.
reduce_mean
(
self
.
mask_cost
)
self
.
_cost
=
avg_cost
...
...
models/recall/tdm/tdm_reader.py
浏览文件 @
9f64b843
...
...
@@ -18,16 +18,17 @@
from
__future__
import
print_function
from
fleetrec.core.reader
import
Reader
from
fleetrec.core.utils
import
envs
class
TrainReader
(
reader
):
class
TrainReader
(
Reader
):
def
init
(
self
):
pass
def
reader
(
self
,
line
):
def
generate_sample
(
self
,
line
):
"""
Read the data line by line and process it as a dictionary
"""
def
iterato
r
():
def
reade
r
():
"""
This function needs to be implemented by the user, based on data format
"""
...
...
@@ -38,4 +39,4 @@ class TrainReader(reader):
feature_name
=
[
"input_emb"
,
"item_label"
]
yield
zip
(
feature_name
,
[
input_emb
]
+
[
item_label
])
return
R
eader
return
r
eader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录