ranknet.py 6.6 KB
Newer Older
D
dzhwinter 已提交
1 2
import os
import sys
D
dong zhihong 已提交
3
import gzip
4
import functools
D
dongzhihong 已提交
5
import argparse
C
caoying03 已提交
6 7 8 9 10 11 12
import logging
import numpy as np

import paddle.v2 as paddle

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
D
dong zhihong 已提交
13 14 15 16

# ranknet is the classic pairwise learning to rank algorithm
# http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf

17

C
caoying03 已提交
18 19 20 21
def score_diff(right_score, left_score):
    return np.average(np.abs(right_score - left_score))


D
dong zhihong 已提交
22
def half_ranknet(name_prefix, input_dim):
23
    """
C
caoying03 已提交
24 25 26 27 28
    parameter in same name will be shared in paddle framework,
    these parameters in ranknet can be used in shared state,
    e.g. left network and right network shared parameters in detail
    https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/api.md
    """
29
    # data layer
C
caoying03 已提交
30
    data = paddle.layer.data(name_prefix + "_data",
31
                             paddle.data_type.dense_vector(input_dim))
D
dong zhihong 已提交
32

33 34 35
    # hidden layer
    hd1 = paddle.layer.fc(
        input=data,
C
caoying03 已提交
36
        name=name_prefix + "_hidden",
37 38 39
        size=10,
        act=paddle.activation.Tanh(),
        param_attr=paddle.attr.Param(initial_std=0.01, name="hidden_w1"))
C
caoying03 已提交
40 41

    # fully connected layer and output layer
42 43
    output = paddle.layer.fc(
        input=hd1,
C
caoying03 已提交
44
        name=name_prefix + "_score",
45 46 47 48
        size=1,
        act=paddle.activation.Linear(),
        param_attr=paddle.attr.Param(initial_std=0.01, name="output"))
    return output
D
dong zhihong 已提交
49 50 51


def ranknet(input_dim):
52
    # label layer
D
dzhwinter 已提交
53
    label = paddle.layer.data("label", paddle.data_type.dense_vector(1))
54 55 56 57 58 59 60 61 62

    # reuse the parameter in half_ranknet
    output_left = half_ranknet("left", input_dim)
    output_right = half_ranknet("right", input_dim)

    # rankcost layer
    cost = paddle.layer.rank_cost(
        name="cost", left=output_left, right=output_right, label=label)
    return cost
D
dong zhihong 已提交
63 64


C
caoying03 已提交
65
def ranknet_train(num_passes, model_save_dir):
66 67 68
    train_reader = paddle.batch(
        paddle.reader.shuffle(paddle.dataset.mq2007.train, buf_size=100),
        batch_size=100)
D
dong zhihong 已提交
69
    test_reader = paddle.batch(paddle.dataset.mq2007.test, batch_size=100)
70

D
dzhwinter 已提交
71
    # mq2007 feature_dim = 46, dense format
72 73 74 75 76 77 78 79 80 81 82
    # fc hidden_dim = 128
    feature_dim = 46
    cost = ranknet(feature_dim)
    parameters = paddle.parameters.create(cost)

    trainer = paddle.trainer.SGD(
        cost=cost,
        parameters=parameters,
        update_equation=paddle.optimizer.Adam(learning_rate=2e-4))

    # Define the input data order
C
caoying03 已提交
83
    feeding = {"label": 0, "left_data": 1, "right_data": 2}
84 85 86 87

    #  Define end batch and end pass event handler
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
88 89 90
            if event.batch_id % 25 == 0:
                diff = score_diff(
                    event.gm.getLayerOutputs("left_score")["left_score"][
C
caoying03 已提交
91 92
                        "value"],
                    event.gm.getLayerOutputs("right_score")["right_score"][
C
caoying03 已提交
93 94 95 96 97
                        "value"])
                logger.info(("Pass %d Batch %d : Cost %.6f, "
                             "average absolute diff scores: %.6f") %
                            (event.pass_id, event.batch_id, event.cost, diff))

98 99
        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(reader=test_reader, feeding=feeding)
C
caoying03 已提交
100 101 102 103 104
            logger.info("\nTest with Pass %d, %s" %
                        (event.pass_id, result.metrics))
            with gzip.open(
                    os.path.join(model_save_dir, "ranknet_params_%d.tar.gz" %
                                 (event.pass_id)), "w") as f:
105
                trainer.save_parameter_to_tar(f)
106 107 108 109 110 111 112

    trainer.train(
        reader=train_reader,
        event_handler=event_handler,
        feeding=feeding,
        num_passes=num_passes)

D
dong zhihong 已提交
113

C
caoying03 已提交
114
def ranknet_infer(model_path):
115
    """
C
caoying03 已提交
116 117 118
    load the trained model. And predict with plain txt input
    """
    logger.info("Begin to Infer...")
119 120
    feature_dim = 46

C
caoying03 已提交
121 122
    # we just need half_ranknet to predict a rank score,
    # which can be used in sort documents
C
caoying03 已提交
123
    output = half_ranknet("right", feature_dim)
C
caoying03 已提交
124
    parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
125

C
caoying03 已提交
126 127
    # load data of same query and relevance documents,
    # need ranknet to rank these candidates
D
dzhwinter 已提交
128
    infer_query_id = []
129
    infer_data = []
D
dzhwinter 已提交
130
    infer_doc_index = []
131 132 133 134 135 136 137

    # convert to mq2007 built-in data format
    # <query_id> <relevance_score> <feature_vector>
    plain_txt_test = functools.partial(
        paddle.dataset.mq2007.test, format="plain_txt")

    for query_id, relevance_score, feature_vector in plain_txt_test():
D
dzhwinter 已提交
138
        infer_query_id.append(query_id)
D
dongzhihong 已提交
139
        infer_data.append([feature_vector])
D
dzhwinter 已提交
140

C
caoying03 已提交
141 142
    # predict score of infer_data document.
    # Re-sort the document base on predict score
D
dzhwinter 已提交
143 144
    # in descending order. then we build the ranking documents
    scores = paddle.infer(
145
        output_layer=output, parameters=parameters, input=infer_data)
D
dzhwinter 已提交
146
    for query_id, score in zip(infer_query_id, scores):
C
caoying03 已提交
147
        print "query_id : ", query_id, " score : ", score
148

D
dong zhihong 已提交
149

C
caoying03 已提交
150 151 152 153 154 155 156 157 158
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="PaddlePaddle RankNet example.")
    parser.add_argument(
        "--run_type",
        type=str,
        help=("A flag indicating to run the training or the inferring task. "
              "Available options are: train or infer."),
        default="train")
D
dongzhihong 已提交
159 160 161
    parser.add_argument(
        "--num_passes",
        type=int,
C
caoying03 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        help="The number of passes to train the model.",
        default=10)
    parser.add_argument(
        "--use_gpu",
        type=bool,
        help="A flag indicating whether to use the GPU device in training.",
        default=False)
    parser.add_argument(
        "--trainer_count",
        type=int,
        help="The thread number used in training.",
        default=1)
    parser.add_argument(
        "--model_save_dir",
        type=str,
        required=False,
        help=("The path to save the trained models."),
        default="models")
    parser.add_argument(
        "--test_model_path",
        type=str,
        required=False,
        help=("This parameter works only in inferring task to "
              "specify path of a trained model."),
        default="")

D
dongzhihong 已提交
188
    args = parser.parse_args()
C
caoying03 已提交
189 190 191 192
    if not os.path.exists(args.model_save_dir): os.mkdir(args.model_save_dir)

    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)

D
dongzhihong 已提交
193
    if args.run_type == "train":
C
caoying03 已提交
194
        ranknet_train(args.num_passes, args.model_save_dir)
D
dongzhihong 已提交
195
    elif args.run_type == "infer":
C
caoying03 已提交
196 197 198 199 200 201
        assert os.path.exists(
            args.test_model_path), "The trained model does not exit."
        ranknet_infer(args.test_model_path)
    else:
        logger.fatal(("A wrong value for parameter run type. "
                      "Available options are: train or infer."))