提交 82eb0fe4 编写于 作者: D dzhwinter

"fix len type error of QueryList"

上级 16d6bd7c
......@@ -116,6 +116,9 @@ class QueryList(object):
for query in self.querylist:
yield query
def __len__(self):
return len(self.querylist)
def _correct_ranking_(self):
if self.querylist is None:
return
......@@ -175,7 +178,7 @@ def gen_pair(querylist, partial_order="full"):
def gen_list(querylist):
"""
gen pair for pair-wise learning to rank algorithm
gen item in list for list-wise learning to rank algorithm
Paramters:
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
......@@ -190,7 +193,9 @@ def gen_list(querylist):
querylist._correct_ranking_()
relevance_score_list = [query.relevance_score for query in querylist]
feature_vector_list = [query.feature_vector for query in querylist]
yield np.array(relevance_score_list).T, np.array(feature_vector_list)
# yield np.array(relevance_score_list).T, np.array(feature_vector_list)
for i in range(len(querylist)):
yield relevance_score_list[i], np.array(feature_vector_list[i])
def load_from_text(filepath, shuffle=True, fill_missing=-1):
......@@ -236,7 +241,9 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1):
for pair in gen_pair(querylist):
yield pair
elif format == "listwise":
yield next(gen_list(querylist))
# yield next(gen_list(querylist))
for instance in gen_list(querylist):
yield instance
train = functools.partial(__reader__,filepath="MQ2007/MQ2007/Fold1/train.txt")
test = functools.partial(__reader__, filepath="MQ2007/MQ2007/Fold1/test.txt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册