lambda_rank.py 4.7 KB
Newer Older
C
caoying03 已提交
1 2
import os
import sys
D
dong zhihong 已提交
3 4
import gzip
import functools
D
dongzhihong 已提交
5
import argparse
C
caoying03 已提交
6 7 8
import numpy as np

import paddle.v2 as paddle
D
dong zhihong 已提交
9 10


P
peterzhang2029 已提交
11
def lambda_rank(input_dim, is_infer):
D
dzhwinter 已提交
12
    """
P
peterzhang2029 已提交
13
    Lambda_rank is a Listwise rank model, the input data and label
C
caoying03 已提交
14 15
    must be sequences.

D
dzhwinter 已提交
16 17 18 19
    https://papers.nips.cc/paper/2971-learning-to-rank-with-nonsmooth-cost-functions.pdf
    parameters :
      input_dim, one document's dense feature vector dimension

P
peterzhang2029 已提交
20
    Format of the dense_vector_sequence:
C
caoying03 已提交
21
    [[f, ...], [f, ...], ...], f is a float or an int number
D
dzhwinter 已提交
22
    """
P
peterzhang2029 已提交
23 24 25
    if not is_infer:
        label = paddle.layer.data("label",
                                  paddle.data_type.dense_vector_sequence(1))
26 27 28
    data = paddle.layer.data("data",
                             paddle.data_type.dense_vector_sequence(input_dim))

P
peterzhang2029 已提交
29
    # Define hidden layer.
30 31
    hd1 = paddle.layer.fc(
        input=data,
D
dzhwinter 已提交
32 33 34 35 36 37
        size=128,
        act=paddle.activation.Tanh(),
        param_attr=paddle.attr.Param(initial_std=0.01))

    hd2 = paddle.layer.fc(
        input=hd1,
38 39 40 41
        size=10,
        act=paddle.activation.Tanh(),
        param_attr=paddle.attr.Param(initial_std=0.01))
    output = paddle.layer.fc(
D
dzhwinter 已提交
42
        input=hd2,
43 44 45
        size=1,
        act=paddle.activation.Linear(),
        param_attr=paddle.attr.Param(initial_std=0.01))
D
dzhwinter 已提交
46

P
peterzhang2029 已提交
47 48 49 50 51 52 53 54
    if not is_infer:
        # Define evaluator.
        evaluator = paddle.evaluator.auc(input=output, label=label)
        # Define cost layer.
        cost = paddle.layer.lambda_cost(
            input=output, score=label, NDCG_num=6, max_sort_size=-1)
        return cost, output
    return output
55

D
dong zhihong 已提交
56

C
caoying03 已提交
57
def train_lambda_rank(num_passes):
P
peterzhang2029 已提交
58
    # Listwise input sequence.
59 60 61 62
    fill_default_train = functools.partial(
        paddle.dataset.mq2007.train, format="listwise")
    fill_default_test = functools.partial(
        paddle.dataset.mq2007.test, format="listwise")
P
peterzhang2029 已提交
63

64 65
    train_reader = paddle.batch(
        paddle.reader.shuffle(fill_default_train, buf_size=100), batch_size=32)
D
dong zhihong 已提交
66
    test_reader = paddle.batch(fill_default_test, batch_size=32)
67

P
peterzhang2029 已提交
68
    # Training dataset: mq2007, input_dim = 46, dense format.
69
    input_dim = 46
P
peterzhang2029 已提交
70
    cost, output = lambda_rank(input_dim, is_infer=False)
71 72 73 74 75 76
    parameters = paddle.parameters.create(cost)

    trainer = paddle.trainer.SGD(
        cost=cost,
        parameters=parameters,
        update_equation=paddle.optimizer.Adam(learning_rate=1e-4))
D
dong zhihong 已提交
77

P
peterzhang2029 已提交
78
    #  Define end batch and end pass event handler.
79 80 81 82 83 84 85
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            print "Pass %d Batch %d Cost %.9f" % (event.pass_id, event.batch_id,
                                                  event.cost)
        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(reader=test_reader, feeding=feeding)
            print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
C
caoying03 已提交
86
            with gzip.open("lambda_rank_params_%d.tar.gz" % (event.pass_id),
87
                           "w") as f:
88
                trainer.save_parameter_to_tar(f)
89 90 91 92 93 94 95

    feeding = {"label": 0, "data": 1}
    trainer.train(
        reader=train_reader,
        event_handler=event_handler,
        feeding=feeding,
        num_passes=num_passes)
D
dong zhihong 已提交
96 97


C
caoying03 已提交
98
def lambda_rank_infer(pass_id):
P
peterzhang2029 已提交
99
    """Lambda rank model inference interface.
C
caoying03 已提交
100

P
peterzhang2029 已提交
101
    Parameters:
C
caoying03 已提交
102
        pass_id : inference model in pass_id
103 104 105
    """
    print "Begin to Infer..."
    input_dim = 46
P
peterzhang2029 已提交
106
    output = lambda_rank(input_dim, is_infer=True)
107
    parameters = paddle.parameters.Parameters.from_tar(
C
caoying03 已提交
108
        gzip.open("lambda_rank_params_%d.tar.gz" % (pass_id - 1)))
109 110 111

    infer_query_id = None
    infer_data = []
D
dzhwinter 已提交
112
    infer_data_num = 1
P
peterzhang2029 已提交
113

114 115 116
    fill_default_test = functools.partial(
        paddle.dataset.mq2007.test, format="listwise")
    for label, querylist in fill_default_test():
P
peterzhang2029 已提交
117
        infer_data.append([querylist])
118 119
        if len(infer_data) == infer_data_num:
            break
D
dzhwinter 已提交
120

P
peterzhang2029 已提交
121 122 123
    # Predict score of infer_data document.
    # Re-sort the document base on predict score.
    # In descending order. then we build the ranking documents.
124 125 126
    predicitons = paddle.infer(
        output_layer=output, parameters=parameters, input=infer_data)
    for i, score in enumerate(predicitons):
D
dzhwinter 已提交
127
        print i, score
128

D
dong zhihong 已提交
129 130

if __name__ == '__main__':
D
dongzhihong 已提交
131 132 133 134 135
    parser = argparse.ArgumentParser(description='LambdaRank demo')
    parser.add_argument("--run_type", type=str, help="run type is train|infer")
    parser.add_argument(
        "--num_passes",
        type=int,
P
peterzhang2029 已提交
136
        help="The Num of passes in train| infer pass number of model.")
D
dongzhihong 已提交
137
    args = parser.parse_args()
C
caoying03 已提交
138
    paddle.init(use_gpu=False, trainer_count=1)
D
dongzhihong 已提交
139 140 141
    if args.run_type == "train":
        train_lambda_rank(args.num_passes)
    elif args.run_type == "infer":
D
dongzhihong 已提交
142
        lambda_rank_infer(pass_id=args.num_passes - 1)