提交 3063c441 编写于 作者: M malin10

gnn

上级 8180c70c
......@@ -42,7 +42,7 @@ hyper_parameters:
gnn_propogation_steps: 1
# select runner by name
mode: train_runner
mode: [train_runner, infer_runner]
# config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process.
runner:
......@@ -54,18 +54,20 @@ runner:
device: cpu
save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 1 # save inference
save_checkpoint_path: "increment" # save checkpoint path
save_inference_path: "inference" # save inference path
save_checkpoint_path: "increment_gnn" # save checkpoint path
save_inference_path: "inference_gnn" # save inference path
save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path
print_interval: 1
phases: [phase1]
- name: infer_runner
class: infer
# device to run training or infer
device: cpu
print_interval: 1
init_model_path: "increment/0" # load model path
init_model_path: "increment_gnn" # load model path
phases: [phase2]
# runner will run all the phase in each epoch
phase:
......@@ -73,7 +75,7 @@ phase:
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_train # select dataset by name
thread_num: 1
# - name: phase2
# model: "{workspace}/model.py" # user-defined model
# dataset_name: dataset_infer # select dataset by name
# thread_num: 1
- name: phase2
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name
thread_num: 1
......@@ -57,5 +57,10 @@ def _download_file(url, savepath, print_progress):
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
_download_file("https://sr-gnn.bj.bcebos.com/train-item-views.csv",
"./train-item-views.csv", True)
if sys.argv[1] == "diginetica":
_download_file("https://sr-gnn.bj.bcebos.com/train-item-views.csv",
"./train-item-views.csv", True)
elif sys.argv[1] == "yoochoose":
_download_file(
"https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat",
"./yoochoose-clicks.dat", True)
......@@ -41,39 +41,29 @@ with open(dataset, "r") as f:
curdate = None
for data in reader:
sessid = data['session_id']
if curdate and not curid == sessid:
date = ''
if opt.dataset == 'yoochoose':
date = time.mktime(
time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
else:
date = time.mktime(time.strptime(curdate, '%Y-%m-%d'))
sess_date[curid] = date
curid = sessid
date = ''
if opt.dataset == 'yoochoose':
item = data['item_id']
date = time.mktime(
time.strptime(data['timestamp'][:19], '%Y-%m-%dT%H:%M:%S'))
else:
item = data['item_id'], int(data['timeframe'])
curdate = ''
if opt.dataset == 'yoochoose':
curdate = data['timestamp']
else:
curdate = data['eventdate']
date = time.mktime(time.strptime(data['eventdate'], '%Y-%m-%d'))
if sessid not in sess_date:
sess_date[sessid] = date
elif date > sess_date[sessid]:
sess_date[sessid] = date
if sessid in sess_clicks:
sess_clicks[sessid] += [item]
else:
sess_clicks[sessid] = [item]
ctr += 1
date = ''
if opt.dataset == 'yoochoose':
date = time.mktime(time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
else:
date = time.mktime(time.strptime(curdate, '%Y-%m-%d'))
if opt.dataset != 'yoochoose':
for i in list(sess_clicks):
sorted_clicks = sorted(sess_clicks[i], key=operator.itemgetter(1))
sess_clicks[i] = [c[0] for c in sorted_clicks]
sess_date[curid] = date
print("-- Reading data @ %ss" % datetime.datetime.now())
# Filter out length 1 sessions
......@@ -160,7 +150,7 @@ def obtian_tra():
train_dates += [date]
train_seqs += [outseq]
print(item_ctr) # 43098, 37484
with open("./diginetica/config.txt", "w") as fout:
with open("./config.txt", "w") as fout:
fout.write(str(item_ctr) + "\n")
return train_ids, train_dates, train_seqs
......
......@@ -15,21 +15,31 @@
# limitations under the License.
set -e
echo "begin to download data"
cd data && python download.py
mkdir diginetica
python preprocess.py --dataset diginetica
dataset=$1
src=$1
if [[ $src == "yoochoose1_4" || $src == "yoochoose1_64" ]];then
src="yoochoose"
elif [[ $src == "diginetica" ]];then
src="diginetica"
else
echo "Usage: sh data_prepare.sh [diginetica|yoochoose1_4|yoochoose1_64]"
exit 1
fi
echo "begin to download data"
cd data && python download.py $src
mkdir $dataset
python preprocess.py --dataset $src
echo "begin to convert data (binary -> txt)"
python convert_data.py --data_dir diginetica
python convert_data.py --data_dir $dataset
cat diginetica/train.txt | wc -l >> diginetica/config.txt
cat ${dataset}/train.txt | wc -l >> config.txt
rm -rf train && mkdir train
mv diginetica/train.txt train
mv ${dataset}/train.txt train
rm -rf test && mkdir test
mv diginetica/test.txt test
mv diginetica/config.txt ./config.txt
mv ${dataset}/test.txt test
......@@ -20,6 +20,7 @@ import paddle.fluid.layers as layers
from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase
from paddlerec.core.metrics import RecallK
class Model(ModelBase):
......@@ -235,16 +236,16 @@ class Model(ModelBase):
softmax = layers.softmax_with_cross_entropy(
logits=logits, label=inputs[6]) # [batch_size, 1]
self.loss = layers.reduce_mean(softmax) # [1]
self.acc = layers.accuracy(input=logits, label=inputs[6], k=20)
acc = RecallK(input=logits, label=inputs[6], k=20)
self._cost = self.loss
if is_infer:
self._infer_results['acc'] = self.acc
self._infer_results['loss'] = self.loss
self._infer_results['P@20'] = acc
self._infer_results['LOSS'] = self.loss
return
self._metrics["LOSS"] = self.loss
self._metrics["train_acc"] = self.acc
self._metrics["Train_P@20"] = acc
def optimizer(self):
step_per_epoch = self.corpus_size // self.train_batch_size
......
# GNN
## 快速开始
PaddleRec中每个内置模型都配备了对应的样例数据,用户可基于该数据集快速对模型、环境进行验证,从而降低后续的调试成本。在内置数据集上进行训练的命令为:
```
python -m paddlerec.run -m paddlerec.models.recall.gnn
```
## 数据处理
- Step1: 原始数据数据集下载,本示例提供了两个开源数据集:DIGINETICA和Yoochoose,可选其中任意一个训练本模型。
```
cd data && python download.py diginetica # or yoochoose
```
> [Yoochooses](https://2015.recsyschallenge.com/challenge.html)数据集来源于RecSys Challenge 2015,原始数据包含如下字段:
1. Session ID – the id of the session. In one session there are one or many clicks.
2. Timestamp – the time when the click occurred.
3. Item ID – the unique identifier of the item.
4. Category – the category of the item.
> [DIGINETICA](https://competitions.codalab.org/competitions/11161#learn_the_details-data2)数据集来源于CIKM Cup 2016 _Personalized E-Commerce Search Challenge_项目。原始数据包含如下字段:
1. sessionId - the id of the session. In one session there are one or many clicks.
2. userId - the id of the user, with anonymized user ids.
3. itemId - the unique identifier of the item.
4. timeframe - time since the first query in a session, in milliseconds.
5. eventdate - calendar date.
- Step2: 数据预处理
```
cd data && python preprocess.py --dataset diginetica # or yoochoose
```
1. 以session_id为key合并原始数据集,得到每个session的日期,及顺序点击列表。
2. 过滤掉长度为1的session;过滤掉点击次数小于5的items。
3. 训练集、测试集划分。原始数据集里最新日期七天内的作为测试集,更早之前的数据作为测试集。
- Step3: 数据整理。 将训练文件统一放在data/train目录下,测试文件统一放在data/test目录下。
```
cat data/diginetica/train.txt | wc -l >> data/config.txt # or yoochoose1_4 or yoochoose1_64
rm -rf data/train/*
rm -rf data/test/*
mv data/diginetica/train.txt data/train
mv data/diginetica/test.txt data/test
```
数据处理完成后,data/train目录存放训练数据,data/test目录下存放测试数据,data/config.txt中存放数据统计信息,用以配置模型超参。
方便起见, 我们提供了一键式数据处理脚本:
```
sh data_prepare.sh diginetica # or yoochoose1_4 or yoochoose1_64
```
## 实验配置
为在真实数据中复现论文中的效果,你还需要完成如下几步,PaddleRec所有配置均通过修改模型目录下的config.yaml文件完成:
1. 真实数据配置。config.yaml中数据集相关配置见`dataset`字段,数据路径通过`data_path`进行配置。用户可以直接将workspace修改为当前项目目录的绝对路径完成设置。
2. 超参配置。
- batch_size: 修改config.yaml中dataset_train数据集的batch_size为100。
- epochs: 修改config.yaml中runner的epochs为5。
- sparse_feature_number: 不同训练数据集(diginetica or yoochoose)配置不一致,diginetica数据集配置为43098,yoochoose数据集配置为37484。具体见数据处理后得到的data/config.txt文件中第一行。
- corpus_size: 不同训练数据集配置不一致,diginetica数据集配置为719470,yoochoose数据集配置为5917745。具体见数据处理后得到的data/config.txt文件中第二行。
## 训练
在完成[实验配置](##实验配置)后,执行如下命令完成训练:
```
python -m paddlerec.run -m ./config.yaml
```
## 测试
开始测试前,你需要完成如下几步配置:
1. 修改config.yaml中的mode,为infer_runner。
2. 修改config.yaml中的phase,为phase_infer,需按提示注释掉phase_trainer。
3. 修改config.yaml中dataset_infer数据集的batch_size为100。
完成上面两步配置后,执行如下命令完成测试:
```
python -m paddlerec.run -m ./config.yaml
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册