infer.py 3.2 KB
Newer Older
zhaoyijin666's avatar
zhaoyijin666 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import gzip
import paddle.v2 as paddle
import argparse
import cPickle

from reader import Reader
from network_conf import DNNmodel
from utils import logger


def parse_args():
    """
    parse arguments
    :return:
    """
    parser = argparse.ArgumentParser(
        description="PaddlePaddle Youtube Recall Model Example")
    parser.add_argument(
        '--infer_set_path',
        type=str,
        required=True,
        help="path of the infer set")
    parser.add_argument(
        '--model_path', type=str, required=True, help="path of the model")
    parser.add_argument(
        '--feature_dict',
        type=str,
        required=True,
        help="path of feature_dict.pkl")
    parser.add_argument(
        '--batch_size',
        type=int,
        default=50,
        help="size of mini-batch (default:50)")
    return parser.parse_args()


def infer():
    """
    infer
    """
    args = parse_args()

    # check argument
    assert os.path.exists(
        args.infer_set_path), 'The train_set_path path does not exist.'
    assert os.path.exists(
        args.model_path), 'The model_path path does not exist.'
    assert os.path.exists(
        args.feature_dict), 'The feature_dict path does not exist.'

    paddle.init(use_gpu=False, trainer_count=1)

    with open(args.feature_dict) as f:
        feature_dict = cPickle.load(f)

    nid_dict = feature_dict['history_clicked_items']
    nid_to_word = dict((v, k) for k, v in nid_dict.items())

    # load the trained model.
    with gzip.open(args.model_path) as f:
        parameters = paddle.parameters.Parameters.from_tar(f)

    # build model
    prediction_layer, fc = DNNmodel(
        dnn_layer_dims=[256, 31], feature_dict=feature_dict,
        is_infer=True).model_cost
    inferer = paddle.inference.Inference(
        output_layer=[prediction_layer, fc], parameters=parameters)

    reader = Reader(feature_dict)
    test_batch = []
    for idx, item in enumerate(reader.infer(args.infer_set_path)):
        test_batch.append(item)
        if len(test_batch) == args.batch_size:
            infer_a_batch(inferer, test_batch, nid_to_word)
            test_batch = []
    if len(test_batch):
        infer_a_batch(inferer, test_batch, nid_to_word)


def infer_a_batch(inferer, test_batch, nid_to_word):
    """
    input a batch of data and infer 
    """
    feeding = {
        'user_id': 0,
        'province': 1,
        'city': 2,
        'history_clicked_items': 3,
        'history_clicked_categories': 4,
        'history_clicked_tags': 5,
        'phone': 6
    }
    probs = inferer.infer(
        input=test_batch,
        feeding=feeding,
        field=["value"],
        flatten_result=False)
    for i, res in enumerate(zip(test_batch, probs[0], probs[1])):
        print "Sample %s:" % str(i)
        softmax_output = res[1]
        sort_nid = res[1].argsort()

        # 输出top 30的推荐视频id,title,分数
        for j in range(1, 30):
            item_id = sort_nid[-1 * j]
            item_id_to_word = nid_to_word[item_id]
            print "%s\t%.6f" \
                    % (item_id_to_word, softmax_output[item_id])


if __name__ == "__main__":
    infer()