提交 a4313de8 编写于 作者: D dzhwinter

"remove the pairwise other genereate method"

上级 4ac5caaa
......@@ -89,8 +89,10 @@ class Query(object):
line = text[:comment_position].strip()
self.description = text[comment_position + 1:].strip()
parts = line.split()
assert (len(parts) == 48), "expect 48 space split parts, get %d" % (
len(parts))
if len(parts) != 48:
sys.stdout.write("expect 48 space split parts, get %d" %
(len(parts)))
return None
# format : 0 qid:10 1:0.000272 2:0.000000 ....
self.relevance_score = int(parts[0])
self.query_id = int(parts[1].split(':')[1])
......@@ -125,6 +127,9 @@ class QueryList(object):
def __len__(self):
return len(self.querylist)
def __getitem__(self, i):
return self.querylist[i]
def _correct_ranking_(self):
if self.querylist is None:
return
......@@ -139,6 +144,46 @@ class QueryList(object):
self.querylist.append(query)
def gen_plain_txt(querylist):
"""
gen plain text in list for other usage
Paramters:
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
return :
------
query_id : np.array, shape=(samples_num, )
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
for query in querylist:
yield querylist.query_id, query.relevance_score, np.array(
query.feature_vector)
def gen_point(querylist):
"""
gen item in list for point-wise learning to rank algorithm
Paramters:
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
return :
------
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
for query in querylist:
yield query.relevance_score, np.array(query.feature_vector)
def gen_pair(querylist, partial_order="full"):
"""
gen pair for pair-wise learning to rank algorithm
......@@ -146,6 +191,7 @@ def gen_pair(querylist, partial_order="full"):
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
pairtial_order : "full" or "neighbour"
there is redudant in all possiable pair combinations, which can be simplifed
gen pairs for neighbour items or the full partial order pairs
return :
......@@ -157,34 +203,28 @@ def gen_pair(querylist, partial_order="full"):
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
querylist._correct_ranking_()
labels = []
docpairs = []
# C(n,2)
if partial_order == "full":
for i, query_left in enumerate(querylist):
for j, query_right in enumerate(querylist):
if query_left.relevance_score > query_right.relevance_score:
yield 1, np.array(query_left.feature_vector), np.array(
query_right.feature_vector)
else:
yield 1, np.array(query_left.feature_vector), np.array(
query_right.feature_vector)
elif partial_order == "neighbour":
# C(n)
k = 0
while k < len(querylist) - 1:
query_left = querylist[k]
query_right = querylist[k + 1]
for i in range(len(querylist)):
query_left = querylist[i]
for j in range(i + 1, len(querylist)):
query_right = querylist[j]
if query_left.relevance_score > query_right.relevance_score:
yield 1, np.array(query_left.feature_vector), np.array(
query_right.feature_vector)
else:
yield 1, np.array(query_left.feature_vector), np.array(
query_right.feature_vector)
k += 1
else:
raise ValueError(
"unsupport parameter of partial_order, Only can be neighbour or full"
)
labels.append(1)
docpairs.append([
np.array(query_left.feature_vector),
np.array(query_right.feature_vector)
])
elif query_left.relevance_score < query_right.relevance_score:
labels.append(1)
docpairs.append([
np.array(query_right.feature_vector),
np.array(query_left.feature_vector)
])
for label, pair in zip(labels, docpairs):
yield label, pair[0], pair[1]
def gen_list(querylist):
......@@ -201,12 +241,30 @@ def gen_list(querylist):
"""
if not isinstance(querylist, QueryList):
querylist = QueryList(querylist)
# querylist._correct_ranking_()
querylist._correct_ranking_()
relevance_score_list = [query.relevance_score for query in querylist]
feature_vector_list = [query.feature_vector for query in querylist]
yield np.array(relevance_score_list).T, np.array(feature_vector_list)
def query_filter(querylists):
"""
filter query get only document with label 0.
label 0, 1, 2 means the relevance score document with query
parameters :
querylist : QueyList list
return :
querylist : QueyList list
"""
filter_query = []
for querylist in querylists:
relevance_score_list = [query.relevance_score for query in querylist]
if sum(relevance_score_list) != .0:
filter_query.append(querylist)
return filter_query
def load_from_text(filepath, shuffle=True, fill_missing=-1):
"""
parse data file into querys
......@@ -219,12 +277,16 @@ def load_from_text(filepath, shuffle=True, fill_missing=-1):
for line in f:
query = Query()
query = query._parse_(line)
if query == None:
continue
if query.query_id != prev_query_id:
if querylist is not None:
querylists.append(querylist)
querylist = QueryList()
prev_query_id = query.query_id
querylist._add_query(query)
if querylist is not None:
querylists.append(querylist)
if shuffle == True:
random.shuffle(querylists)
return querylists
......@@ -244,10 +306,15 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1):
label query_left, query_right # format = "pairwise"
label querylist # format = "listwise"
"""
querylists = load_from_text(
filepath, shuffle=shuffle, fill_missing=fill_missing)
querylists = query_filter(
load_from_text(
filepath, shuffle=shuffle, fill_missing=fill_missing))
for querylist in querylists:
if format == "pairwise":
if format == "plain_txt":
yield next(gen_plain_txt(querylist))
elif format == "pointwise":
yield next(gen_point(querylist))
elif format == "pairwise":
for pair in gen_pair(querylist):
yield pair
elif format == "listwise":
......@@ -264,7 +331,7 @@ def fetch():
if __name__ == "__main__":
fetch()
for i, (score,
samples) in enumerate(train(
format="listwise", shuffle=False)):
np.savetxt("query_%d" % (i), score, fmt="%.2f")
mytest = functools.partial(
__reader__, filepath="MQ2007/MQ2007/Fold1/sample", format="listwise")
for label, query in mytest():
print label, query
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册