infer.py 3.0 KB
Newer Older
1 2 3 4 5 6 7
import sys
import os
import gzip

import paddle.v2 as paddle

import reader
8
from network_conf import fc_net, convolution_net
9
from utils import logger, load_dict, load_reverse_dict
10 11 12 13


def infer(topology, data_dir, model_path, word_dict_path, label_dict_path,
          batch_size):
C
caoying03 已提交
14
    def _infer_a_batch(inferer, test_batch, ids_2_word, ids_2_label):
15
        probs = inferer.infer(input=test_batch, field=["value"])
16
        assert len(probs) == len(test_batch)
C
caoying03 已提交
17 18 19 20 21
        for word_ids, prob in zip(test_batch, probs):
            word_text = " ".join([ids_2_word[id] for id in word_ids[0]])
            print("%s\t%s\t%s" % (ids_2_label[prob.argmax()],
                                  " ".join(["{:0.4f}".format(p)
                                            for p in prob]), word_text))
22

23
    logger.info("begin to predict...")
24 25 26 27
    use_default_data = (data_dir is None)

    if use_default_data:
        word_dict = paddle.dataset.imdb.word_dict()
C
caoying03 已提交
28 29
        word_reverse_dict = dict((value, key)
                                 for key, value in word_dict.iteritems())
30 31 32 33
        label_reverse_dict = {0: "positive", 1: "negative"}
        test_reader = paddle.dataset.imdb.test(word_dict)
    else:
        assert os.path.exists(
34
            word_dict_path), "the word dictionary file does not exist"
35
        assert os.path.exists(
36
            label_dict_path), "the label dictionary file does not exist"
C
caoying03 已提交
37

38
        word_dict = load_dict(word_dict_path)
C
caoying03 已提交
39
        word_reverse_dict = load_reverse_dict(word_dict_path)
40 41 42 43 44 45 46 47 48 49 50 51 52
        label_reverse_dict = load_reverse_dict(label_dict_path)

        test_reader = reader.test_reader(data_dir, word_dict)()

    dict_dim = len(word_dict)
    class_num = len(label_reverse_dict)
    prob_layer = topology(dict_dim, class_num, is_infer=True)

    # initialize PaddlePaddle
    paddle.init(use_gpu=False, trainer_count=1)

    # load the trained models
    parameters = paddle.parameters.Parameters.from_tar(
53
        gzip.open(model_path, "r"))
54 55 56 57 58 59 60
    inferer = paddle.inference.Inference(
        output_layer=prob_layer, parameters=parameters)

    test_batch = []
    for idx, item in enumerate(test_reader):
        test_batch.append([item[0]])
        if len(test_batch) == batch_size:
C
caoying03 已提交
61 62
            _infer_a_batch(inferer, test_batch, word_reverse_dict,
                           label_reverse_dict)
63 64
            test_batch = []

C
caoying03 已提交
65 66 67 68
    if len(test_batch):
        _infer_a_batch(inferer, test_batch, word_reverse_dict,
                       label_reverse_dict)
        test_batch = []
69 70


71 72
if __name__ == "__main__":
    model_path = "models/dnn_params_pass_00000.tar.gz"
73 74
    assert os.path.exists(model_path), "the trained model does not exist."

75
    nn_type = "dnn"
76 77 78 79
    test_dir = None
    word_dict = None
    label_dict = None

80 81 82 83
    if nn_type == "dnn":
        topology = fc_net
    elif nn_type == "cnn":
        topology = convolution_net
84 85 86 87 88 89 90 91

    infer(
        topology=topology,
        data_dir=test_dir,
        word_dict_path=word_dict,
        label_dict_path=label_dict,
        model_path=model_path,
        batch_size=10)