infer.py 4.5 KB
Newer Older
S
Superjom 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
import argparse
import itertools

import reader
import paddle.v2 as paddle
from network_conf import DSSM
from utils import logger, ModelType, ModelArch, load_dic

parser = argparse.ArgumentParser(description="PaddlePaddle DSSM infer")
parser.add_argument(
    '--model_path',
    type=str,
    required=True,
    help="path of model parameters file")
parser.add_argument(
    '-i',
    '--data_path',
    type=str,
    required=True,
    help="path of the dataset to infer")
parser.add_argument(
    '-o',
    '--prediction_output_path',
    type=str,
    required=True,
    help="path to output the prediction")
parser.add_argument(
    '-y',
    '--model_type',
    type=int,
    required=True,
    default=ModelType.CLASSIFICATION_MODE,
C
caoying03 已提交
33 34 35 36
    help=("model type, %d for classification, %d for pairwise rank, "
          "%d for regression (default: classification)") %
    (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
     ModelType.REGRESSION_MODE))
S
Superjom 已提交
37 38 39 40 41 42 43 44 45 46
parser.add_argument(
    '-s',
    '--source_dic_path',
    type=str,
    required=False,
    help="path of the source's word dic")
parser.add_argument(
    '--target_dic_path',
    type=str,
    required=False,
C
caoying03 已提交
47 48
    help=("path of the target's word dictionary, "
          "if not set, the `source_dic_path` will be used"))
S
Superjom 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
parser.add_argument(
    '-a',
    '--model_arch',
    type=int,
    required=True,
    default=ModelArch.CNN_MODE,
    help="model architecture, %d for CNN, %d for FC, %d for RNN" %
    (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE))
parser.add_argument(
    '--share_network_between_source_target',
    type=bool,
    default=False,
    help="whether to share network parameters between source and target")
parser.add_argument(
    '--share_embed',
    type=bool,
    default=False,
    help="whether to share word embedding between source and target")
parser.add_argument(
    '--dnn_dims',
    type=str,
    default='256,128,64,32',
C
caoying03 已提交
71 72 73
    help=("dimentions of dnn layers, default is '256,128,64,32', "
          "which means create a 4-layer dnn, "
          "demention of each layer is 256, 128, 64 and 32"))
S
Superjom 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87
parser.add_argument(
    '-c',
    '--class_num',
    type=int,
    default=0,
    help="number of categories for classification task.")

args = parser.parse_args()
args.model_type = ModelType(args.model_type)
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
    assert args.class_num > 1, "--class_num should be set in classification task."

layer_dims = map(int, args.dnn_dims.split(','))
C
caoying03 已提交
88 89
args.target_dic_path = args.source_dic_path if not args.target_dic_path \
        else args.target_dic_path
S
Superjom 已提交
90 91 92 93 94 95 96 97

paddle.init(use_gpu=False, trainer_count=1)


class Inferer(object):
    def __init__(self, param_path):
        logger.info("create DSSM model")

S
Superjom 已提交
98
        prediction = DSSM(
S
Superjom 已提交
99 100 101 102 103 104 105 106 107
            dnn_dims=layer_dims,
            vocab_sizes=[
                len(load_dic(path))
                for path in [args.source_dic_path, args.target_dic_path]
            ],
            model_type=args.model_type,
            model_arch=args.model_arch,
            share_semantic_generator=args.share_network_between_source_target,
            class_num=args.class_num,
S
Superjom 已提交
108 109
            share_embed=args.share_embed,
            is_infer=True)()
S
Superjom 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

        # load parameter
        logger.info("load model parameters from %s" % param_path)
        self.parameters = paddle.parameters.Parameters.from_tar(
            open(param_path, 'r'))
        self.inferer = paddle.inference.Inference(
            output_layer=prediction, parameters=self.parameters)

    def infer(self, data_path):
        logger.info("infer data...")
        dataset = reader.Dataset(
            train_path=data_path,
            test_path=None,
            source_dic_path=args.source_dic_path,
            target_dic_path=args.target_dic_path,
            model_type=args.model_type, )
        infer_reader = paddle.batch(dataset.infer, batch_size=1000)
        logger.warning('write predictions to %s' % args.prediction_output_path)

        output_f = open(args.prediction_output_path, 'w')

        for id, batch in enumerate(infer_reader()):
            res = self.inferer.infer(input=batch)
            predictions = [' '.join(map(str, x)) for x in res]
C
caoying03 已提交
134 135 136
            assert len(batch) == len(predictions), (
                "predict error, %d inputs, "
                "but %d predictions") % (len(batch), len(predictions))
S
Superjom 已提交
137 138 139 140 141 142
            output_f.write('\n'.join(map(str, predictions)) + '\n')


if __name__ == '__main__':
    inferer = Inferer(args.model_path)
    inferer.infer(args.data_path)