提交 a4313de8 编写于 作者: D dzhwinter

"remove the pairwise other genereate method"

上级 4ac5caaa
...@@ -89,8 +89,10 @@ class Query(object): ...@@ -89,8 +89,10 @@ class Query(object):
line = text[:comment_position].strip() line = text[:comment_position].strip()
self.description = text[comment_position + 1:].strip() self.description = text[comment_position + 1:].strip()
parts = line.split() parts = line.split()
assert (len(parts) == 48), "expect 48 space split parts, get %d" % ( if len(parts) != 48:
len(parts)) sys.stdout.write("expect 48 space split parts, get %d" %
(len(parts)))
return None
# format : 0 qid:10 1:0.000272 2:0.000000 .... # format : 0 qid:10 1:0.000272 2:0.000000 ....
self.relevance_score = int(parts[0]) self.relevance_score = int(parts[0])
self.query_id = int(parts[1].split(':')[1]) self.query_id = int(parts[1].split(':')[1])
...@@ -125,6 +127,9 @@ class QueryList(object): ...@@ -125,6 +127,9 @@ class QueryList(object):
def __len__(self): def __len__(self):
return len(self.querylist) return len(self.querylist)
def __getitem__(self, i):
return self.querylist[i]
def _correct_ranking_(self): def _correct_ranking_(self):
if self.querylist is None: if self.querylist is None:
return return
...@@ -139,6 +144,46 @@ class QueryList(object): ...@@ -139,6 +144,46 @@ class QueryList(object):
self.querylist.append(query) self.querylist.append(query)
def gen_plain_txt(querylist):
"""
gen plain text in list for other usage
Paramters:
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
return :
------
query_id : np.array, shape=(samples_num, )
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
for query in querylist:
yield querylist.query_id, query.relevance_score, np.array(
query.feature_vector)
def gen_point(querylist):
"""
gen item in list for point-wise learning to rank algorithm
Paramters:
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
return :
------
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
for query in querylist:
yield query.relevance_score, np.array(query.feature_vector)
def gen_pair(querylist, partial_order="full"): def gen_pair(querylist, partial_order="full"):
""" """
gen pair for pair-wise learning to rank algorithm gen pair for pair-wise learning to rank algorithm
...@@ -146,6 +191,7 @@ def gen_pair(querylist, partial_order="full"): ...@@ -146,6 +191,7 @@ def gen_pair(querylist, partial_order="full"):
-------- --------
querylist : querylist, one query match many docment pairs in list, see QueryList querylist : querylist, one query match many docment pairs in list, see QueryList
pairtial_order : "full" or "neighbour" pairtial_order : "full" or "neighbour"
there is redudant in all possiable pair combinations, which can be simplifed
gen pairs for neighbour items or the full partial order pairs gen pairs for neighbour items or the full partial order pairs
return : return :
...@@ -157,34 +203,28 @@ def gen_pair(querylist, partial_order="full"): ...@@ -157,34 +203,28 @@ def gen_pair(querylist, partial_order="full"):
if not isinstance(querylist, QueryList): if not isinstance(querylist, QueryList):
querylist = QueryList(querylist) querylist = QueryList(querylist)
querylist._correct_ranking_() querylist._correct_ranking_()
labels = []
docpairs = []
# C(n,2) # C(n,2)
if partial_order == "full": for i in range(len(querylist)):
for i, query_left in enumerate(querylist): query_left = querylist[i]
for j, query_right in enumerate(querylist): for j in range(i + 1, len(querylist)):
if query_left.relevance_score > query_right.relevance_score: query_right = querylist[j]
yield 1, np.array(query_left.feature_vector), np.array(
query_right.feature_vector)
else:
yield 1, np.array(query_left.feature_vector), np.array(
query_right.feature_vector)
elif partial_order == "neighbour":
# C(n)
k = 0
while k < len(querylist) - 1:
query_left = querylist[k]
query_right = querylist[k + 1]
if query_left.relevance_score > query_right.relevance_score: if query_left.relevance_score > query_right.relevance_score:
yield 1, np.array(query_left.feature_vector), np.array( labels.append(1)
query_right.feature_vector) docpairs.append([
else: np.array(query_left.feature_vector),
yield 1, np.array(query_left.feature_vector), np.array( np.array(query_right.feature_vector)
query_right.feature_vector) ])
k += 1 elif query_left.relevance_score < query_right.relevance_score:
else: labels.append(1)
raise ValueError( docpairs.append([
"unsupport parameter of partial_order, Only can be neighbour or full" np.array(query_right.feature_vector),
) np.array(query_left.feature_vector)
])
for label, pair in zip(labels, docpairs):
yield label, pair[0], pair[1]
def gen_list(querylist): def gen_list(querylist):
...@@ -201,12 +241,30 @@ def gen_list(querylist): ...@@ -201,12 +241,30 @@ def gen_list(querylist):
""" """
if not isinstance(querylist, QueryList): if not isinstance(querylist, QueryList):
querylist = QueryList(querylist) querylist = QueryList(querylist)
# querylist._correct_ranking_() querylist._correct_ranking_()
relevance_score_list = [query.relevance_score for query in querylist] relevance_score_list = [query.relevance_score for query in querylist]
feature_vector_list = [query.feature_vector for query in querylist] feature_vector_list = [query.feature_vector for query in querylist]
yield np.array(relevance_score_list).T, np.array(feature_vector_list) yield np.array(relevance_score_list).T, np.array(feature_vector_list)
def query_filter(querylists):
"""
filter query get only document with label 0.
label 0, 1, 2 means the relevance score document with query
parameters :
querylist : QueyList list
return :
querylist : QueyList list
"""
filter_query = []
for querylist in querylists:
relevance_score_list = [query.relevance_score for query in querylist]
if sum(relevance_score_list) != .0:
filter_query.append(querylist)
return filter_query
def load_from_text(filepath, shuffle=True, fill_missing=-1): def load_from_text(filepath, shuffle=True, fill_missing=-1):
""" """
parse data file into querys parse data file into querys
...@@ -219,12 +277,16 @@ def load_from_text(filepath, shuffle=True, fill_missing=-1): ...@@ -219,12 +277,16 @@ def load_from_text(filepath, shuffle=True, fill_missing=-1):
for line in f: for line in f:
query = Query() query = Query()
query = query._parse_(line) query = query._parse_(line)
if query == None:
continue
if query.query_id != prev_query_id: if query.query_id != prev_query_id:
if querylist is not None: if querylist is not None:
querylists.append(querylist) querylists.append(querylist)
querylist = QueryList() querylist = QueryList()
prev_query_id = query.query_id prev_query_id = query.query_id
querylist._add_query(query) querylist._add_query(query)
if querylist is not None:
querylists.append(querylist)
if shuffle == True: if shuffle == True:
random.shuffle(querylists) random.shuffle(querylists)
return querylists return querylists
...@@ -244,10 +306,15 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1): ...@@ -244,10 +306,15 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1):
label query_left, query_right # format = "pairwise" label query_left, query_right # format = "pairwise"
label querylist # format = "listwise" label querylist # format = "listwise"
""" """
querylists = load_from_text( querylists = query_filter(
filepath, shuffle=shuffle, fill_missing=fill_missing) load_from_text(
filepath, shuffle=shuffle, fill_missing=fill_missing))
for querylist in querylists: for querylist in querylists:
if format == "pairwise": if format == "plain_txt":
yield next(gen_plain_txt(querylist))
elif format == "pointwise":
yield next(gen_point(querylist))
elif format == "pairwise":
for pair in gen_pair(querylist): for pair in gen_pair(querylist):
yield pair yield pair
elif format == "listwise": elif format == "listwise":
...@@ -264,7 +331,7 @@ def fetch(): ...@@ -264,7 +331,7 @@ def fetch():
if __name__ == "__main__": if __name__ == "__main__":
fetch() fetch()
for i, (score, mytest = functools.partial(
samples) in enumerate(train( __reader__, filepath="MQ2007/MQ2007/Fold1/sample", format="listwise")
format="listwise", shuffle=False)): for label, query in mytest():
np.savetxt("query_%d" % (i), score, fmt="%.2f") print label, query
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册