From 33e68ab406396706d61b52a361c92ebdf3264e21 Mon Sep 17 00:00:00 2001 From: frankwhzhang Date: Mon, 5 Nov 2018 09:43:16 +0800 Subject: [PATCH] add train multiple negative and infer (#1422) * fix readme2.0 * add tagspace infer --- fluid/PaddleRec/gru4rec/README.md | 2 +- .../{TagSpace => tagspace}/README.md | 11 ++- fluid/PaddleRec/tagspace/infer.py | 77 +++++++++++++++++++ .../{TagSpace => tagspace}/small_test.txt | 0 .../{TagSpace => tagspace}/small_train.txt | 0 .../PaddleRec/{TagSpace => tagspace}/train.py | 28 +++---- .../PaddleRec/{TagSpace => tagspace}/utils.py | 35 +++++++-- 7 files changed, 125 insertions(+), 28 deletions(-) rename fluid/PaddleRec/{TagSpace => tagspace}/README.md (88%) create mode 100644 fluid/PaddleRec/tagspace/infer.py rename fluid/PaddleRec/{TagSpace => tagspace}/small_test.txt (100%) rename fluid/PaddleRec/{TagSpace => tagspace}/small_train.txt (100%) rename fluid/PaddleRec/{TagSpace => tagspace}/train.py (85%) rename fluid/PaddleRec/{TagSpace => tagspace}/utils.py (79%) diff --git a/fluid/PaddleRec/gru4rec/README.md b/fluid/PaddleRec/gru4rec/README.md index e41d13c1..6b3c9c66 100644 --- a/fluid/PaddleRec/gru4rec/README.md +++ b/fluid/PaddleRec/gru4rec/README.md @@ -21,7 +21,7 @@ GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recu 论文的贡献在于首次将RNN(GRU)运用于session-based推荐,相比传统的KNN和矩阵分解,效果有明显的提升。 -论文的核心思想史在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。 +论文的核心思想是在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。 session-based推荐应用场景非常广泛,比如用户的商品浏览、新闻点击、地点签到等序列数据。 diff --git a/fluid/PaddleRec/TagSpace/README.md b/fluid/PaddleRec/tagspace/README.md similarity index 88% rename from fluid/PaddleRec/TagSpace/README.md rename to fluid/PaddleRec/tagspace/README.md index 922e59ae..29509d3f 100644 --- a/fluid/PaddleRec/TagSpace/README.md +++ b/fluid/PaddleRec/tagspace/README.md @@ -6,6 +6,7 @@ . ├── README.md # 文档 ├── train.py # 训练脚本 +├── infer.py # 预测脚本 ├── utils # 通用函数 ├── small_train.txt # 小样本训练集 └── small_test.txt # 小样本测试集 @@ -26,7 +27,6 @@ TagSpace模型的介绍可以参阅论文[#TagSpace: Semantic Embeddings from Ha "3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again." ``` - ## 训练 '--use_cuda 1' 表示使用gpu, 缺省表示使用cpu @@ -41,10 +41,9 @@ CPU 环境 python train.py small_train.txt small_test.txt ``` -## 未来工作 - -添加预测部分 - -添加多种负例采样方式 +## 预测 +``` +CUDA_VISIBLE_DEVICES=0 python infer.py model/ 1 10 small_train.txt small_test.txt --use_cuda 1 +``` diff --git a/fluid/PaddleRec/tagspace/infer.py b/fluid/PaddleRec/tagspace/infer.py new file mode 100644 index 00000000..0ebc5c8d --- /dev/null +++ b/fluid/PaddleRec/tagspace/infer.py @@ -0,0 +1,77 @@ +import sys +import time +import math +import unittest +import contextlib +import numpy as np +import six +import paddle.fluid as fluid +import paddle +import utils + +def infer(test_reader, vocab_tag, use_cuda, model_path): + """ inference function """ + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + + with fluid.scope_guard(fluid.core.Scope()): + infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model( + model_path, exe) + t0 = time.time() + step_id = 0 + true_num = 0 + all_num = 0 + size = len(vocab_tag) + value = [] + for data in test_reader(): + step_id += 1 + lod_text_seq = utils.to_lodtensor([dat[0] for dat in data], place) + lod_tag = utils.to_lodtensor([dat[1] for dat in data], place) + lod_pos_tag = utils.to_lodtensor([dat[2] for dat in data], place) + para = exe.run( + infer_program, + feed={ + "text": lod_text_seq, + "pos_tag": lod_tag}, + fetch_list=fetch_vars, + return_numpy=False) + value.append(para[0]._get_float_element(0)) + if step_id % size == 0 and step_id > 1: + all_num += 1 + true_pos = [dat[2] for dat in data][0][0] + if value.index(max(value)) == int(true_pos): + true_num += 1 + value = [] + if step_id % 1000 == 0: + print(step_id, 1.0 * true_num / all_num) + t1 = time.time() + +if __name__ == "__main__": + if len(sys.argv) != 6: + print( + "Usage: %s model_dir start_epoch last_epoch(inclusive) train_file test_file" + ) + exit(0) + train_file = "" + test_file = "" + model_dir = sys.argv[1] + try: + start_index = int(sys.argv[2]) + last_index = int(sys.argv[3]) + train_file = sys.argv[4] + test_file = sys.argv[5] + except: + print( + "Usage: %s model_dir start_ipoch last_epoch(inclusive) train_file test_file" + ) + exit(-1) + vocab_text, vocab_tag, train_reader, test_reader = utils.prepare_data( + train_file, + test_file, + batch_size=1, + buffer_size=1000, + word_freq_threshold=0) + + for epoch in xrange(start_index, last_index + 1): + epoch_path = model_dir + "/epoch_" + str(epoch) + infer(test_reader=test_reader, vocab_tag=vocab_tag, use_cuda=False, model_path=epoch_path) diff --git a/fluid/PaddleRec/TagSpace/small_test.txt b/fluid/PaddleRec/tagspace/small_test.txt similarity index 100% rename from fluid/PaddleRec/TagSpace/small_test.txt rename to fluid/PaddleRec/tagspace/small_test.txt diff --git a/fluid/PaddleRec/TagSpace/small_train.txt b/fluid/PaddleRec/tagspace/small_train.txt similarity index 100% rename from fluid/PaddleRec/TagSpace/small_train.txt rename to fluid/PaddleRec/tagspace/small_train.txt diff --git a/fluid/PaddleRec/TagSpace/train.py b/fluid/PaddleRec/tagspace/train.py similarity index 85% rename from fluid/PaddleRec/TagSpace/train.py rename to fluid/PaddleRec/tagspace/train.py index 3abf222c..a79990d9 100644 --- a/fluid/PaddleRec/TagSpace/train.py +++ b/fluid/PaddleRec/tagspace/train.py @@ -24,7 +24,7 @@ def parse_args(): args = parser.parse_args() return args -def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=5, margin=0.1): +def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=5, margin=0.1, neg_size=5): """ network definition """ text = io.data(name="text", shape=[1], lod_level=1, dtype='int64') pos_tag = io.data(name="pos_tag", shape=[1], lod_level=1, dtype='int64') @@ -44,12 +44,14 @@ def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size= act="tanh", pool_type="max", param_attr="cnn") - text_hid = fluid.layers.fc(input=conv_1d, size=emb_dim, param_attr="text_hid") - cos_pos = nn.cos_sim(pos_tag_emb, text_hid) - cos_neg = nn.cos_sim(neg_tag_emb, text_hid) - + mul_text_hid = fluid.layers.sequence_expand_as(x=text_hid, y=neg_tag_emb) + mul_cos_neg = nn.cos_sim(neg_tag_emb, mul_text_hid) + cos_neg_all = fluid.layers.sequence_reshape(input=mul_cos_neg, new_dim=neg_size) + #choose max negtive cosine + cos_neg = nn.reduce_max(cos_neg_all, dim=1, keep_dim=True) + #calculate hinge loss loss_part1 = nn.elementwise_sub( tensor.fill_constant_batch_size_like( input=cos_pos, @@ -63,22 +65,20 @@ def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size= input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'), loss_part2) avg_cost = nn.mean(loss_part3) - less = tensor.cast(cf.less_than(cos_neg, cos_pos), dtype='float32') correct = nn.reduce_sum(less) return text, pos_tag, neg_tag, avg_cost, correct, cos_pos -def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size, +def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size, neg_size, pass_num, use_cuda, model_dir): """ train network """ - args = parse_args() vocab_text_size = len(vocab_text) vocab_tag_size = len(vocab_tag) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() # Train program - text, pos_tag, neg_tag, avg_cost, correct, pos_cos = network(vocab_text_size, vocab_tag_size) + text, pos_tag, neg_tag, avg_cost, correct, cos_pos = network(vocab_text_size, vocab_tag_size, neg_size=neg_size) # Optimization to minimize lost sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=base_lr) @@ -117,8 +117,8 @@ def train(train_reader, vocab_text, vocab_tag, base_lr, batch_size, (epoch_idx, batch_id, total_time / epoch_idx)) save_dir = "%s/epoch_%d" % (model_dir, epoch_idx) feed_var_names = ["text", "pos_tag"] - fetch_vars = [pos_cos] - fluid.io.save_inference_model(save_dir ,feed_var_names, fetch_vars, exe) + fetch_vars = [cos_pos] + fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe) print("finish training") def train_net(): @@ -128,17 +128,19 @@ def train_net(): test_file = args.test_file use_cuda = True if args.use_cuda else False batch_size = 100 + neg_size = 3 vocab_text, vocab_tag, train_reader, test_reader = utils.prepare_data( - train_file, test_file, batch_size=batch_size, buffer_size=batch_size*100, word_freq_threshold=0) + train_file, test_file, neg_size=neg_size, batch_size=batch_size, buffer_size=batch_size*100, word_freq_threshold=0) train( train_reader=train_reader, vocab_text=vocab_text, vocab_tag=vocab_tag, base_lr=0.01, batch_size=batch_size, + neg_size=neg_size, pass_num=10, use_cuda=use_cuda, - model_dir="model_dim10_2") + model_dir="model") if __name__ == "__main__": diff --git a/fluid/PaddleRec/TagSpace/utils.py b/fluid/PaddleRec/tagspace/utils.py similarity index 79% rename from fluid/PaddleRec/TagSpace/utils.py rename to fluid/PaddleRec/tagspace/utils.py index e1644c3a..bf483db6 100644 --- a/fluid/PaddleRec/TagSpace/utils.py +++ b/fluid/PaddleRec/tagspace/utils.py @@ -38,12 +38,13 @@ def prepare_data(train_filename, train_reader = sort_batch( paddle.reader.shuffle( train( - train_filename, vocab_text, vocab_tag, buffer_size, data_type=DataType.SEQ), + train_filename, vocab_text, vocab_tag, neg_size, + buffer_size, data_type=DataType.SEQ), buf_size=buffer_size), batch_size, batch_size * 20) test_reader = sort_batch( test( - test_filename, vocab_text, vocab_tag, buffer_size, data_type=DataType.SEQ), + test_filename, vocab_text, vocab_tag, neg_size, buffer_size, data_type=DataType.SEQ), batch_size, batch_size * 20) return vocab_text, vocab_tag, train_reader, test_reader @@ -123,7 +124,7 @@ def build_dict(column_num=2, min_word_freq=50, train_filename="", test_filename= word_idx = dict(list(zip(words, six.moves.range(len(words))))) return word_idx -def reader_creator(filename, text_idx, tag_idx, n, data_type): +def train_reader_creator(filename, text_idx, tag_idx, neg_size, n, data_type): def reader(): with open(filename) as input_file: data_file = csv.reader(input_file) @@ -138,7 +139,7 @@ def reader_creator(filename, text_idx, tag_idx, n, data_type): max_iter = 100 now_iter = 0 sum_n = 0 - while(sum_n < 1) : + while(sum_n < neg_size) : now_iter += 1 if now_iter > max_iter: print("error : only one class") @@ -152,8 +153,26 @@ def reader_creator(filename, text_idx, tag_idx, n, data_type): yield text, pos_tag, neg_tag return reader -def train(filename, text_idx, tag_idx, n, data_type=DataType.SEQ): - return reader_creator(filename, text_idx, tag_idx, n, data_type) +def test_reader_creator(filename, text_idx, tag_idx, n, data_type): + def reader(): + with open(filename) as input_file: + data_file = csv.reader(input_file) + for row in data_file: + text_raw = re.split(r'\W+', row[2].strip()) + text = [text_idx.get(w) for w in text_raw] + tag_raw = re.split(r'\W+', row[0].strip()) + pos_index = tag_idx.get(tag_raw[0]) + pos_tag = [] + pos_tag.append(pos_index) + for ii in range(len(tag_idx)): + tag = [] + tag.append(ii) + yield text, tag, pos_tag + return reader + + +def train(filename, text_idx, tag_idx, neg_size, n, data_type=DataType.SEQ): + return train_reader_creator(filename, text_idx, tag_idx, neg_size, n, data_type) -def test(filename, text_idx, tag_idx, n, data_type=DataType.SEQ): - return reader_creator(filename, text_idx, tag_idx, n, data_type) +def test(filename, text_idx, tag_idx, neg_size, n, data_type=DataType.SEQ): + return test_reader_creator(filename, text_idx, tag_idx, n, data_type) -- GitLab