infer.py 2.3 KB
Newer Older
P
Peng Li 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
import os
import sys
import argparse

import paddle.v2 as paddle

import reader
import utils
import network
import config

from utils import logger


class Infer(object):
    def __init__(self, conf):
        self.conf = conf

        self.settings = reader.Settings(
            vocab=conf.vocab, is_training=False, label_schema=conf.label_schema)

        # init paddle
        # TODO(lipeng17) v2 API does not support parallel_nn yet. Therefore, we
        # can only use CPU currently
        paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count)

        # define network
        self.tags_layer = network.inference_net(conf)

    def infer(self, model_path, data_path, output):
        test_reader = paddle.batch(
            paddle.reader.buffered(
                reader.create_reader(data_path, self.settings),
                size=self.conf.batch_size * 1000),
            batch_size=self.conf.batch_size)

        # load the trained models
        parameters = paddle.parameters.Parameters.from_tar(
            utils.open_file(model_path, "r"))
        inferer = paddle.inference.Inference(
            output_layer=self.tags_layer, parameters=parameters)

        def count_evi_ids(test_batch):
            num = 0
            for sample in test_batch:
                num += len(sample[reader.E_IDS])
            return num

        for test_batch in test_reader():
            tags = inferer.infer(
                input=test_batch, field=["id"], feeding=network.feeding)
            evi_ids_num = count_evi_ids(test_batch)
            assert len(tags) == evi_ids_num
            print >> output, ";\n".join(str(tag) for tag in tags) + ";"


def parse_cmd():
    parser = argparse.ArgumentParser()
    parser.add_argument("model_path")
    parser.add_argument("data_path")
    parser.add_argument("output", help="'-' for stdout")
    return parser.parse_args()


def main(args):
    conf = config.InferConfig()
    conf.vocab = utils.load_dict(conf.word_dict_path)
    logger.info("length of word dictionary is : %d." % len(conf.vocab))

    if args.output == "-":
        output = sys.stdout
    else:
        output = utils.open_file(args.output, "w")

    infer = Infer(conf)
    infer.infer(args.model_path, args.data_path, output)

    output.close()


if __name__ == "__main__":
    main(parse_cmd())