infer.py 2.7 KB
Newer Older
S
Superjom 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import gzip
import argparse
import itertools

import paddle.v2 as paddle
import network_conf
from train import dnn_layer_dims
import reader
from utils import logger, ModelType

parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser.add_argument(
    '--model_gz_path',
    type=str,
    required=True,
    help="path of model parameters gz file")
parser.add_argument(
    '--data_path', type=str, required=True, help="path of the dataset to infer")
parser.add_argument(
    '--prediction_output_path',
    type=str,
    required=True,
    help="path to output the prediction")
parser.add_argument(
    '--data_meta_path',
    type=str,
    default="./data.meta",
    help="path of trainset's meta info, default is ./data.meta")
parser.add_argument(
    '--model_type',
    type=int,
    required=True,
    default=ModelType.CLASSIFICATION,
    help='model type, classification: %d, regression %d (default classification)'
    % (ModelType.CLASSIFICATION, ModelType.REGRESSION))

args = parser.parse_args()

paddle.init(use_gpu=False, trainer_count=1)


class CTRInferer(object):
    def __init__(self, param_path):
        logger.info("create CTR model")
        dnn_input_dim, lr_input_dim = reader.load_data_meta(args.data_meta_path)
        # create the mdoel
        self.ctr_model = network_conf.CTRmodel(
            dnn_layer_dims,
            dnn_input_dim,
            lr_input_dim,
            model_type=args.model_type,
            is_infer=True)
        # load parameter
        logger.info("load model parameters from %s" % param_path)
        self.parameters = paddle.parameters.Parameters.from_tar(
            gzip.open(param_path, 'r'))
        self.inferer = paddle.inference.Inference(
            output_layer=self.ctr_model.model,
            parameters=self.parameters, )

    def infer(self, data_path):
        logger.info("infer data...")
        dataset = reader.Dataset()
        infer_reader = paddle.batch(
            dataset.infer(args.data_path), batch_size=1000)
        logger.warning('write predictions to %s' % args.prediction_output_path)
        output_f = open(args.prediction_output_path, 'w')
        for id, batch in enumerate(infer_reader()):
            res = self.inferer.infer(input=batch)
            predictions = [x for x in itertools.chain.from_iterable(res)]
            assert len(batch) == len(
                predictions), "predict error, %d inputs, but %d predictions" % (
                    len(batch), len(predictions))
            output_f.write('\n'.join(map(str, predictions)) + '\n')


if __name__ == '__main__':
    ctr_inferer = CTRInferer(args.model_gz_path)
    ctr_inferer.infer(args.data_path)