From 96ca1e966aaf75ab7644e594603e77f71ed14ca7 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 8 May 2017 16:30:27 +0800 Subject: [PATCH] "add mq2007 dataset for learning to rank task" --- python/paddle/v2/dataset/mq2007.py | 293 ++++++++++++++++++ python/paddle/v2/dataset/tests/mq2007_test.py | 31 ++ 2 files changed, 324 insertions(+) create mode 100644 python/paddle/v2/dataset/mq2007.py create mode 100644 python/paddle/v2/dataset/tests/mq2007_test.py diff --git a/python/paddle/v2/dataset/mq2007.py b/python/paddle/v2/dataset/mq2007.py new file mode 100644 index 000000000..8884dfd5b --- /dev/null +++ b/python/paddle/v2/dataset/mq2007.py @@ -0,0 +1,293 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +MQ2007 dataset + +MQ2007 is a query set from Million Query track of TREC 2007. There are about 1700 queries in it with labeled documents. In MQ2007, the 5-fold cross +validation strategy is adopted and the 5-fold partitions are included in the package. In each fold, there are three subsets for learning: training set, +validation set and testing set. + +MQ2007 dataset from +http://research.microsoft.com/en-us/um/beijing/projects/letor/LETOR4.0/Data/MQ2007.rar and parse training set and test set into paddle reader creators + +""" + + +import os +import random +import functools +import rarfile +from common import download +import numpy as np + + +# URL = "http://research.microsoft.com/en-us/um/beijing/projects/letor/LETOR4.0/Data/MQ2007.rar" +URL = "http://www.bigdatalab.ac.cn/benchmark/upload/download_source/7b6dbbe2-842c-11e4-a536-bcaec51b9163_MQ2007.rar" +MD5 = "7be1640ae95c6408dab0ae7207bdc706" + + +def __initialize_meta_info__(): + """ + download and extract the MQ2007 dataset + """ + fn = fetch() + rar = rarfile.RarFile(fn) + dirpath = os.path.dirname(fn) + rar.extractall(path=dirpath) + return dirpath + + +class Query(object): + """ + queries used for learning to rank algorithms. It is created from relevance scores, query-document feature vectors + + Parameters: + ---------- + query_id : int + query_id in dataset, mapping from query to relevance documents + relevance_score : int + relevance score of query and document pair + feature_vector : array, dense feature + feature in vector format + description : string + comment section in query doc pair data + """ + def __init__(self, query_id=-1, relevance_score=-1, + feature_vector=None, description=""): + self.query_id = query_id + self.relevance_score = relevance_score + if feature_vector is None: + self.feature_vector = [] + else: + self.feature_vector = feature_vector + self.description = description + + def __str__(self): + string = "%s %s %s" %(str(self.relevance_score), str(self.query_id), " ".join(str(f) for f in self.feature_vector)) + return string + + # @classmethod + def _parse_(self, text): + """ + parse line into Query + """ + comment_position = text.find('#') + line = text[:comment_position].strip() + self.description = text[comment_position+1:].strip() + parts = line.split() + assert(len(parts) == 48), "expect 48 space split parts, get %d" %(len(parts)) + # format : 0 qid:10 1:0.000272 2:0.000000 .... + self.relevance_score = int(parts[0]) + self.query_id = int(parts[1].split(':')[1]) + for p in parts[2:]: + pair = p.split(':') + self.feature_vector.append(float(pair[1])) + return self + +class QueryList(object): + """ + group query into list, every item in list is a Query + """ + def __init__(self, querylist=None): + self.query_id = -1 + if querylist is None: + self.querylist = [] + else: + self.querylist = querylist + for query in self.querylist: + if self.query_id == -1: + self.query_id = query.query_id + else: + if self.query_id != query.query_id: + raise ValueError("query in list must be same query_id") + + def __iter__(self): + for query in self.querylist: + yield query + + def _correct_ranking_(self): + if self.querylist is None: + return + self.querylist.sort(key=lambda x:x.relevance_score, reverse=True) + + def _add_query(self, query): + if self.query_id == -1: + self.query_id = query.query_id + else: + if self.query_id != query.query_id: + raise ValueError("query in list must be same query_id") + self.querylist.append(query) + + + +def gen_pair(querylist, partial_order="full"): + """ + gen pair for pair-wise learning to rank algorithm + Paramters: + -------- + querylist : querylist, one query match many docment pairs in list, see QueryList + pairtial_order : "full" or "neighbour" + gen pairs for neighbour items or the full partial order pairs + + return : + ------ + label : np.array, shape=(1) + query_left : np.array, shape=(1, feature_dimension) + query_right : same as left + """ + if not isinstance(querylist, QueryList): + querylist = QueryList(querylist) + querylist._correct_ranking_() + # C(n,2) + if partial_order == "full": + for i, query_left in enumerate(querylist): + for j, query_right in enumerate(querylist): + if query_left.relevance_score > query_right.relevance_score: + yield np.ones(1), np.array(query_left.feature_vector), np.array(query_right.feature_vector) + else: + yield np.ones(1), np.array(query_left.feature_vector), np.array(query_right.feature_vector) + + elif partial_order == "neighbour": + # C(n) + k = 0 + while k < len(querylist)-1: + query_left = querylist[k] + query_right = querylist[k+1] + if query_left.relevance_score > query_right.relevance_score: + yield np.ones(1), np.array(query_left.feature_vector), np.array(query_right.feature_vector) + else: + yield np.ones(1), np.array(query_left.feature_vector), np.array(query_right.feature_vector) + k += 1 + else: + raise ValueError("unsupport parameter of partial_order, Only can be neighbour or full") + + +def gen_list(querylist): + """ + gen pair for pair-wise learning to rank algorithm + Paramters: + -------- + querylist : querylist, one query match many docment pairs in list, see QueryList + + return : + ------ + label : np.array, shape=(samples_num, ) + querylist : np.array, shape=(samples_num, feature_dimension) + """ + if not isinstance(querylist, QueryList): + querylist = QueryList(querylist) + querylist._correct_ranking_() + relevance_score_list = [query.relevance_score for query in querylist] + feature_vector_list = [query.feature_vector for query in querylist] + yield np.array(relevance_score_list).T, np.array(feature_vector_list) + + +def load_from_text(filepath, shuffle=True, fill_missing=-1): + """ + parse data file into querys + """ + prev_query_id = -1; + querylists = [] + querylist = None + fn = __initialize_meta_info__() + with open(os.path.join(fn, filepath)) as f: + for line in f: + query = Query() + query = query._parse_(line) + if query.query_id != prev_query_id: + if querylist is not None: + querylists.append(querylist) + querylist = QueryList() + prev_query_id = query.query_id + querylist._add_query(query) + if shuffle == True: + random.shuffle(querylists) + return querylists + + +def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1): + """ + Parameters + -------- + filename : string + shuffle : shuffle query-doc pair under the same query + fill_missing : fill the missing value. default in MQ2007 is -1 + + Returns + ------ + yield + label query_left, query_right # format = "pairwise" + label querylist # format = "listwise" + """ + querylists = load_from_text(filepath, shuffle=shuffle, fill_missing=fill_missing) + for querylist in querylists: + if format == "pairwise": + for pair in gen_pair(querylist): + yield pair + elif format == "listwise": + yield next(gen_list(querylist)) + +train = functools.partial(__reader__,filepath="MQ2007/MQ2007/Fold1/train.txt") +test = functools.partial(__reader__, filepath="MQ2007/MQ2007/Fold1/test.txt") +# def __parse_line__(line_stream): +# """ +# return : score, qid, 46-dim feature vector +# parse line of file +# """ +# score = -1, qid = -1, features = [] +# line = line_stream[:line_stream.find('#')].strip() +# parts = line.split() +# assert(len(parts) == 48), "expect 48 space split parts, get ", len(parts) +# # format : 0 qid:10 1:0.000272 2:0.000000 .... +# score = int(parts[0]) +# qid = int(parts[1].split(':')[1]) +# for p in parts[2:]: +# pair = p.split(':') +# features.append(float(part[1])) +# return score, qid, features + + +# def __reader__(filename, rand_seed=0, is_test=False, test_rate=0.0): +# """ +# create a line reader Generator + +# Parameters +# -------- +# filename : string +# rand_seed : sample instance from dataset, set the sample random seed +# is_test : sample test set or generate train set +# test_rate : sample test set rate + +# Returns +# ------ +# yield +# int int lists +# score query_id, features +# """ +# rand = random.Random(x=rand_seed) +# with open(file_name, 'r') as f: +# for line in f: +# if (rand.random() < test_rate) == is_test: +# yield __parse_line__(line) + + +# def __pair_reader__(filename, shuffle=True): + + +def fetch(): + return download(URL, "MQ2007", MD5) + +if __name__ == "__main__": + fetch() + diff --git a/python/paddle/v2/dataset/tests/mq2007_test.py b/python/paddle/v2/dataset/tests/mq2007_test.py new file mode 100644 index 000000000..c9bddddeb --- /dev/null +++ b/python/paddle/v2/dataset/tests/mq2007_test.py @@ -0,0 +1,31 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.mq2007 +import unittest + + +class TestMQ2007(unittest.TestCase): + def test_pairwise(self): + for label, query_left, query_right in paddle.v2.dataset.mq2007.test(format="pairwise"): + self.assertEqual(query_left.shape(), (46, )) + self.assertEqual(query_right.shape(), (46, )) + + def test_listwise(self): + for label_array, query_array in paddle.v2.dataset.mq2007.test(format="listwise"): + self.assertEqual(len(label_array), len(query_array)) + + +if __name__ == "__main__": + unittest.main() -- GitLab