diff --git a/python/paddle/v2/dataset/mq2007.py b/python/paddle/v2/dataset/mq2007.py index 9c6f8927b769a1943060de9cb431afd3e2c5d363..fd71b341662ca6f540ce44a86348e782561a97d7 100644 --- a/python/paddle/v2/dataset/mq2007.py +++ b/python/paddle/v2/dataset/mq2007.py @@ -89,8 +89,10 @@ class Query(object): line = text[:comment_position].strip() self.description = text[comment_position + 1:].strip() parts = line.split() - assert (len(parts) == 48), "expect 48 space split parts, get %d" % ( - len(parts)) + if len(parts) != 48: + sys.stdout.write("expect 48 space split parts, get %d" % + (len(parts))) + return None # format : 0 qid:10 1:0.000272 2:0.000000 .... self.relevance_score = int(parts[0]) self.query_id = int(parts[1].split(':')[1]) @@ -125,6 +127,9 @@ class QueryList(object): def __len__(self): return len(self.querylist) + def __getitem__(self, i): + return self.querylist[i] + def _correct_ranking_(self): if self.querylist is None: return @@ -139,6 +144,46 @@ class QueryList(object): self.querylist.append(query) +def gen_plain_txt(querylist): + """ + gen plain text in list for other usage + Paramters: + -------- + querylist : querylist, one query match many docment pairs in list, see QueryList + + return : + ------ + query_id : np.array, shape=(samples_num, ) + label : np.array, shape=(samples_num, ) + querylist : np.array, shape=(samples_num, feature_dimension) + """ + if not isinstance(querylist, QueryList): + querylist = QueryList(querylist) + querylist._correct_ranking_() + for query in querylist: + yield querylist.query_id, query.relevance_score, np.array( + query.feature_vector) + + +def gen_point(querylist): + """ + gen item in list for point-wise learning to rank algorithm + Paramters: + -------- + querylist : querylist, one query match many docment pairs in list, see QueryList + + return : + ------ + label : np.array, shape=(samples_num, ) + querylist : np.array, shape=(samples_num, feature_dimension) + """ + if not isinstance(querylist, QueryList): + querylist = QueryList(querylist) + querylist._correct_ranking_() + for query in querylist: + yield query.relevance_score, np.array(query.feature_vector) + + def gen_pair(querylist, partial_order="full"): """ gen pair for pair-wise learning to rank algorithm @@ -146,6 +191,7 @@ def gen_pair(querylist, partial_order="full"): -------- querylist : querylist, one query match many docment pairs in list, see QueryList pairtial_order : "full" or "neighbour" + there is redudant in all possiable pair combinations, which can be simplifed gen pairs for neighbour items or the full partial order pairs return : @@ -157,34 +203,28 @@ def gen_pair(querylist, partial_order="full"): if not isinstance(querylist, QueryList): querylist = QueryList(querylist) querylist._correct_ranking_() + labels = [] + docpairs = [] + # C(n,2) - if partial_order == "full": - for i, query_left in enumerate(querylist): - for j, query_right in enumerate(querylist): - if query_left.relevance_score > query_right.relevance_score: - yield 1, np.array(query_left.feature_vector), np.array( - query_right.feature_vector) - else: - yield 1, np.array(query_left.feature_vector), np.array( - query_right.feature_vector) - - elif partial_order == "neighbour": - # C(n) - k = 0 - while k < len(querylist) - 1: - query_left = querylist[k] - query_right = querylist[k + 1] + for i in range(len(querylist)): + query_left = querylist[i] + for j in range(i + 1, len(querylist)): + query_right = querylist[j] if query_left.relevance_score > query_right.relevance_score: - yield 1, np.array(query_left.feature_vector), np.array( - query_right.feature_vector) - else: - yield 1, np.array(query_left.feature_vector), np.array( - query_right.feature_vector) - k += 1 - else: - raise ValueError( - "unsupport parameter of partial_order, Only can be neighbour or full" - ) + labels.append(1) + docpairs.append([ + np.array(query_left.feature_vector), + np.array(query_right.feature_vector) + ]) + elif query_left.relevance_score < query_right.relevance_score: + labels.append(1) + docpairs.append([ + np.array(query_right.feature_vector), + np.array(query_left.feature_vector) + ]) + for label, pair in zip(labels, docpairs): + yield label, pair[0], pair[1] def gen_list(querylist): @@ -201,12 +241,30 @@ def gen_list(querylist): """ if not isinstance(querylist, QueryList): querylist = QueryList(querylist) - # querylist._correct_ranking_() + querylist._correct_ranking_() relevance_score_list = [query.relevance_score for query in querylist] feature_vector_list = [query.feature_vector for query in querylist] yield np.array(relevance_score_list).T, np.array(feature_vector_list) +def query_filter(querylists): + """ + filter query get only document with label 0. + label 0, 1, 2 means the relevance score document with query + parameters : + querylist : QueyList list + + return : + querylist : QueyList list + """ + filter_query = [] + for querylist in querylists: + relevance_score_list = [query.relevance_score for query in querylist] + if sum(relevance_score_list) != .0: + filter_query.append(querylist) + return filter_query + + def load_from_text(filepath, shuffle=True, fill_missing=-1): """ parse data file into querys @@ -219,12 +277,16 @@ def load_from_text(filepath, shuffle=True, fill_missing=-1): for line in f: query = Query() query = query._parse_(line) + if query == None: + continue if query.query_id != prev_query_id: if querylist is not None: querylists.append(querylist) querylist = QueryList() prev_query_id = query.query_id querylist._add_query(query) + if querylist is not None: + querylists.append(querylist) if shuffle == True: random.shuffle(querylists) return querylists @@ -244,10 +306,15 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1): label query_left, query_right # format = "pairwise" label querylist # format = "listwise" """ - querylists = load_from_text( - filepath, shuffle=shuffle, fill_missing=fill_missing) + querylists = query_filter( + load_from_text( + filepath, shuffle=shuffle, fill_missing=fill_missing)) for querylist in querylists: - if format == "pairwise": + if format == "plain_txt": + yield next(gen_plain_txt(querylist)) + elif format == "pointwise": + yield next(gen_point(querylist)) + elif format == "pairwise": for pair in gen_pair(querylist): yield pair elif format == "listwise": @@ -264,7 +331,7 @@ def fetch(): if __name__ == "__main__": fetch() - for i, (score, - samples) in enumerate(train( - format="listwise", shuffle=False)): - np.savetxt("query_%d" % (i), score, fmt="%.2f") + mytest = functools.partial( + __reader__, filepath="MQ2007/MQ2007/Fold1/sample", format="listwise") + for label, query in mytest(): + print label, query