infer.py 3.6 KB
Newer Older
P
peterzhang2029 已提交
1 2 3
import sys
import os
import gzip
4
import click
P
peterzhang2029 已提交
5 6 7 8

import paddle.v2 as paddle

import reader
P
peterzhang2029 已提交
9 10
from network_conf import nested_net
from utils import logger, load_dict, load_reverse_dict
11 12 13 14 15 16


@click.command('infer')
@click.option(
    "--data_path",
    default=None,
P
peterzhang2029 已提交
17 18
    help=("The path of data for inference (default: None). "
          "If this parameter is not set, "
19 20
          "imdb test dataset will be used."))
@click.option(
P
peterzhang2029 已提交
21
    "--model_path", type=str, required=True, help="The path of saved model.")
22 23 24 25
@click.option(
    "--word_dict_path",
    type=str,
    default=None,
P
peterzhang2029 已提交
26 27
    help=("The path of word dictionary (default: None). "
          "If this parameter is not set, imdb dataset will be used."))
28
@click.option(
P
peterzhang2029 已提交
29 30 31 32 33
    "--label_dict_path",
    type=str,
    default=None,
    help=("The path of label dictionary (default: None)."
          "If this parameter is not set, imdb dataset will be used. "))
34 35 36 37
@click.option(
    "--batch_size",
    type=int,
    default=32,
P
peterzhang2029 已提交
38
    help="The number of examples in one batch (default: 32).")
P
peterzhang2029 已提交
39 40
def infer(data_path, model_path, word_dict_path, batch_size, label_dict_path):
    def _infer_a_batch(inferer, test_batch, ids_2_word, ids_2_label):
P
peterzhang2029 已提交
41 42 43 44 45 46 47
        probs = inferer.infer(input=test_batch, field=["value"])
        assert len(probs) == len(test_batch)
        for word_ids, prob in zip(test_batch, probs):
            sent_ids = []
            for sent in word_ids[0]:
                sent_ids.extend(sent)
            word_text = " ".join([ids_2_word[id] for id in sent_ids])
P
peterzhang2029 已提交
48
            print("%s\t%s\t%s" % (ids_2_label[prob.argmax()],
P
peterzhang2029 已提交
49 50 51
                                  " ".join(["{:0.4f}".format(p)
                                            for p in prob]), word_text))

P
peterzhang2029 已提交
52 53
    assert os.path.exists(model_path), "The trained model does not exist."
    logger.info("Begin to predict...")
P
peterzhang2029 已提交
54 55 56 57 58 59
    use_default_data = (data_path is None)

    if use_default_data:
        word_dict = reader.imdb_word_dict()
        word_reverse_dict = dict((value, key)
                                 for key, value in word_dict.iteritems())
P
peterzhang2029 已提交
60

P
peterzhang2029 已提交
61
        # The reversed label dict of the imdb dataset 
P
peterzhang2029 已提交
62
        label_reverse_dict = {0: "positive", 1: "negative"}
P
peterzhang2029 已提交
63 64 65 66
        test_reader = reader.imdb_test(word_dict)
        class_num = 2
    else:
        assert os.path.exists(
P
peterzhang2029 已提交
67
            word_dict_path), "The word dictionary file does not exist"
P
peterzhang2029 已提交
68 69
        assert os.path.exists(
            label_dict_path), "The label dictionary file does not exist"
P
peterzhang2029 已提交
70

71
        word_dict = load_dict(word_dict_path)
P
peterzhang2029 已提交
72 73
        word_reverse_dict = dict((value, key)
                                 for key, value in word_dict.iteritems())
P
peterzhang2029 已提交
74 75
        label_reverse_dict = load_reverse_dict(label_dict_path)
        class_num = len(label_reverse_dict)
P
peterzhang2029 已提交
76 77 78 79
        test_reader = reader.infer_reader(data_path, word_dict)()

    dict_dim = len(word_dict)

P
peterzhang2029 已提交
80 81
    # initialize PaddlePaddle.
    paddle.init(use_gpu=False, trainer_count=1)
P
peterzhang2029 已提交
82

P
peterzhang2029 已提交
83 84
    prob_layer = nested_net(dict_dim, class_num, is_infer=True)

P
peterzhang2029 已提交
85
    # load the trained models.
P
peterzhang2029 已提交
86 87 88 89 90 91 92 93 94
    parameters = paddle.parameters.Parameters.from_tar(
        gzip.open(model_path, "r"))
    inferer = paddle.inference.Inference(
        output_layer=prob_layer, parameters=parameters)

    test_batch = []
    for idx, item in enumerate(test_reader):
        test_batch.append([item[0]])
        if len(test_batch) == batch_size:
P
peterzhang2029 已提交
95 96
            _infer_a_batch(inferer, test_batch, word_reverse_dict,
                           label_reverse_dict)
P
peterzhang2029 已提交
97 98 99
            test_batch = []

    if len(test_batch):
P
peterzhang2029 已提交
100 101
        _infer_a_batch(inferer, test_batch, word_reverse_dict,
                       label_reverse_dict)
P
peterzhang2029 已提交
102 103 104 105
        test_batch = []


if __name__ == "__main__":
106
    infer()