infer.py 6.6 KB
Newer Older
C
caoying03 已提交
1
#coding=utf-8
C
caoying03 已提交
2

C
caoying03 已提交
3 4
import os
import sys
C
caoying03 已提交
5
import argparse
C
caoying03 已提交
6 7 8 9 10 11 12 13 14 15
import gzip
import logging
import numpy as np

import paddle.v2 as paddle
from paddle.v2.layer import parse_network
import reader

from model import GNR
from train import choose_samples
C
caoying03 已提交
16
from config import ModelConfig
C
caoying03 已提交
17
from beam_decoding import BeamDecoding
C
caoying03 已提交
18 19 20 21 22

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)


C
caoying03 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
def parse_cmd():
    """
    Build the command line arguments parser for inferring task.
    """
    parser = argparse.ArgumentParser(
        description="Globally Normalized Reader in PaddlePaddle.")
    parser.add_argument(
        "--model_path",
        required=True,
        type=str,
        help="Path of the trained model to evaluate.",
        default="")
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Path of the training and testing data.",
        default="")
    parser.add_argument(
        "--batch_size",
        type=int,
        required=False,
        help="The batch size for inferring.",
        default=1)
    parser.add_argument(
        "--use_gpu",
        type=int,
        required=False,
        help="Whether to run the inferring on GPU.",
        default=0)
    parser.add_argument(
        "--trainer_count",
        type=int,
        required=False,
        help=("The thread number used in inferring. When set "
              "use_gpu=True, the trainer_count cannot excess "
              "the gpu device number in your computer."),
        default=1)
    return parser.parse_args()


C
caoying03 已提交
64
def load_reverse_dict(dict_file):
C
caoying03 已提交
65 66 67 68 69 70 71
    """ Build the dict which is used to map the word index to word string.

    The keys are word index and the values are word strings.

    Arguments:
        - dict_file:    The path of a word dictionary.
    """
C
caoying03 已提交
72 73 74 75 76 77 78
    word_dict = {}
    with open(dict_file, "r") as fin:
        for idx, line in enumerate(fin):
            word_dict[idx] = line.strip()
    return word_dict


C
caoying03 已提交
79
def print_result(test_batch, predicted_ans, ids_2_word, print_top_k=1):
C
caoying03 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    """ Print the readable predicted answers.

    Format of the output:
        query:\tthe input query.
        documents:\n
        0\tthe first sentence in the document.
        1\tthe second sentence in the document.
        ...
        gold:\t[i j k] the answer words.
            (i: the sentence index;
             j: the start span index;
             k: the end span index)
        top answers:
        score0\t[i j k] the answer with the highest score.
        score1\t[i j k] the answer with the second highest score.
            (i, j, k has a same meaning as in gold.)
        ...

        By default, top 10 answers will be printed.

    Arguments:
        - test_batch:     A test batch returned by reader.
        - predicted_ans:  The beam decoding results.
        - ids_2_word:     The dict whose key is word index and the values are
                          word strings.
        - print_top_k:    Indicating how many answers will be printed.
    """

C
caoying03 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121
    for i, sample in enumerate(test_batch):
        query_words = [ids_2_word[ids] for ids in sample[0]]
        print("query:\t%s" % (" ".join(query_words)))

        print("documents:")
        for j, sen in enumerate(sample[1]):
            sen_words = [ids_2_word[ids] for ids in sen]
            start = sample[4]
            end = sample[4] + sample[5] + 1
            print("%d\t%s" % (j, " ".join(sen_words)))
        print("gold:\t[%d %d %d] %s" % (
            sample[3], sample[4], sample[5], " ".join(
                [ids_2_word[ids] for ids in sample[1][sample[3]][start:end]])))

C
caoying03 已提交
122
        print("top answers:")
C
caoying03 已提交
123 124 125 126 127 128 129 130 131 132 133
        for k in range(print_top_k):
            label = predicted_ans[i][k]["label"]
            start = label[1]
            end = label[1] + label[2] + 1
            ans_words = [
                ids_2_word[ids] for ids in sample[1][label[0]][start:end]
            ]
            print("%.4f\t[%d %d %d] %s" %
                  (predicted_ans[i][k]["score"], label[0], label[1], label[2],
                   " ".join(ans_words)))
        print("\n")
C
caoying03 已提交
134 135 136


def infer_a_batch(inferer, test_batch, ids_2_word, out_layer_count):
C
caoying03 已提交
137 138 139 140 141 142 143 144 145 146
    """ Call the PaddlePaddle's infer interface to infer by batch.

    Arguments:
        - inferer:         The PaddlePaddle Inference object.
        - test_batch:      A test batch returned by reader.
        - ids_2_word:      The dict whose key is word index and the values are
                           word strings.
        - out_layer_count: The number of output layers in the inferring process.
    """

C
caoying03 已提交
147
    outs = inferer.infer(input=test_batch, flatten_result=False, field="value")
C
caoying03 已提交
148 149
    decoder = BeamDecoding([sample[1] for sample in test_batch], *outs)
    print_result(test_batch, decoder.decoding(), ids_2_word, print_top_k=10)
C
caoying03 已提交
150 151


C
caoying03 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
def infer(model_path,
          data_dir,
          batch_size,
          config,
          use_gpu=False,
          trainer_count=1):
    """ The inferring process.

    Arguments:
        - model_path:      The path of trained model.
        - data_dir:        The directory path of test data.
        - batch_size:      The batch_size.
        - config:          The model configuration.
        - use_gpu:         Whether to run the inferring on GPU.
        - trainer_count:   The thread number used in inferring. When set
                           use_gpu=True, the trainer_count cannot excess
                           the gpu device number in your computer.
    """

C
caoying03 已提交
171
    assert os.path.exists(model_path), "The model does not exist."
C
caoying03 已提交
172
    paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
C
caoying03 已提交
173 174 175 176 177 178 179 180

    ids_2_word = load_reverse_dict(config.dict_path)

    outputs = GNR(config, is_infer=True)

    # load the trained models
    parameters = paddle.parameters.Parameters.from_tar(
        gzip.open(model_path, "r"))
C
caoying03 已提交
181 182
    logger.info("loading parameter is done.")

C
caoying03 已提交
183 184 185 186 187 188 189 190 191
    inferer = paddle.inference.Inference(
        output_layer=outputs, parameters=parameters)

    _, valid_samples = choose_samples(data_dir)
    test_reader = reader.data_reader(valid_samples, is_train=False)

    test_batch = []
    for i, item in enumerate(test_reader()):
        test_batch.append(item)
C
caoying03 已提交
192
        if len(test_batch) == batch_size:
C
caoying03 已提交
193 194 195 196 197 198 199 200
            infer_a_batch(inferer, test_batch, ids_2_word, len(outputs))
            test_batch = []

    if len(test_batch):
        infer_a_batch(inferer, test_batch, ids_2_word, len(outputs))
        test_batch = []


C
caoying03 已提交
201 202 203 204 205 206 207 208 209 210
def main(args):
    infer(
        model_path=args.model_path,
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        config=ModelConfig,
        use_gpu=args.use_gpu,
        trainer_count=args.trainer_count)


C
caoying03 已提交
211
if __name__ == "__main__":
C
caoying03 已提交
212 213
    args = parse_cmd()
    main(args)