未验证 提交 8a6f5942 编写于 作者: H hutuxian 提交者: GitHub

Merge pull request #1870 from hutuxian/gnn

gnn network train infer
# SR-GNN
以下是本例的简要目录结构及说明:
```text
.
├── README.md # 文档
├── train.py # 训练脚本
├── infer.py # 预测脚本
├── network.py # 网络结构
├── cluster_train.py # 多机训练
├── cluster_train.sh # 多机训练脚本
├── reader.py # 和读取数据相关的函数
├── data/
├── download.sh # 下载数据的脚本
├── preprocess.py # 数据预处理
```
## 简介
SR-GNN模型的介绍可以参阅论文[Session-based Recommendation with Graph Neural Networks](https://arxiv.org/abs/1811.00855)
本文解决的是Session-based Recommendation这一问题,过程大致分为以下四步:
是对所有的session序列通过有向图进行建模。
然后通过GNN,学习每个node(item)的隐向量表示
然后通过一个attention架构模型得到每个session的embedding
最后通过一个softmax层进行全表预测
我们复现了论文效果,在DIGINETICA数据集上P@20可以达到50.7
## 数据下载及预处理
使用DIGINETICA数据集,数据来自:http://cikm2016.cs.iupui.edu/cikm-cup。可以按照下述过程操作获得数据集以及进行简单的数据预处理。
* Step 1: 运行如下命令,下载DIGINETICA数据集并进行预处理
```
cd data && sh download.sh
```
* Step 2: 产生训练集、测试集和config文件
```
python preprocess.py
cd ..
```
运行之后在data文件夹下会产生diginetica文件夹,里面包含config.txt、test.txt train.txt三个文件
生成的数据格式为:(session_list,
label_list)。
其中session_list是一个session的列表,其中每个元素都是一个list,代表不同的session。label_list是一个列表,每个位置的元素是session_list中对应session的label。
例子:session_list=[[1,2,3], [4], [7,9]]。代表这个session_list包含3个session,第一个session包含的item序列是1,2,3,第二个session只有1个item 4,第三个session包含的item序列是7,9。
label_list = [6, 9,
1]。代表[1,2,3]这个session的预测label值应该为6,后两个以此类推。
提示:
* 如果您想使用自己业务场景下的数据,只要令数据满足上述格式要求即可
* 本例中的train.txt和test.txt两个文件均为二进制文件
## 训练
可以参考下面不同场景下的运行命令进行训练,还可以指定诸如batch_size,lr(learning rate)等参数,具体的配置说明可通过运行下列代码查看
```
python train.py -h
```
gpu 单机单卡训练
``` bash
CUDA_VISIBLE_DEVICES=1 python -u train.py --use_cuda 1 > log.txt 2>&1 &
```
cpu 单机训练
``` bash
python -u train.py --use_cuda 0 > log.txt 2>&1 &
```
值得注意的是上述单卡训练可以通过加--parallel 1参数使用Parallel Executor来进行加速
## 训练结果示例
我们在Tesla K40m单GPU卡上训练的日志如下所示(以实际输出为准)
```text
W0308 16:08:24.249840 1785 device_context.cc:263] Please NOTE: device: 0, CUDA Capability: 35, Driver API Version: 9.0, Runtime API Version: 8.0
W0308 16:08:24.249974 1785 device_context.cc:271] device: 0, cuDNN Version: 7.0.
2019-03-08 16:08:38,079 - INFO - load data complete
2019-03-08 16:08:38,080 - INFO - begin train
2019-03-08 16:09:07,605 - INFO - step: 500, loss: 10.2052, train_acc: 0.0088
2019-03-08 16:09:36,940 - INFO - step: 1000, loss: 9.7192, train_acc: 0.0320
2019-03-08 16:10:08,617 - INFO - step: 1500, loss: 8.9290, train_acc: 0.1350
...
2019-03-08 16:16:01,151 - INFO - model saved in ./saved_model/epoch_0
...
```
## 预测
运行如下命令即可开始预测。可以通过参数指定开始和结束的epoch轮次。
```
CUDA_VISIBLE_DEVICES=3 python infer.py
```
## 预测结果示例
```text
W0308 16:41:56.847339 31709 device_context.cc:263] Please NOTE: device: 0, CUDA Capability: 35, Driver API Version: 9.0, Runtime API Version: 8.0
W0308 16:41:56.847705 31709 device_context.cc:271] device: 0, cuDNN Version: 7.0.
2019-03-08 16:42:20,420 - INFO - TEST --> loss: 5.8865, Recall@20: 0.4525
2019-03-08 16:42:45,153 - INFO - TEST --> loss: 5.5314, Recall@20: 0.5010
2019-03-08 16:43:10,233 - INFO - TEST --> loss: 5.5128, Recall@20: 0.5047
...
```
#!/bin/bash
#The gdown.pl script comes from: https://github.com/circulosmeos/gdown.pl
./gdown.pl https://drive.google.com/open?id=0B7XZSACQf0KdenRmMk8yVUU5LWc dataset-train-diginetica.zip
unzip dataset-train-diginetica.zip "train-item-views.csv"
sed -i '1d' train-item-views.csv
sed -i '1i session_id;user_id;item_id;timeframe;eventdate' train-item-views.csv
#!/usr/bin/env perl
#
# Google Drive direct download of big files
# ./gdown.pl 'gdrive file url' ['desired file name']
#
# v1.0 by circulosmeos 04-2014.
# v1.1 by circulosmeos 01-2017.
# v1.2, v1.3, v1.4 by circulosmeos 01-2019, 02-2019.
# //circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files
# Distributed under GPL 3 (//www.gnu.org/licenses/gpl-3.0.html)
#
use strict;
use POSIX;
my $TEMP='gdown.cookie.temp';
my $COMMAND;
my $confirm;
my $check;
sub execute_command();
my $URL=shift;
die "\n./gdown.pl 'gdrive file url' [desired file name]\n\n" if $URL eq '';
my $FILENAME=shift;
$FILENAME='gdown.'.strftime("%Y%m%d%H%M%S", localtime).'.'.substr(rand,2) if $FILENAME eq '';
if ($URL=~m#^https?://drive.google.com/file/d/([^/]+)#) {
$URL="https://docs.google.com/uc?id=$1&export=download";
}
elsif ($URL=~m#^https?://drive.google.com/open\?id=([^/]+)#) {
$URL="https://docs.google.com/uc?id=$1&export=download";
}
execute_command();
while (-s $FILENAME < 100000) { # only if the file isn't the download yet
open fFILENAME, '<', $FILENAME;
$check=0;
foreach (<fFILENAME>) {
if (/href="(\/uc\?export=download[^"]+)/) {
$URL='https://docs.google.com'.$1;
$URL=~s/&amp;/&/g;
$confirm='';
$check=1;
last;
}
if (/confirm=([^;&]+)/) {
$confirm=$1;
$check=1;
last;
}
if (/"downloadUrl":"([^"]+)/) {
$URL=$1;
$URL=~s/\\u003d/=/g;
$URL=~s/\\u0026/&/g;
$confirm='';
$check=1;
last;
}
}
close fFILENAME;
die "Couldn't download the file :-(\n" if ($check==0);
$URL=~s/confirm=([^;&]+)/confirm=$confirm/ if $confirm ne '';
execute_command();
}
unlink $TEMP;
sub execute_command() {
$COMMAND="wget --progress=dot:giga --no-check-certificate --load-cookie $TEMP --save-cookie $TEMP \"$URL\"";
$COMMAND.=" -O \"$FILENAME\"" if $FILENAME ne '';
system ( $COMMAND );
return 1;
}
#!/usr/bin/env python36
# -*- coding: utf-8 -*-
"""
Created on July, 2018
@author: Tangrizzly
"""
import argparse
import time
import csv
import pickle
import operator
import datetime
import os
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset',
default='sample',
help='dataset name: diginetica/yoochoose/sample')
opt = parser.parse_args()
print(opt)
dataset = 'sample_train-item-views.csv'
if opt.dataset == 'diginetica':
dataset = 'train-item-views.csv'
elif opt.dataset == 'yoochoose':
dataset = 'yoochoose-clicks.dat'
print("-- Starting @ %ss" % datetime.datetime.now())
with open(dataset, "r") as f:
if opt.dataset == 'yoochoose':
reader = csv.DictReader(f, delimiter=',')
else:
reader = csv.DictReader(f, delimiter=';')
sess_clicks = {}
sess_date = {}
ctr = 0
curid = -1
curdate = None
for data in reader:
sessid = data['session_id']
if curdate and not curid == sessid:
date = ''
if opt.dataset == 'yoochoose':
date = time.mktime(
time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
else:
date = time.mktime(time.strptime(curdate, '%Y-%m-%d'))
sess_date[curid] = date
curid = sessid
if opt.dataset == 'yoochoose':
item = data['item_id']
else:
item = data['item_id'], int(data['timeframe'])
curdate = ''
if opt.dataset == 'yoochoose':
curdate = data['timestamp']
else:
curdate = data['eventdate']
if sessid in sess_clicks:
sess_clicks[sessid] += [item]
else:
sess_clicks[sessid] = [item]
ctr += 1
date = ''
if opt.dataset == 'yoochoose':
date = time.mktime(time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
else:
date = time.mktime(time.strptime(curdate, '%Y-%m-%d'))
for i in list(sess_clicks):
sorted_clicks = sorted(sess_clicks[i], key=operator.itemgetter(1))
sess_clicks[i] = [c[0] for c in sorted_clicks]
sess_date[curid] = date
print("-- Reading data @ %ss" % datetime.datetime.now())
# Filter out length 1 sessions
for s in list(sess_clicks):
if len(sess_clicks[s]) == 1:
del sess_clicks[s]
del sess_date[s]
# Count number of times each item appears
iid_counts = {}
for s in sess_clicks:
seq = sess_clicks[s]
for iid in seq:
if iid in iid_counts:
iid_counts[iid] += 1
else:
iid_counts[iid] = 1
sorted_counts = sorted(iid_counts.items(), key=operator.itemgetter(1))
length = len(sess_clicks)
for s in list(sess_clicks):
curseq = sess_clicks[s]
filseq = list(filter(lambda i: iid_counts[i] >= 5, curseq))
if len(filseq) < 2:
del sess_clicks[s]
del sess_date[s]
else:
sess_clicks[s] = filseq
# Split out test set based on dates
dates = list(sess_date.items())
maxdate = dates[0][1]
for _, date in dates:
if maxdate < date:
maxdate = date
# 7 days for test
splitdate = 0
if opt.dataset == 'yoochoose':
splitdate = maxdate - 86400 * 1 # the number of seconds for a day:86400
else:
splitdate = maxdate - 86400 * 7
print('Splitting date', splitdate) # Yoochoose: ('Split date', 1411930799.0)
tra_sess = filter(lambda x: x[1] < splitdate, dates)
tes_sess = filter(lambda x: x[1] > splitdate, dates)
# Sort sessions by date
tra_sess = sorted(
tra_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ]
tes_sess = sorted(
tes_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ]
print(len(tra_sess)) # 186670 # 7966257
print(len(tes_sess)) # 15979 # 15324
print(tra_sess[:3])
print(tes_sess[:3])
print("-- Splitting train set and test set @ %ss" % datetime.datetime.now())
# Choosing item count >=5 gives approximately the same number of items as reported in paper
item_dict = {}
# Convert training sessions to sequences and renumber items to start from 1
def obtian_tra():
train_ids = []
train_seqs = []
train_dates = []
item_ctr = 1
for s, date in tra_sess:
seq = sess_clicks[s]
outseq = []
for i in seq:
if i in item_dict:
outseq += [item_dict[i]]
else:
outseq += [item_ctr]
item_dict[i] = item_ctr
item_ctr += 1
if len(outseq) < 2: # Doesn't occur
continue
train_ids += [s]
train_dates += [date]
train_seqs += [outseq]
print(item_ctr) # 43098, 37484
with open("./diginetica/config.txt", "w") as fout:
fout.write(str(item_ctr) + "\n")
return train_ids, train_dates, train_seqs
# Convert test sessions to sequences, ignoring items that do not appear in training set
def obtian_tes():
test_ids = []
test_seqs = []
test_dates = []
for s, date in tes_sess:
seq = sess_clicks[s]
outseq = []
for i in seq:
if i in item_dict:
outseq += [item_dict[i]]
if len(outseq) < 2:
continue
test_ids += [s]
test_dates += [date]
test_seqs += [outseq]
return test_ids, test_dates, test_seqs
tra_ids, tra_dates, tra_seqs = obtian_tra()
tes_ids, tes_dates, tes_seqs = obtian_tes()
def process_seqs(iseqs, idates):
out_seqs = []
out_dates = []
labs = []
ids = []
for id, seq, date in zip(range(len(iseqs)), iseqs, idates):
for i in range(1, len(seq)):
tar = seq[-i]
labs += [tar]
out_seqs += [seq[:-i]]
out_dates += [date]
ids += [id]
return out_seqs, out_dates, labs, ids
tr_seqs, tr_dates, tr_labs, tr_ids = process_seqs(tra_seqs, tra_dates)
te_seqs, te_dates, te_labs, te_ids = process_seqs(tes_seqs, tes_dates)
tra = (tr_seqs, tr_labs)
tes = (te_seqs, te_labs)
print(len(tr_seqs))
print(len(te_seqs))
print(tr_seqs[:3], tr_dates[:3], tr_labs[:3])
print(te_seqs[:3], te_dates[:3], te_labs[:3])
all = 0
for seq in tra_seqs:
all += len(seq)
for seq in tes_seqs:
all += len(seq)
print('avg length: ', all / (len(tra_seqs) + len(tes_seqs) * 1.0))
if opt.dataset == 'diginetica':
if not os.path.exists('diginetica'):
os.makedirs('diginetica')
pickle.dump(tra, open('diginetica/train.txt', 'wb'))
pickle.dump(tes, open('diginetica/test.txt', 'wb'))
pickle.dump(tra_seqs, open('diginetica/all_train_seq.txt', 'wb'))
elif opt.dataset == 'yoochoose':
if not os.path.exists('yoochoose1_4'):
os.makedirs('yoochoose1_4')
if not os.path.exists('yoochoose1_64'):
os.makedirs('yoochoose1_64')
pickle.dump(tes, open('yoochoose1_4/test.txt', 'wb'))
pickle.dump(tes, open('yoochoose1_64/test.txt', 'wb'))
split4, split64 = int(len(tr_seqs) / 4), int(len(tr_seqs) / 64)
print(len(tr_seqs[-split4:]))
print(len(tr_seqs[-split64:]))
tra4, tra64 = (tr_seqs[-split4:], tr_labs[-split4:]), (tr_seqs[-split64:],
tr_labs[-split64:])
seq4, seq64 = tra_seqs[tr_ids[-split4]:], tra_seqs[tr_ids[-split64]:]
pickle.dump(tra4, open('yoochoose1_4/train.txt', 'wb'))
pickle.dump(seq4, open('yoochoose1_4/all_train_seq.txt', 'wb'))
pickle.dump(tra64, open('yoochoose1_64/train.txt', 'wb'))
pickle.dump(seq64, open('yoochoose1_64/all_train_seq.txt', 'wb'))
else:
if not os.path.exists('sample'):
os.makedirs('sample')
pickle.dump(tra, open('sample/train.txt', 'wb'))
pickle.dump(tes, open('sample/test.txt', 'wb'))
pickle.dump(tra_seqs, open('sample/all_train_seq.txt', 'wb'))
print('Done.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import logging
import numpy as np
import os
import paddle
import paddle.fluid as fluid
import reader
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DIN example")
parser.add_argument(
'--model_path', type=str, default='./saved_model/', help="path of model parameters")
parser.add_argument(
'--test_path', type=str, default='./data/diginetica/test.txt', help='dir of test file')
parser.add_argument(
'--use_cuda', type=int, default=1, help='whether to use gpu')
parser.add_argument(
'--batch_size', type=int, default=100, help='input batch size')
parser.add_argument(
'--start_index', type=int, default='0', help='start index')
parser.add_argument(
'--last_index', type=int, default='10', help='end index')
return parser.parse_args()
def infer(epoch_num):
args = parse_args()
batch_size = args.batch_size
test_data = reader.Data(args.test_path, False)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
model_path = args.model_path + "epoch_" + str(epoch_num)
[infer_program, feed_names, fetch_targets] = fluid.io.load_inference_model(
model_path, exe)
feeder = fluid.DataFeeder(
feed_list=feed_names, place=place, program=infer_program)
loss_sum = 0.0
acc_sum = 0.0
count = 0
for data in test_data.reader(batch_size, batch_size, False):
res = exe.run(infer_program,
feed=feeder.feed(data),
fetch_list=fetch_targets)
loss_sum += res[0]
acc_sum += res[1]
count += 1
logger.info("TEST --> loss: %.4lf, Recall@20: %.4lf" %
(loss_sum / count, acc_sum / count))
if __name__ == "__main__":
args = parse_args()
for index in range(args.start_index, args.last_index + 1):
infer(index)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import math
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def network(batch_size, items_num, hidden_size, step):
stdv = 1.0 / math.sqrt(hidden_size)
items = layers.data(
name="items",
shape=[batch_size, items_num, 1],
dtype="int64",
append_batch_size=False) #[bs, uniq_max, 1]
seq_index = layers.data(
name="seq_index",
shape=[batch_size, items_num],
dtype="int32",
append_batch_size=False) #[-1(seq_max)*batch_size, 1]
last_index = layers.data(
name="last_index",
shape=[batch_size],
dtype="int32",
append_batch_size=False) #[batch_size, 1]
adj_in = layers.data(
name="adj_in",
shape=[batch_size, items_num, items_num],
dtype="float32",
append_batch_size=False)
adj_out = layers.data(
name="adj_out",
shape=[batch_size, items_num, items_num],
dtype="float32",
append_batch_size=False)
mask = layers.data(
name="mask",
shape=[batch_size, -1, 1],
dtype="float32",
append_batch_size=False)
label = layers.data(
name="label",
shape=[batch_size, 1],
dtype="int64",
append_batch_size=False)
items_emb = layers.embedding(
input=items,
param_attr=fluid.ParamAttr(
name="emb",
initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv)),
size=[items_num, hidden_size]) #[batch_size, uniq_max, h]
pre_state = items_emb
for i in range(step):
pre_state = layers.reshape(
x=pre_state, shape=[batch_size, -1, hidden_size])
state_in = layers.fc(
input=pre_state,
name="state_in",
size=hidden_size,
act=None,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv))) #[batch_size, uniq_max, h]
state_out = layers.fc(
input=pre_state,
name="state_out",
size=hidden_size,
act=None,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv))) #[batch_size, uniq_max, h]
state_adj_in = layers.matmul(adj_in, state_in) #[batch_size, uniq_max, h]
state_adj_out = layers.matmul(adj_out, state_out) #[batch_size, uniq_max, h]
gru_input = layers.concat([state_adj_in, state_adj_out], axis=2)
gru_input = layers.reshape(x=gru_input, shape=[-1, hidden_size * 2])
gru_fc = layers.fc(
input=gru_input,
name="gru_fc",
size=3 * hidden_size,
bias_attr=False)
pre_state, _, _ = fluid.layers.gru_unit(
input=gru_fc,
hidden=layers.reshape(
x=pre_state, shape=[-1, hidden_size]),
size=3 * hidden_size)
final_state = pre_state
seq_index = layers.reshape(seq_index, shape=[-1])
seq = layers.gather(final_state, seq_index) #[batch_size*-1(seq_max), h]
last = layers.gather(final_state, last_index) #[batch_size, h]
seq = layers.reshape(
seq, shape=[batch_size, -1, hidden_size]) #[batch_size, -1(seq_max), h]
last = layers.reshape(
last, shape=[batch_size, hidden_size]) #[batch_size, h]
seq_fc = layers.fc(
input=seq,
name="seq_fc",
size=hidden_size,
bias_attr=False,
act=None,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv))) #[batch_size, -1(seq_max), h]
last_fc = layers.fc(
input=last,
name="last_fc",
size=hidden_size,
bias_attr=False,
act=None,
num_flatten_dims=1,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv))) #[bathc_size, h]
seq_fc_t = layers.transpose(
seq_fc, perm=[1, 0, 2]) #[-1(seq_max), batch_size, h]
add = layers.elementwise_add(
seq_fc_t, last_fc) #[-1(seq_max), batch_size, h]
b = layers.create_parameter(
shape=[hidden_size],
dtype='float32',
default_initializer=fluid.initializer.Constant(value=0.0)) #[h]
add = layers.elementwise_add(add, b) #[-1(seq_max), batch_size, h]
add_sigmoid = layers.sigmoid(add) #[-1(seq_max), batch_size, h]
add_sigmoid = layers.transpose(
add_sigmoid, perm=[1, 0, 2]) #[batch_size, -1(seq_max), h]
weight = layers.fc(
input=add_sigmoid,
name="weight_fc",
size=1,
act=None,
num_flatten_dims=2,
bias_attr=False,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv))) #[batch_size, -1, 1]
weight *= mask
weight_mask = layers.elementwise_mul(seq, weight, axis=0)
global_attention = layers.reduce_sum(weight_mask, dim=1)
final_attention = layers.concat(
[global_attention, last_fc], axis=1) #[batch_size, 2*h]
final_attention_fc = layers.fc(
input=final_attention,
name="fina_attention_fc",
size=hidden_size,
bias_attr=False,
act=None,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv))) #[batch_size, h]
all_vocab = layers.create_global_var(
shape=[items_num - 1, 1],
value=0,
dtype="int64",
persistable=True,
name="all_vocab")
all_emb = layers.embedding(
input=all_vocab,
param_attr=fluid.ParamAttr(
name="emb",
initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv)),
size=[items_num, hidden_size]) #[all_vocab, h]
logits = layers.matmul(
x=final_attention_fc, y=all_emb,
transpose_y=True) #[batch_size, all_vocab]
softmax = layers.softmax_with_cross_entropy(
logits=logits, label=label) #[batch_size, 1]
loss = layers.reduce_mean(softmax) # [1]
acc = layers.accuracy(input=logits, label=label, k=20)
return loss, acc
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import copy
import random
import pickle
class Data():
def __init__(self, path, shuffle=False):
data = pickle.load(open(path, 'rb'))
self.shuffle = shuffle
self.length = len(data[0])
self.input = zip(data[0], data[1])
def make_data(self, cur_batch, batch_size):
cur_batch = [list(e) for e in cur_batch]
max_seq_len = 0
for e in cur_batch:
max_seq_len = max(max_seq_len, len(e[0]))
last_id = []
for e in cur_batch:
last_id.append(len(e[0]) - 1)
e[0] += [0] * (max_seq_len - len(e[0]))
max_uniq_len = 0
for e in cur_batch:
max_uniq_len = max(max_uniq_len, len(np.unique(e[0])))
items, adj_in, adj_out, seq_index, last_index = [], [], [], [], []
mask, label = [], []
id = 0
for e in cur_batch:
node = np.unique(e[0])
items.append(node.tolist() + (max_uniq_len - len(node)) * [0])
adj = np.zeros((max_uniq_len, max_uniq_len))
for i in np.arange(len(e[0]) - 1):
if e[0][i + 1] == 0:
break
u = np.where(node == e[0][i])[0][0]
v = np.where(node == e[0][i + 1])[0][0]
adj[u][v] = 1
u_deg_in = np.sum(adj, 0)
u_deg_in[np.where(u_deg_in == 0)] = 1
adj_in.append(
np.divide(adj, u_deg_in)
) #maybe should add a transpose, but the result shows no difference
u_deg_out = np.sum(adj, 1)
u_deg_out[np.where(u_deg_out == 0)] = 1
adj_out.append(np.divide(adj.transpose(), u_deg_out))
seq_index.append(
[np.where(node == i)[0][0] + id * max_uniq_len for i in e[0]])
last_index.append(
np.where(node == e[0][last_id[id]])[0][0] + id * max_uniq_len)
label.append(e[1] - 1)
mask.append([[1] * (last_id[id] + 1) + [0] *
(max_seq_len - last_id[id] - 1)])
id += 1
items = np.array(items).astype("int64").reshape((batch_size, -1, 1))
seq_index = np.array(seq_index).astype("int32").reshape(
(batch_size, -1))
last_index = np.array(last_index).astype("int32").reshape(
(batch_size, 1))
adj_in = np.array(adj_in).astype("float32").reshape(
(batch_size, max_uniq_len, max_uniq_len))
adj_out = np.array(adj_out).astype("float32").reshape(
(batch_size, max_uniq_len, max_uniq_len))
mask = np.array(mask).astype("float32").reshape((batch_size, -1, 1))
label = np.array(label).astype("int64").reshape((batch_size, 1))
return zip(items, seq_index, last_index, adj_in, adj_out, mask, label)
def reader(self, batch_size, batch_group_size, train=True):
if self.shuffle:
random.shuffle(self.input)
group_remain = self.length % batch_group_size
for bg_id in range(0, self.length - group_remain, batch_group_size):
cur_bg = copy.deepcopy(self.input[bg_id:bg_id + batch_group_size])
if train:
cur_bg = sorted(cur_bg, key=lambda x: len(x[0]), reverse=True)
for i in range(0, batch_group_size, batch_size):
cur_batch = cur_bg[i:i + batch_size]
yield self.make_data(cur_batch, batch_size)
#deal with the remaining, discard at most batch_size data
if group_remain < batch_size:
return
remain_data = copy.deepcopy(self.input[-group_remain:])
if train:
remain_data = sorted(
remain_data, key=lambda x: len(x[0]), reverse=True)
for i in range(0, batch_group_size, batch_size):
if i + batch_size <= len(remain_data):
cur_batch = remain_data[i:i + batch_size]
yield self.make_data(cur_batch, batch_size)
def read_config(path):
with open(path, "r") as fin:
item_num = int(fin.readline())
return item_num
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import os
from functools import partial
import logging
import paddle
import paddle.fluid as fluid
import argparse
import network
import reader
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("gnn")
parser.add_argument(
'--train_path', type=str, default='./data/diginetica/train.txt', help='dir of training data')
parser.add_argument(
'--config_path', type=str, default='./data/diginetica/config.txt', help='dir of config')
parser.add_argument(
'--model_path', type=str, default='./saved_model', help="path of model parameters")
parser.add_argument(
'--epoch_num', type=int, default=30, help='number of epochs to train for')
parser.add_argument(
'--batch_size', type=int, default=100, help='input batch size')
parser.add_argument(
'--hidden_size', type=int, default=100, help='hidden state size')
parser.add_argument(
'--l2', type=float, default=1e-5, help='l2 penalty')
parser.add_argument(
'--lr', type=float, default=0.001, help='learning rate')
parser.add_argument(
'--step', type=int, default=1, help='gnn propogation steps')
parser.add_argument(
'--lr_dc', type=float, default=0.1, help='learning rate decay rate')
parser.add_argument(
'--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether to use gpu')
parser.add_argument(
'--use_parallel', type=int, default=1, help='whether to use parallel executor')
return parser.parse_args()
def train():
args = parse_args()
batch_size = args.batch_size
items_num = reader.read_config(args.config_path)
loss, acc = network.network(batch_size, items_num, args.hidden_size,
args.step)
data_reader = reader.Data(args.train_path, True)
logger.info("load data complete")
use_cuda = True if args.use_cuda else False
use_parallel = True if args.use_parallel else False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
step_per_epoch = data_reader.length // batch_size
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.lr,
decay_steps=step_per_epoch * args.lr_dc_step,
decay_rate=args.lr_dc),
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.l2))
optimizer.minimize(loss)
exe.run(fluid.default_startup_program())
all_vocab = fluid.global_scope().var("all_vocab").get_tensor()
all_vocab.set(
np.arange(1, items_num).astype("int64").reshape((-1, 1)), place)
feed_list = [
"items", "seq_index", "last_index", "adj_in", "adj_out", "mask", "label"
]
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
if use_parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name)
else:
train_exe = exe
logger.info("begin train")
loss_sum = 0.0
acc_sum = 0.0
global_step = 0
PRINT_STEP = 500
for i in range(args.epoch_num):
epoch_sum = []
for data in data_reader.reader(batch_size, batch_size * 20, True):
res = train_exe.run(feed=feeder.feed(data),
fetch_list=[loss.name, acc.name])
loss_sum += res[0]
acc_sum += res[1]
epoch_sum.append(res[0])
global_step += 1
if global_step % PRINT_STEP == 0:
logger.info("global_step: %d, loss: %.4lf, train_acc: %.4lf" % (
global_step, loss_sum / PRINT_STEP, acc_sum / PRINT_STEP))
loss_sum = 0.0
acc_sum = 0.0
logger.info("epoch loss: %.4lf" % (np.mean(epoch_sum)))
save_dir = args.model_path + "/epoch_" + str(i)
fetch_vars = [loss, acc]
fluid.io.save_inference_model(save_dir, feed_list, fetch_vars, exe)
logger.info("model saved in " + save_dir)
if __name__ == "__main__":
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册