Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a4313de8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a4313de8
编写于
5月 15, 2017
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"remove the pairwise other genereate method"
上级
4ac5caaa
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
103 addition
and
36 deletion
+103
-36
python/paddle/v2/dataset/mq2007.py
python/paddle/v2/dataset/mq2007.py
+103
-36
未找到文件。
python/paddle/v2/dataset/mq2007.py
浏览文件 @
a4313de8
...
@@ -89,8 +89,10 @@ class Query(object):
...
@@ -89,8 +89,10 @@ class Query(object):
line
=
text
[:
comment_position
].
strip
()
line
=
text
[:
comment_position
].
strip
()
self
.
description
=
text
[
comment_position
+
1
:].
strip
()
self
.
description
=
text
[
comment_position
+
1
:].
strip
()
parts
=
line
.
split
()
parts
=
line
.
split
()
assert
(
len
(
parts
)
==
48
),
"expect 48 space split parts, get %d"
%
(
if
len
(
parts
)
!=
48
:
len
(
parts
))
sys
.
stdout
.
write
(
"expect 48 space split parts, get %d"
%
(
len
(
parts
)))
return
None
# format : 0 qid:10 1:0.000272 2:0.000000 ....
# format : 0 qid:10 1:0.000272 2:0.000000 ....
self
.
relevance_score
=
int
(
parts
[
0
])
self
.
relevance_score
=
int
(
parts
[
0
])
self
.
query_id
=
int
(
parts
[
1
].
split
(
':'
)[
1
])
self
.
query_id
=
int
(
parts
[
1
].
split
(
':'
)[
1
])
...
@@ -125,6 +127,9 @@ class QueryList(object):
...
@@ -125,6 +127,9 @@ class QueryList(object):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
querylist
)
return
len
(
self
.
querylist
)
def
__getitem__
(
self
,
i
):
return
self
.
querylist
[
i
]
def
_correct_ranking_
(
self
):
def
_correct_ranking_
(
self
):
if
self
.
querylist
is
None
:
if
self
.
querylist
is
None
:
return
return
...
@@ -139,6 +144,46 @@ class QueryList(object):
...
@@ -139,6 +144,46 @@ class QueryList(object):
self
.
querylist
.
append
(
query
)
self
.
querylist
.
append
(
query
)
def
gen_plain_txt
(
querylist
):
"""
gen plain text in list for other usage
Paramters:
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
return :
------
query_id : np.array, shape=(samples_num, )
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if
not
isinstance
(
querylist
,
QueryList
):
querylist
=
QueryList
(
querylist
)
querylist
.
_correct_ranking_
()
for
query
in
querylist
:
yield
querylist
.
query_id
,
query
.
relevance_score
,
np
.
array
(
query
.
feature_vector
)
def
gen_point
(
querylist
):
"""
gen item in list for point-wise learning to rank algorithm
Paramters:
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
return :
------
label : np.array, shape=(samples_num, )
querylist : np.array, shape=(samples_num, feature_dimension)
"""
if
not
isinstance
(
querylist
,
QueryList
):
querylist
=
QueryList
(
querylist
)
querylist
.
_correct_ranking_
()
for
query
in
querylist
:
yield
query
.
relevance_score
,
np
.
array
(
query
.
feature_vector
)
def
gen_pair
(
querylist
,
partial_order
=
"full"
):
def
gen_pair
(
querylist
,
partial_order
=
"full"
):
"""
"""
gen pair for pair-wise learning to rank algorithm
gen pair for pair-wise learning to rank algorithm
...
@@ -146,6 +191,7 @@ def gen_pair(querylist, partial_order="full"):
...
@@ -146,6 +191,7 @@ def gen_pair(querylist, partial_order="full"):
--------
--------
querylist : querylist, one query match many docment pairs in list, see QueryList
querylist : querylist, one query match many docment pairs in list, see QueryList
pairtial_order : "full" or "neighbour"
pairtial_order : "full" or "neighbour"
there is redudant in all possiable pair combinations, which can be simplifed
gen pairs for neighbour items or the full partial order pairs
gen pairs for neighbour items or the full partial order pairs
return :
return :
...
@@ -157,34 +203,28 @@ def gen_pair(querylist, partial_order="full"):
...
@@ -157,34 +203,28 @@ def gen_pair(querylist, partial_order="full"):
if
not
isinstance
(
querylist
,
QueryList
):
if
not
isinstance
(
querylist
,
QueryList
):
querylist
=
QueryList
(
querylist
)
querylist
=
QueryList
(
querylist
)
querylist
.
_correct_ranking_
()
querylist
.
_correct_ranking_
()
labels
=
[]
docpairs
=
[]
# C(n,2)
# C(n,2)
if
partial_order
==
"full"
:
for
i
in
range
(
len
(
querylist
)):
for
i
,
query_left
in
enumerate
(
querylist
):
query_left
=
querylist
[
i
]
for
j
,
query_right
in
enumerate
(
querylist
):
for
j
in
range
(
i
+
1
,
len
(
querylist
)):
if
query_left
.
relevance_score
>
query_right
.
relevance_score
:
query_right
=
querylist
[
j
]
yield
1
,
np
.
array
(
query_left
.
feature_vector
),
np
.
array
(
query_right
.
feature_vector
)
else
:
yield
1
,
np
.
array
(
query_left
.
feature_vector
),
np
.
array
(
query_right
.
feature_vector
)
elif
partial_order
==
"neighbour"
:
# C(n)
k
=
0
while
k
<
len
(
querylist
)
-
1
:
query_left
=
querylist
[
k
]
query_right
=
querylist
[
k
+
1
]
if
query_left
.
relevance_score
>
query_right
.
relevance_score
:
if
query_left
.
relevance_score
>
query_right
.
relevance_score
:
yield
1
,
np
.
array
(
query_left
.
feature_vector
),
np
.
array
(
labels
.
append
(
1
)
query_right
.
feature_vector
)
docpairs
.
append
([
else
:
np
.
array
(
query_left
.
feature_vector
),
yield
1
,
np
.
array
(
query_left
.
feature_vector
),
np
.
array
(
np
.
array
(
query_right
.
feature_vector
)
query_right
.
feature_vector
)
])
k
+=
1
elif
query_left
.
relevance_score
<
query_right
.
relevance_score
:
else
:
labels
.
append
(
1
)
raise
ValueError
(
docpairs
.
append
([
"unsupport parameter of partial_order, Only can be neighbour or full"
np
.
array
(
query_right
.
feature_vector
),
)
np
.
array
(
query_left
.
feature_vector
)
])
for
label
,
pair
in
zip
(
labels
,
docpairs
):
yield
label
,
pair
[
0
],
pair
[
1
]
def
gen_list
(
querylist
):
def
gen_list
(
querylist
):
...
@@ -201,12 +241,30 @@ def gen_list(querylist):
...
@@ -201,12 +241,30 @@ def gen_list(querylist):
"""
"""
if
not
isinstance
(
querylist
,
QueryList
):
if
not
isinstance
(
querylist
,
QueryList
):
querylist
=
QueryList
(
querylist
)
querylist
=
QueryList
(
querylist
)
#
querylist._correct_ranking_()
querylist
.
_correct_ranking_
()
relevance_score_list
=
[
query
.
relevance_score
for
query
in
querylist
]
relevance_score_list
=
[
query
.
relevance_score
for
query
in
querylist
]
feature_vector_list
=
[
query
.
feature_vector
for
query
in
querylist
]
feature_vector_list
=
[
query
.
feature_vector
for
query
in
querylist
]
yield
np
.
array
(
relevance_score_list
).
T
,
np
.
array
(
feature_vector_list
)
yield
np
.
array
(
relevance_score_list
).
T
,
np
.
array
(
feature_vector_list
)
def
query_filter
(
querylists
):
"""
filter query get only document with label 0.
label 0, 1, 2 means the relevance score document with query
parameters :
querylist : QueyList list
return :
querylist : QueyList list
"""
filter_query
=
[]
for
querylist
in
querylists
:
relevance_score_list
=
[
query
.
relevance_score
for
query
in
querylist
]
if
sum
(
relevance_score_list
)
!=
.
0
:
filter_query
.
append
(
querylist
)
return
filter_query
def
load_from_text
(
filepath
,
shuffle
=
True
,
fill_missing
=-
1
):
def
load_from_text
(
filepath
,
shuffle
=
True
,
fill_missing
=-
1
):
"""
"""
parse data file into querys
parse data file into querys
...
@@ -219,12 +277,16 @@ def load_from_text(filepath, shuffle=True, fill_missing=-1):
...
@@ -219,12 +277,16 @@ def load_from_text(filepath, shuffle=True, fill_missing=-1):
for
line
in
f
:
for
line
in
f
:
query
=
Query
()
query
=
Query
()
query
=
query
.
_parse_
(
line
)
query
=
query
.
_parse_
(
line
)
if
query
==
None
:
continue
if
query
.
query_id
!=
prev_query_id
:
if
query
.
query_id
!=
prev_query_id
:
if
querylist
is
not
None
:
if
querylist
is
not
None
:
querylists
.
append
(
querylist
)
querylists
.
append
(
querylist
)
querylist
=
QueryList
()
querylist
=
QueryList
()
prev_query_id
=
query
.
query_id
prev_query_id
=
query
.
query_id
querylist
.
_add_query
(
query
)
querylist
.
_add_query
(
query
)
if
querylist
is
not
None
:
querylists
.
append
(
querylist
)
if
shuffle
==
True
:
if
shuffle
==
True
:
random
.
shuffle
(
querylists
)
random
.
shuffle
(
querylists
)
return
querylists
return
querylists
...
@@ -244,10 +306,15 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1):
...
@@ -244,10 +306,15 @@ def __reader__(filepath, format="pairwise", shuffle=True, fill_missing=-1):
label query_left, query_right # format = "pairwise"
label query_left, query_right # format = "pairwise"
label querylist # format = "listwise"
label querylist # format = "listwise"
"""
"""
querylists
=
load_from_text
(
querylists
=
query_filter
(
filepath
,
shuffle
=
shuffle
,
fill_missing
=
fill_missing
)
load_from_text
(
filepath
,
shuffle
=
shuffle
,
fill_missing
=
fill_missing
))
for
querylist
in
querylists
:
for
querylist
in
querylists
:
if
format
==
"pairwise"
:
if
format
==
"plain_txt"
:
yield
next
(
gen_plain_txt
(
querylist
))
elif
format
==
"pointwise"
:
yield
next
(
gen_point
(
querylist
))
elif
format
==
"pairwise"
:
for
pair
in
gen_pair
(
querylist
):
for
pair
in
gen_pair
(
querylist
):
yield
pair
yield
pair
elif
format
==
"listwise"
:
elif
format
==
"listwise"
:
...
@@ -264,7 +331,7 @@ def fetch():
...
@@ -264,7 +331,7 @@ def fetch():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
fetch
()
fetch
()
for
i
,
(
score
,
mytest
=
functools
.
partial
(
samples
)
in
enumerate
(
train
(
__reader__
,
filepath
=
"MQ2007/MQ2007/Fold1/sample"
,
format
=
"listwise"
)
format
=
"listwise"
,
shuffle
=
False
)
):
for
label
,
query
in
mytest
(
):
np
.
savetxt
(
"query_%d"
%
(
i
),
score
,
fmt
=
"%.2f"
)
print
label
,
query
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录