提交 33e68ab4 编写于 作者: F frankwhzhang 提交者: Yi Liu

add train multiple negative and infer (#1422)

* fix readme2.0

* add tagspace infer
上级 e8160b1c
...@@ -21,7 +21,7 @@ GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recu ...@@ -21,7 +21,7 @@ GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recu
论文的贡献在于首次将RNN(GRU)运用于session-based推荐,相比传统的KNN和矩阵分解,效果有明显的提升。 论文的贡献在于首次将RNN(GRU)运用于session-based推荐,相比传统的KNN和矩阵分解,效果有明显的提升。
论文的核心思想在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。 论文的核心思想在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。
session-based推荐应用场景非常广泛,比如用户的商品浏览、新闻点击、地点签到等序列数据。 session-based推荐应用场景非常广泛,比如用户的商品浏览、新闻点击、地点签到等序列数据。
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
. .
├── README.md # 文档 ├── README.md # 文档
├── train.py # 训练脚本 ├── train.py # 训练脚本
├── infer.py # 预测脚本
├── utils # 通用函数 ├── utils # 通用函数
├── small_train.txt # 小样本训练集 ├── small_train.txt # 小样本训练集
└── small_test.txt # 小样本测试集 └── small_test.txt # 小样本测试集
...@@ -26,7 +27,6 @@ TagSpace模型的介绍可以参阅论文[#TagSpace: Semantic Embeddings from Ha ...@@ -26,7 +27,6 @@ TagSpace模型的介绍可以参阅论文[#TagSpace: Semantic Embeddings from Ha
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again." "3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
``` ```
## 训练 ## 训练
'--use_cuda 1' 表示使用gpu, 缺省表示使用cpu '--use_cuda 1' 表示使用gpu, 缺省表示使用cpu
...@@ -41,10 +41,9 @@ CPU 环境 ...@@ -41,10 +41,9 @@ CPU 环境
python train.py small_train.txt small_test.txt python train.py small_train.txt small_test.txt
``` ```
## 未来工作 ## 预测
添加预测部分
添加多种负例采样方式
```
CUDA_VISIBLE_DEVICES=0 python infer.py model/ 1 10 small_train.txt small_test.txt --use_cuda 1
```
import sys
import time
import math
import unittest
import contextlib
import numpy as np
import six
import paddle.fluid as fluid
import paddle
import utils
def infer(test_reader, vocab_tag, use_cuda, model_path):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
with fluid.scope_guard(fluid.core.Scope()):
infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
model_path, exe)
t0 = time.time()
step_id = 0
true_num = 0
all_num = 0
size = len(vocab_tag)
value = []
for data in test_reader():
step_id += 1
lod_text_seq = utils.to_lodtensor([dat[0] for dat in data], place)
lod_tag = utils.to_lodtensor([dat[1] for dat in data], place)
lod_pos_tag = utils.to_lodtensor([dat[2] for dat in data], place)
para = exe.run(
infer_program,
feed={
"text": lod_text_seq,
"pos_tag": lod_tag},
fetch_list=fetch_vars,
return_numpy=False)
value.append(para[0]._get_float_element(0))
if step_id % size == 0 and step_id > 1:
all_num += 1
true_pos = [dat[2] for dat in data][0][0]
if value.index(max(value)) == int(true_pos):
true_num += 1
value = []
if step_id % 1000 == 0:
print(step_id, 1.0 * true_num / all_num)
t1 = time.time()
if __name__ == "__main__":
if len(sys.argv) != 6:
print(
"Usage: %s model_dir start_epoch last_epoch(inclusive) train_file test_file"
)
exit(0)
train_file = ""
test_file = ""
model_dir = sys.argv[1]
try:
start_index = int(sys.argv[2])
last_index = int(sys.argv[3])
train_file = sys.argv[4]
test_file = sys.argv[5]
except:
print(
"Usage: %s model_dir start_ipoch last_epoch(inclusive) train_file test_file"
)
exit(-1)
vocab_text, vocab_tag, train_reader, test_reader = utils.prepare_data(
train_file,
test_file,
batch_size=1,
buffer_size=1000,
word_freq_threshold=0)
for epoch in xrange(start_index, last_index + 1):
epoch_path = model_dir + "/epoch_" + str(epoch)
infer(test_reader=test_reader, vocab_tag=vocab_tag, use_cuda=False, model_path=epoch_path)
...@@ -24,7 +24,7 @@ def parse_args(): ...@@ -24,7 +24,7 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=5, margin=0.1): def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=5, margin=0.1, neg_size=5):
""" network definition """ """ network definition """
text = io.data(name="text", shape=[1], lod_level=1, dtype='int64') text = io.data(name="text", shape=[1], lod_level=1, dtype='int64')
pos_tag = io.data(name="pos_tag", shape=[1], lod_level=1, dtype='int64') pos_tag = io.data(name="pos_tag", shape=[1], lod_level=1, dtype='int64')
...@@ -44,12 +44,14 @@ def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size= ...@@ -44,12 +44,14 @@ def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=
act="tanh", act="tanh",
pool_type="max", pool_type="max",
param_attr="cnn") param_attr="cnn")
text_hid = fluid.layers.fc(input=conv_1d, size=emb_dim, param_attr="text_hid") text_hid = fluid.layers.fc(input=conv_1d, size=emb_dim, param_attr="text_hid")
cos_pos = nn.cos_sim(pos_tag_emb, text_hid) cos_pos = nn.cos_sim(pos_tag_emb, text_hid)
cos_neg = nn.cos_sim(neg_tag_emb, text_hid) mul_text_hid = fluid.layers.sequence_expand_as(x=text_hid, y=neg_tag_emb)
mul_cos_neg = nn.cos_sim(neg_tag_emb, mul_text_hid)
cos_neg_all = fluid.layers.sequence_reshape(input=mul_cos_neg, new_dim=neg_size)
#choose max negtive cosine
cos_neg = nn.reduce_max(cos_neg_all, dim=1, keep_dim=True)
#calculate hinge loss
loss_part1 = nn.elementwise_sub( loss_part1 = nn.elementwise_sub(
tensor.fill_constant_batch_size_like( tensor.fill_constant_batch_size_like(
input=cos_pos, input=cos_pos,
...@@ -63,22 +65,20 @@ def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size= ...@@ -63,22 +65,20 @@ def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=
input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'), input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'),
loss_part2) loss_part2)
avg_cost = nn.mean(loss_part3) avg_cost = nn.mean(loss_part3)
less = tensor.cast(cf.less_than(cos_neg, cos_pos), dtype='float32') less = tensor.cast(cf.less_than(cos_neg, cos_pos), dtype='float32')
correct = nn.reduce_sum(less) correct = nn.reduce_sum(less)
return text, pos_tag, neg_tag, avg_cost, correct, cos_pos return text, pos_tag, neg_tag, avg_cost, correct, cos_pos
def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size, def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size, neg_size,
pass_num, use_cuda, model_dir): pass_num, use_cuda, model_dir):
""" train network """ """ train network """
args = parse_args() args = parse_args()
vocab_text_size = len(vocab_text) vocab_text_size = len(vocab_text)
vocab_tag_size = len(vocab_tag) vocab_tag_size = len(vocab_tag)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# Train program # Train program
text, pos_tag, neg_tag, avg_cost, correct, pos_cos = network(vocab_text_size, vocab_tag_size) text, pos_tag, neg_tag, avg_cost, correct, cos_pos = network(vocab_text_size, vocab_tag_size, neg_size=neg_size)
# Optimization to minimize lost # Optimization to minimize lost
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=base_lr) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=base_lr)
...@@ -117,8 +117,8 @@ def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size, ...@@ -117,8 +117,8 @@ def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size,
(epoch_idx, batch_id, total_time / epoch_idx)) (epoch_idx, batch_id, total_time / epoch_idx))
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx) save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
feed_var_names = ["text", "pos_tag"] feed_var_names = ["text", "pos_tag"]
fetch_vars = [pos_cos] fetch_vars = [cos_pos]
fluid.io.save_inference_model(save_dir ,feed_var_names, fetch_vars, exe) fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe)
print("finish training") print("finish training")
def train_net(): def train_net():
...@@ -128,17 +128,19 @@ def train_net(): ...@@ -128,17 +128,19 @@ def train_net():
test_file = args.test_file test_file = args.test_file
use_cuda = True if args.use_cuda else False use_cuda = True if args.use_cuda else False
batch_size = 100 batch_size = 100
neg_size = 3
vocab_text, vocab_tag, train_reader, test_reader = utils.prepare_data( vocab_text, vocab_tag, train_reader, test_reader = utils.prepare_data(
train_file, test_file, batch_size=batch_size, buffer_size=batch_size*100, word_freq_threshold=0) train_file, test_file, neg_size=neg_size, batch_size=batch_size, buffer_size=batch_size*100, word_freq_threshold=0)
train( train(
train_reader=train_reader, train_reader=train_reader,
vocab_text=vocab_text, vocab_text=vocab_text,
vocab_tag=vocab_tag, vocab_tag=vocab_tag,
base_lr=0.01, base_lr=0.01,
batch_size=batch_size, batch_size=batch_size,
neg_size=neg_size,
pass_num=10, pass_num=10,
use_cuda=use_cuda, use_cuda=use_cuda,
model_dir="model_dim10_2") model_dir="model")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -38,12 +38,13 @@ def prepare_data(train_filename, ...@@ -38,12 +38,13 @@ def prepare_data(train_filename,
train_reader = sort_batch( train_reader = sort_batch(
paddle.reader.shuffle( paddle.reader.shuffle(
train( train(
train_filename, vocab_text, vocab_tag, buffer_size, data_type=DataType.SEQ), train_filename, vocab_text, vocab_tag, neg_size,
buffer_size, data_type=DataType.SEQ),
buf_size=buffer_size), buf_size=buffer_size),
batch_size, batch_size * 20) batch_size, batch_size * 20)
test_reader = sort_batch( test_reader = sort_batch(
test( test(
test_filename, vocab_text, vocab_tag, buffer_size, data_type=DataType.SEQ), test_filename, vocab_text, vocab_tag, neg_size, buffer_size, data_type=DataType.SEQ),
batch_size, batch_size * 20) batch_size, batch_size * 20)
return vocab_text, vocab_tag, train_reader, test_reader return vocab_text, vocab_tag, train_reader, test_reader
...@@ -123,7 +124,7 @@ def build_dict(column_num=2, min_word_freq=50, train_filename="", test_filename= ...@@ -123,7 +124,7 @@ def build_dict(column_num=2, min_word_freq=50, train_filename="", test_filename=
word_idx = dict(list(zip(words, six.moves.range(len(words))))) word_idx = dict(list(zip(words, six.moves.range(len(words)))))
return word_idx return word_idx
def reader_creator(filename, text_idx, tag_idx, n, data_type): def train_reader_creator(filename, text_idx, tag_idx, neg_size, n, data_type):
def reader(): def reader():
with open(filename) as input_file: with open(filename) as input_file:
data_file = csv.reader(input_file) data_file = csv.reader(input_file)
...@@ -138,7 +139,7 @@ def reader_creator(filename, text_idx, tag_idx, n, data_type): ...@@ -138,7 +139,7 @@ def reader_creator(filename, text_idx, tag_idx, n, data_type):
max_iter = 100 max_iter = 100
now_iter = 0 now_iter = 0
sum_n = 0 sum_n = 0
while(sum_n < 1) : while(sum_n < neg_size) :
now_iter += 1 now_iter += 1
if now_iter > max_iter: if now_iter > max_iter:
print("error : only one class") print("error : only one class")
...@@ -152,8 +153,26 @@ def reader_creator(filename, text_idx, tag_idx, n, data_type): ...@@ -152,8 +153,26 @@ def reader_creator(filename, text_idx, tag_idx, n, data_type):
yield text, pos_tag, neg_tag yield text, pos_tag, neg_tag
return reader return reader
def train(filename, text_idx, tag_idx, n, data_type=DataType.SEQ): def test_reader_creator(filename, text_idx, tag_idx, n, data_type):
return reader_creator(filename, text_idx, tag_idx, n, data_type) def reader():
with open(filename) as input_file:
data_file = csv.reader(input_file)
for row in data_file:
text_raw = re.split(r'\W+', row[2].strip())
text = [text_idx.get(w) for w in text_raw]
tag_raw = re.split(r'\W+', row[0].strip())
pos_index = tag_idx.get(tag_raw[0])
pos_tag = []
pos_tag.append(pos_index)
for ii in range(len(tag_idx)):
tag = []
tag.append(ii)
yield text, tag, pos_tag
return reader
def train(filename, text_idx, tag_idx, neg_size, n, data_type=DataType.SEQ):
return train_reader_creator(filename, text_idx, tag_idx, neg_size, n, data_type)
def test(filename, text_idx, tag_idx, n, data_type=DataType.SEQ): def test(filename, text_idx, tag_idx, neg_size, n, data_type=DataType.SEQ):
return reader_creator(filename, text_idx, tag_idx, n, data_type) return test_reader_creator(filename, text_idx, tag_idx, n, data_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册