From 82eb0fe45b293a884cd6e8be805a60366be88d6b Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 8 May 2017 22:26:46 +0800 Subject: [PATCH] "fix len type error of QueryList" --- python/paddle/v2/dataset/mq2007.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/dataset/mq2007.py b/python/paddle/v2/dataset/mq2007.py index 9dd33a22bac..2e8a2f9b089 100644 --- a/python/paddle/v2/dataset/mq2007.py +++ b/python/paddle/v2/dataset/mq2007.py @@ -116,6 +116,9 @@ class QueryList(object): for query in self.querylist: yield query + def __len__(self): + return len(self.querylist) + def _correct_ranking_(self): if self.querylist is None: return @@ -175,7 +178,7 @@ def gen_pair(querylist, partial_order="full"): def gen_list(querylist): """ - gen pair for pair-wise learning to rank algorithm + gen item in list for list-wise learning to rank algorithm Paramters: -------- querylist : querylist, one query match many docment pairs in list, see QueryList @@ -190,7 +193,9 @@ def gen_list(querylist): querylist._correct_ranking_() relevance_score_list = [query.relevance_score for query in querylist] feature_vector_list = [query.feature_vector for query in querylist] - yield np.array(relevance_score_list).T, np.array(feature_vector_list) + # yield np.array(relevance_score_list).T, np.array(feature_vector_list) + for i in range(len(querylist)): + yield relevance_score_list[i], np.array(feature_vector_list[i]) def load_from_text(filepath, shuffle=True, fill_missing=-1): @@ -236,7 +241,9 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1): for pair in gen_pair(querylist): yield pair elif format == "listwise": - yield next(gen_list(querylist)) + # yield next(gen_list(querylist)) + for instance in gen_list(querylist): + yield instance train = functools.partial(__reader__,filepath="MQ2007/MQ2007/Fold1/train.txt") test = functools.partial(__reader__, filepath="MQ2007/MQ2007/Fold1/test.txt") -- GitLab