提交 18e6fa95 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #173 from dzhwinter/fix_ltr2

"add start script"
......@@ -3,6 +3,7 @@ import gzip
import paddle.v2 as paddle
import numpy as np
import functools
import argparse
def lambda_rank(input_dim):
......@@ -117,6 +118,15 @@ def lambda_rank_infer(pass_id):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='LambdaRank demo')
parser.add_argument("--run_type", type=str, help="run type is train|infer")
parser.add_argument(
"--num_passes",
type=int,
help="num of passes in train| infer pass number of model")
args = parser.parse_args()
paddle.init(use_gpu=False, trainer_count=1)
train_lambda_rank(2)
lambda_rank_infer(pass_id=1)
if args.run_type == "train":
train_lambda_rank(args.num_passes)
elif args.run_type == "infer":
lambda_rank_infer(pass_id=args.pass_num - 1)
......@@ -5,6 +5,7 @@ import functools
import paddle.v2 as paddle
import numpy as np
from metrics import ndcg
import argparse
# ranknet is the classic pairwise learning to rank algorithm
# http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf
......@@ -104,7 +105,7 @@ def ranknet_infer(pass_id):
# we just need half_ranknet to predict a rank score, which can be used in sort documents
output = half_ranknet("infer", feature_dim)
parameters = paddle.parameters.Parameters.from_tar(
gzip.open("ranknet_params_%d.tar.gz" % (pass_id - 1)))
gzip.open("ranknet_params_%d.tar.gz" % (pass_id)))
# load data of same query and relevance documents, need ranknet to rank these candidates
infer_query_id = []
......@@ -118,18 +119,27 @@ def ranknet_infer(pass_id):
for query_id, relevance_score, feature_vector in plain_txt_test():
infer_query_id.append(query_id)
infer_data.append(feature_vector)
infer_data.append([feature_vector])
# predict score of infer_data document. Re-sort the document base on predict score
# in descending order. then we build the ranking documents
scores = paddle.infer(
output_layer=output, parameters=parameters, input=infer_data)
print scores
for query_id, score in zip(infer_query_id, scores):
print "query_id : ", query_id, " ranknet rank document order : ", score
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Ranknet demo')
parser.add_argument("--run_type", type=str, help="run type is train|infer")
parser.add_argument(
"--num_passes",
type=int,
help="num of passes in train| infer pass number of model")
args = parser.parse_args()
paddle.init(use_gpu=False, trainer_count=4)
pass_num = 2
train_ranknet(pass_num)
ranknet_infer(pass_id=pass_num - 1)
if args.run_type == "train":
train_ranknet(args.num_passes)
elif args.run_type == "infer":
ranknet_infer(pass_id=args.pass_num - 1)
#!/bin/sh
python lambda_rank.py \
--run_type="train" \
--num_passes=10 \
2>&1 | tee lambdarank_train.log
python lambda_rank.py \
--run_type="infer" \
--num_passes=10 \
2>&1 | tee lambdarank_infer.log
#!/bin/sh
python ranknet.py \
--run_type="train" \
--num_passes=10 \
2>&1 | tee rankenet_train.log
python ranknet.py \
--run_type="infer" \
--num_passes=10 \
2>&1 | tee ranknet_infer.log
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册