Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
1a2cafdc
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1a2cafdc
编写于
8月 05, 2020
作者:
W
wuzhihua
提交者:
GitHub
8月 05, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #162 from 123malin/gnn
Gnn
上级
d02319d8
5a5abe15
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
283 addition
and
52 deletion
+283
-52
core/metric.py
core/metric.py
+2
-2
core/trainers/framework/runner.py
core/trainers/framework/runner.py
+0
-1
models/recall/gnn/config.yaml
models/recall/gnn/config.yaml
+13
-11
models/recall/gnn/data/download.py
models/recall/gnn/data/download.py
+7
-2
models/recall/gnn/data/preprocess.py
models/recall/gnn/data/preprocess.py
+11
-21
models/recall/gnn/data_prepare.sh
models/recall/gnn/data_prepare.sh
+20
-10
models/recall/gnn/model.py
models/recall/gnn/model.py
+6
-5
models/recall/gnn/readme.md
models/recall/gnn/readme.md
+224
-0
未找到文件。
core/metric.py
浏览文件 @
1a2cafdc
...
@@ -69,8 +69,8 @@ class Metric(object):
...
@@ -69,8 +69,8 @@ class Metric(object):
global_metrics
=
dict
()
global_metrics
=
dict
()
for
key
in
self
.
_global_metric_state_vars
:
for
key
in
self
.
_global_metric_state_vars
:
varname
,
dtype
=
self
.
_global_metric_state_vars
[
key
]
varname
,
dtype
=
self
.
_global_metric_state_vars
[
key
]
global_metrics
[
key
]
=
self
.
get_global_metric_state
(
fleet
,
scope
,
global_metrics
[
key
]
=
self
.
_
get_global_metric_state
(
fleet
,
scope
,
varname
)
varname
)
return
self
.
_calculate
(
global_metrics
)
return
self
.
_calculate
(
global_metrics
)
...
...
core/trainers/framework/runner.py
浏览文件 @
1a2cafdc
...
@@ -520,7 +520,6 @@ class SingleInferRunner(RunnerBase):
...
@@ -520,7 +520,6 @@ class SingleInferRunner(RunnerBase):
def
run
(
self
,
context
):
def
run
(
self
,
context
):
self
.
_dir_check
(
context
)
self
.
_dir_check
(
context
)
self
.
epoch_model_name_list
.
sort
()
for
index
,
epoch_name
in
enumerate
(
self
.
epoch_model_name_list
):
for
index
,
epoch_name
in
enumerate
(
self
.
epoch_model_name_list
):
for
model_dict
in
context
[
"phases"
]:
for
model_dict
in
context
[
"phases"
]:
model_class
=
context
[
"model"
][
model_dict
[
"name"
]][
"model"
]
model_class
=
context
[
"model"
][
model_dict
[
"name"
]][
"model"
]
...
...
models/recall/gnn/config.yaml
浏览文件 @
1a2cafdc
...
@@ -42,30 +42,32 @@ hyper_parameters:
...
@@ -42,30 +42,32 @@ hyper_parameters:
gnn_propogation_steps
:
1
gnn_propogation_steps
:
1
# select runner by name
# select runner by name
mode
:
train_runner
mode
:
[
single_cpu_train
,
single_cpu_infer
]
# config of each runner.
# config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process.
# runner is a kind of paddle training class, which wraps the train/infer process.
runner
:
runner
:
-
name
:
train_runner
-
name
:
single_cpu_train
class
:
train
class
:
train
# num of epochs
# num of epochs
epochs
:
2
epochs
:
5
# device to run training or infer
# device to run training or infer
device
:
cpu
device
:
cpu
save_checkpoint_interval
:
1
# save model interval of epochs
save_checkpoint_interval
:
1
# save model interval of epochs
save_inference_interval
:
1
# save inference
save_inference_interval
:
1
# save inference
save_checkpoint_path
:
"
increment"
# save checkpoint path
save_checkpoint_path
:
"
increment
_gnn
"
# save checkpoint path
save_inference_path
:
"
inference"
# save inference path
save_inference_path
:
"
inference
_gnn
"
# save inference path
save_inference_feed_varnames
:
[]
# feed vars of save inference
save_inference_feed_varnames
:
[]
# feed vars of save inference
save_inference_fetch_varnames
:
[]
# fetch vars of save inference
save_inference_fetch_varnames
:
[]
# fetch vars of save inference
init_model_path
:
"
"
# load model path
init_model_path
:
"
"
# load model path
print_interval
:
1
print_interval
:
1
-
name
:
infer_runner
phases
:
[
phase1
]
-
name
:
single_cpu_infer
class
:
infer
class
:
infer
# device to run training or infer
# device to run training or infer
device
:
cpu
device
:
cpu
print_interval
:
1
print_interval
:
1
init_model_path
:
"
increment/0"
# load model path
init_model_path
:
"
increment_gnn"
# load model path
phases
:
[
phase2
]
# runner will run all the phase in each epoch
# runner will run all the phase in each epoch
phase
:
phase
:
...
@@ -73,7 +75,7 @@ phase:
...
@@ -73,7 +75,7 @@ phase:
model
:
"
{workspace}/model.py"
# user-defined model
model
:
"
{workspace}/model.py"
# user-defined model
dataset_name
:
dataset_train
# select dataset by name
dataset_name
:
dataset_train
# select dataset by name
thread_num
:
1
thread_num
:
1
#
- name: phase2
-
name
:
phase2
#
model: "{workspace}/model.py" # user-defined model
model
:
"
{workspace}/model.py"
# user-defined model
#
dataset_name: dataset_infer # select dataset by name
dataset_name
:
dataset_infer
# select dataset by name
#
thread_num: 1
thread_num
:
1
models/recall/gnn/data/download.py
浏览文件 @
1a2cafdc
...
@@ -57,5 +57,10 @@ def _download_file(url, savepath, print_progress):
...
@@ -57,5 +57,10 @@ def _download_file(url, savepath, print_progress):
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
50
,
100
),
end
=
True
)
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
50
,
100
),
end
=
True
)
_download_file
(
"https://sr-gnn.bj.bcebos.com/train-item-views.csv"
,
if
sys
.
argv
[
1
]
==
"diginetica"
:
"./train-item-views.csv"
,
True
)
_download_file
(
"https://sr-gnn.bj.bcebos.com/train-item-views.csv"
,
"./train-item-views.csv"
,
True
)
elif
sys
.
argv
[
1
]
==
"yoochoose"
:
_download_file
(
"https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat"
,
"./yoochoose-clicks.dat"
,
True
)
models/recall/gnn/data/preprocess.py
浏览文件 @
1a2cafdc
...
@@ -41,39 +41,29 @@ with open(dataset, "r") as f:
...
@@ -41,39 +41,29 @@ with open(dataset, "r") as f:
curdate
=
None
curdate
=
None
for
data
in
reader
:
for
data
in
reader
:
sessid
=
data
[
'session_id'
]
sessid
=
data
[
'session_id'
]
if
curdate
and
not
curid
==
sessid
:
date
=
''
date
=
''
if
opt
.
dataset
==
'yoochoose'
:
date
=
time
.
mktime
(
time
.
strptime
(
curdate
[:
19
],
'%Y-%m-%dT%H:%M:%S'
))
else
:
date
=
time
.
mktime
(
time
.
strptime
(
curdate
,
'%Y-%m-%d'
))
sess_date
[
curid
]
=
date
curid
=
sessid
if
opt
.
dataset
==
'yoochoose'
:
if
opt
.
dataset
==
'yoochoose'
:
item
=
data
[
'item_id'
]
item
=
data
[
'item_id'
]
date
=
time
.
mktime
(
time
.
strptime
(
data
[
'timestamp'
][:
19
],
'%Y-%m-%dT%H:%M:%S'
))
else
:
else
:
item
=
data
[
'item_id'
],
int
(
data
[
'timeframe'
])
item
=
data
[
'item_id'
],
int
(
data
[
'timeframe'
])
curdate
=
''
date
=
time
.
mktime
(
time
.
strptime
(
data
[
'eventdate'
],
'%Y-%m-%d'
))
if
opt
.
dataset
==
'yoochoose'
:
curdate
=
data
[
'timestamp'
]
if
sessid
not
in
sess_date
:
else
:
sess_date
[
sessid
]
=
date
curdate
=
data
[
'eventdate'
]
elif
date
>
sess_date
[
sessid
]:
sess_date
[
sessid
]
=
date
if
sessid
in
sess_clicks
:
if
sessid
in
sess_clicks
:
sess_clicks
[
sessid
]
+=
[
item
]
sess_clicks
[
sessid
]
+=
[
item
]
else
:
else
:
sess_clicks
[
sessid
]
=
[
item
]
sess_clicks
[
sessid
]
=
[
item
]
ctr
+=
1
ctr
+=
1
date
=
''
if
opt
.
dataset
!=
'yoochoose'
:
if
opt
.
dataset
==
'yoochoose'
:
date
=
time
.
mktime
(
time
.
strptime
(
curdate
[:
19
],
'%Y-%m-%dT%H:%M:%S'
))
else
:
date
=
time
.
mktime
(
time
.
strptime
(
curdate
,
'%Y-%m-%d'
))
for
i
in
list
(
sess_clicks
):
for
i
in
list
(
sess_clicks
):
sorted_clicks
=
sorted
(
sess_clicks
[
i
],
key
=
operator
.
itemgetter
(
1
))
sorted_clicks
=
sorted
(
sess_clicks
[
i
],
key
=
operator
.
itemgetter
(
1
))
sess_clicks
[
i
]
=
[
c
[
0
]
for
c
in
sorted_clicks
]
sess_clicks
[
i
]
=
[
c
[
0
]
for
c
in
sorted_clicks
]
sess_date
[
curid
]
=
date
print
(
"-- Reading data @ %ss"
%
datetime
.
datetime
.
now
())
print
(
"-- Reading data @ %ss"
%
datetime
.
datetime
.
now
())
# Filter out length 1 sessions
# Filter out length 1 sessions
...
@@ -160,7 +150,7 @@ def obtian_tra():
...
@@ -160,7 +150,7 @@ def obtian_tra():
train_dates
+=
[
date
]
train_dates
+=
[
date
]
train_seqs
+=
[
outseq
]
train_seqs
+=
[
outseq
]
print
(
item_ctr
)
# 43098, 37484
print
(
item_ctr
)
# 43098, 37484
with
open
(
"./
diginetica/
config.txt"
,
"w"
)
as
fout
:
with
open
(
"./config.txt"
,
"w"
)
as
fout
:
fout
.
write
(
str
(
item_ctr
)
+
"
\n
"
)
fout
.
write
(
str
(
item_ctr
)
+
"
\n
"
)
return
train_ids
,
train_dates
,
train_seqs
return
train_ids
,
train_dates
,
train_seqs
...
...
models/recall/gnn/data_prepare.sh
浏览文件 @
1a2cafdc
...
@@ -15,21 +15,31 @@
...
@@ -15,21 +15,31 @@
# limitations under the License.
# limitations under the License.
set
-e
set
-e
echo
"begin to download data"
cd
data
&&
python download.py
dataset
=
$1
mkdir
diginetica
src
=
$1
python preprocess.py
--dataset
diginetica
if
[[
$src
==
"yoochoose1_4"
||
$src
==
"yoochoose1_64"
]]
;
then
src
=
"yoochoose"
elif
[[
$src
==
"diginetica"
]]
;
then
src
=
"diginetica"
else
echo
"Usage: sh data_prepare.sh [diginetica|yoochoose1_4|yoochoose1_64]"
exit
1
fi
echo
"begin to download data"
cd
data
&&
python download.py
$src
mkdir
$dataset
python preprocess.py
--dataset
$src
echo
"begin to convert data (binary -> txt)"
echo
"begin to convert data (binary -> txt)"
python convert_data.py
--data_dir
diginetica
python convert_data.py
--data_dir
$dataset
cat
diginetica/train.txt |
wc
-l
>>
diginetica/
config.txt
cat
${
dataset
}
/train.txt |
wc
-l
>>
config.txt
rm
-rf
train
&&
mkdir
train
rm
-rf
train
&&
mkdir
train
mv
diginetica
/train.txt train
mv
${
dataset
}
/train.txt train
rm
-rf
test
&&
mkdir test
rm
-rf
test
&&
mkdir test
mv
diginetica/test.txt
test
mv
${
dataset
}
/test.txt
test
mv
diginetica/config.txt ./config.txt
models/recall/gnn/model.py
浏览文件 @
1a2cafdc
...
@@ -20,6 +20,7 @@ import paddle.fluid.layers as layers
...
@@ -20,6 +20,7 @@ import paddle.fluid.layers as layers
from
paddlerec.core.utils
import
envs
from
paddlerec.core.utils
import
envs
from
paddlerec.core.model
import
ModelBase
from
paddlerec.core.model
import
ModelBase
from
paddlerec.core.metrics
import
RecallK
class
Model
(
ModelBase
):
class
Model
(
ModelBase
):
...
@@ -235,16 +236,16 @@ class Model(ModelBase):
...
@@ -235,16 +236,16 @@ class Model(ModelBase):
softmax
=
layers
.
softmax_with_cross_entropy
(
softmax
=
layers
.
softmax_with_cross_entropy
(
logits
=
logits
,
label
=
inputs
[
6
])
# [batch_size, 1]
logits
=
logits
,
label
=
inputs
[
6
])
# [batch_size, 1]
self
.
loss
=
layers
.
reduce_mean
(
softmax
)
# [1]
self
.
loss
=
layers
.
reduce_mean
(
softmax
)
# [1]
self
.
acc
=
layers
.
accuracy
(
input
=
logits
,
label
=
inputs
[
6
],
k
=
20
)
acc
=
RecallK
(
input
=
logits
,
label
=
inputs
[
6
],
k
=
20
)
self
.
_cost
=
self
.
loss
self
.
_cost
=
self
.
loss
if
is_infer
:
if
is_infer
:
self
.
_infer_results
[
'
acc'
]
=
self
.
acc
self
.
_infer_results
[
'
P@20'
]
=
acc
self
.
_infer_results
[
'
loss
'
]
=
self
.
loss
self
.
_infer_results
[
'
LOSS
'
]
=
self
.
loss
return
return
self
.
_metrics
[
"LOSS"
]
=
self
.
loss
self
.
_metrics
[
"LOSS"
]
=
self
.
loss
self
.
_metrics
[
"
train_acc"
]
=
self
.
acc
self
.
_metrics
[
"
Train_P@20"
]
=
acc
def
optimizer
(
self
):
def
optimizer
(
self
):
step_per_epoch
=
self
.
corpus_size
//
self
.
train_batch_size
step_per_epoch
=
self
.
corpus_size
//
self
.
train_batch_size
...
...
models/recall/gnn/readme.md
0 → 100644
浏览文件 @
1a2cafdc
# GNN
以下是本例的简要目录结构及说明:
```
├── data #样例数据
├── train
├── train.txt
├── test
├── test.txt
├── download.py
├── convert_data.py
├── preprocess.py
├── __init__.py
├── README.md # 文档
├── model.py #模型文件
├── config.yaml #配置文件
├── data_prepare.sh #一键数据处理脚本
├── reader.py #训练数据reader
├── evaluate_reader.py # 预测数据reader
```
注:在阅读该示例前,建议您先了解以下内容:
[
paddlerec入门教程
](
https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md
)
---
## 内容
-
[
模型简介
](
#模型简介
)
-
[
数据准备
](
#数据准备
)
-
[
运行环境
](
#运行环境
)
-
[
快速开始
](
#快速开始
)
-
[
论文复现
](
#论文复现
)
-
[
进阶使用
](
#进阶使用
)
-
[
FAQ
](
#FAQ
)
## 模型简介
SR-GNN模型的介绍可以参阅论文
[
Session-based Recommendation with Graph Neural Networks
](
https://arxiv.org/abs/1811.00855
)
。
本文解决的是Session-based Recommendation这一问题,过程大致分为以下四步:
1.
首先对所有的session序列通过有向图进行建模。
2.
然后通过GNN,学习每个node(item)的隐向量表示
3.
通过一个attention架构模型得到每个session的embedding
4.
最后通过一个softmax层进行全表预测
本示例中,我们复现了论文效果,在DIGINETICA数据集上P@20可以达到50.7。
同时推荐用户参考
[
IPython Notebook demo
](
https://aistudio.baidu.com/aistudio/projectDetail/124382
)
本模型配置默认使用demo数据集,若进行精度验证,请参考
[
论文复现
](
#论文复现
)
部分。
本项目支持功能
训练:单机CPU、单机单卡GPU、单机多卡GPU、本地模拟参数服务器训练、增量训练,配置请参考
[
启动训练
](
https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/train.md
)
预测:单机CPU、单机单卡GPU ;配置请参考
[
PaddleRec 离线预测
](
https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/predict.md
)
## 数据处理
本示例中数据处理共包含三步:
-
Step1: 原始数据数据集下载,本示例提供了两个开源数据集:DIGINETICA和Yoochoose,可选其中任意一个训练本模型。数据下载命令及原始数据格式如下所示。若采用diginetica数据集,执行完该命令之后,会在data目录下得到原始数据文件train-item-views.csv。若采用yoochoose数据集,执行完该命令之后,会在data目录下得到原始数据文件yoochoose-clicks.dat。
```
cd data && python download.py diginetica # or yoochoose
```
> [Yoochooses](https://2015.recsyschallenge.com/challenge.html)数据集来源于RecSys Challenge 2015,原始数据包含如下字段:
1.
Session ID – the id of the session. In one session there are one or many clicks.
2.
Timestamp – the time when the click occurred.
3.
Item ID – the unique identifier of the item.
4.
Category – the category of the item.
> [DIGINETICA](https://competitions.codalab.org/competitions/11161#learn_the_details-data2)数据集来源于CIKM Cup 2016 _Personalized E-Commerce Search Challenge_项目。原始数据包含如下字段:
1. sessionId - the id of the session. In one session there are one or many clicks.
2. userId - the id of the user, with anonymized user ids.
3. itemId - the unique identifier of the item.
4. timeframe - time since the first query in a session, in milliseconds.
5. eventdate - calendar date.
-
Step2: 数据预处理。
1.
以session_id为key合并原始数据集,得到每个session的日期,及顺序点击列表。
2.
过滤掉长度为1的session;过滤掉点击次数小于5的items。
3.
训练集、测试集划分。原始数据集里最新日期七天内的作为训练集,更早之前的数据作为测试集。
```
cd data && python preprocess.py --dataset diginetica # or yoochoose
```
-
Step3: 数据整理。 将训练文件统一放在data/train目录下,测试文件统一放在data/test目录下。
```
cat data/diginetica/train.txt | wc -l >> data/config.txt # or yoochoose1_4 or yoochoose1_64
rm -rf data/train/*
rm -rf data/test/*
mv data/diginetica/train.txt data/train
mv data/diginetica/test.txt data/test
```
数据处理完成后,data/train目录存放训练数据,data/test目录下存放测试数据,数据格式如下:
```
#session\tlabel
10,11,12,12,13,14\t15
```
data/config.txt中存放数据统计信息,第一行代表训练集中item总数,用以配置模型词表大小,第二行代表训练集大小。
方便起见, 我们提供了一键式数据处理脚本:
```
sh data_prepare.sh diginetica # or yoochoose1_4 or yoochoose1_64
```
## 运行环境
PaddlePaddle>=1.7.2
python 2.7/3.5/3.6/3.7
PaddleRec >=0.1
os : windows/linux/macos
## 快速开始
### 单机训练
CPU环境
在config.yaml文件中设置好设备,epochs等。
```
# select runner by name
mode: [single_cpu_train, single_cpu_infer]
# config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process.
runner:
- name: single_cpu_train
class: train
# num of epochs
epochs: 2
# device to run training or infer
device: cpu
save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 1 # save inference
save_checkpoint_path: "increment_gnn" # save checkpoint path
save_inference_path: "inference_gnn" # save inference path
save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path
print_interval: 1
phases: [phase1]
```
### 单机预测
CPU环境
在config.yaml文件中设置好epochs、device等参数。
```
- name: single_cpu_infer
class: infer
# device to run training or infer
device: cpu
print_interval: 1
init_model_path: "increment_gnn" # load model path
phases: [phase2]
```
### 运行
```
python -m paddlerec.run -m paddlerec.models.recall.gnn
```
### 结果展示
样例数据训练结果展示:
```
Running SingleStartup.
Running SingleRunner.
batch: 1, LOSS: [10.67443], InsCnt: [200.], RecallCnt: [0.], Acc(Recall@20): [0.]
batch: 2, LOSS: [10.672471], InsCnt: [300.], RecallCnt: [0.], Acc(Recall@20): [0.]
batch: 3, LOSS: [10.672463], InsCnt: [400.], RecallCnt: [1.], Acc(Recall@20): [0.0025]
batch: 4, LOSS: [10.670724], InsCnt: [500.], RecallCnt: [2.], Acc(Recall@20): [0.004]
batch: 5, LOSS: [10.66949], InsCnt: [600.], RecallCnt: [2.], Acc(Recall@20): [0.00333333]
batch: 6, LOSS: [10.670102], InsCnt: [700.], RecallCnt: [2.], Acc(Recall@20): [0.00285714]
batch: 7, LOSS: [10.671348], InsCnt: [800.], RecallCnt: [2.], Acc(Recall@20): [0.0025]
...
epoch 0 done, use time: 2926.6897077560425, global metrics: LOSS=[6.0788856], InsCnt=719400.0 RecallCnt=224033.0 Acc(Recall@20)=0.3114164581595774
...
epoch 4 done, use time: 3083.101449728012, global metrics: LOSS=[4.249889], InsCnt=3597000.0 RecallCnt=2070666.0 Acc(Recall@20)=0.5756647206005004
```
样例数据预测结果展示:
```
Running SingleInferStartup.
Running SingleInferRunner.
load persistables from increment_gnn/2
batch: 1, InsCnt: [200.], RecallCnt: [96.], Acc(Recall@20): [0.48], LOSS: [5.7198644]
batch: 2, InsCnt: [300.], RecallCnt: [153.], Acc(Recall@20): [0.51], LOSS: [5.4096317]
batch: 3, InsCnt: [400.], RecallCnt: [210.], Acc(Recall@20): [0.525], LOSS: [5.300991]
batch: 4, InsCnt: [500.], RecallCnt: [258.], Acc(Recall@20): [0.516], LOSS: [5.6269655]
batch: 5, InsCnt: [600.], RecallCnt: [311.], Acc(Recall@20): [0.5183333], LOSS: [5.39276]
batch: 6, InsCnt: [700.], RecallCnt: [352.], Acc(Recall@20): [0.50285715], LOSS: [5.633842]
batch: 7, InsCnt: [800.], RecallCnt: [406.], Acc(Recall@20): [0.5075], LOSS: [5.342844]
batch: 8, InsCnt: [900.], RecallCnt: [465.], Acc(Recall@20): [0.51666665], LOSS: [4.918761]
...
Infer phase2 of epoch 0 done, use time: 549.1640813350677, global metrics: InsCnt=60800.0 RecallCnt=31083.0 Acc(Recall@20)=0.511233552631579, LOSS=[5.8957024]
```
## 论文复现
用原论文的完整数据复现论文效果需要在config.yaml修改超参:
-
batch_size: 修改config.yaml中dataset_train数据集的batch_size为100。
-
epochs: 修改config.yaml中runner的epochs为5。
-
sparse_feature_number: 不同训练数据集(diginetica or yoochoose)配置不一致,diginetica数据集配置为43098,yoochoose数据集配置为37484。具体见数据处理后得到的data/config.txt文件中第一行。
-
corpus_size: 不同训练数据集配置不一致,diginetica数据集配置为719470,yoochoose数据集配置为5917745。具体见数据处理后得到的data/config.txt文件中第二行。
使用cpu训练 5轮 测试Recall@20:0.51367
修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行
```
python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径
```
## 进阶使用
## FAQ
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录