infer.py 4.6 KB
Newer Older
S
Superjom 已提交
1 2
import argparse
import itertools
R
ranqiu 已提交
3
import distutils.util
S
Superjom 已提交
4 5 6 7 8 9 10 11

import reader
import paddle.v2 as paddle
from network_conf import DSSM
from utils import logger, ModelType, ModelArch, load_dic

parser = argparse.ArgumentParser(description="PaddlePaddle DSSM infer")
parser.add_argument(
C
caoying03 已提交
12
    "--model_path", type=str, required=True, help="The path of trained model.")
S
Superjom 已提交
13
parser.add_argument(
C
caoying03 已提交
14 15
    "-i",
    "--data_path",
S
Superjom 已提交
16 17
    type=str,
    required=True,
C
caoying03 已提交
18
    help="The path of the data for inferring.")
S
Superjom 已提交
19
parser.add_argument(
C
caoying03 已提交
20 21
    "-o",
    "--prediction_output_path",
S
Superjom 已提交
22 23
    type=str,
    required=True,
C
caoying03 已提交
24
    help="The path to save the predictions.")
S
Superjom 已提交
25
parser.add_argument(
C
caoying03 已提交
26 27
    "-y",
    "--model_type",
S
Superjom 已提交
28 29 30
    type=int,
    required=True,
    default=ModelType.CLASSIFICATION_MODE,
C
caoying03 已提交
31 32
    help=("The model type: %d for classification, %d for pairwise rank, "
          "%d for regression (default: classification).") %
C
caoying03 已提交
33 34
    (ModelType.CLASSIFICATION_MODE, ModelType.RANK_MODE,
     ModelType.REGRESSION_MODE))
S
Superjom 已提交
35
parser.add_argument(
C
caoying03 已提交
36 37
    "-s",
    "--source_dic_path",
S
Superjom 已提交
38 39
    type=str,
    required=False,
C
caoying03 已提交
40
    help="The path of the source's word dictionary.")
S
Superjom 已提交
41
parser.add_argument(
C
caoying03 已提交
42
    "--target_dic_path",
S
Superjom 已提交
43 44
    type=str,
    required=False,
C
caoying03 已提交
45 46
    help=("The path of the target's word dictionary, "
          "if this parameter is not set, the `source_dic_path` will be used."))
S
Superjom 已提交
47
parser.add_argument(
C
caoying03 已提交
48 49
    "-a",
    "--model_arch",
S
Superjom 已提交
50 51 52 53 54 55
    type=int,
    required=True,
    default=ModelArch.CNN_MODE,
    help="model architecture, %d for CNN, %d for FC, %d for RNN" %
    (ModelArch.CNN_MODE, ModelArch.FC_MODE, ModelArch.RNN_MODE))
parser.add_argument(
C
caoying03 已提交
56
    "--share_network_between_source_target",
R
ranqiu 已提交
57
    type=distutils.util.strtobool,
S
Superjom 已提交
58 59 60
    default=False,
    help="whether to share network parameters between source and target")
parser.add_argument(
C
caoying03 已提交
61
    "--share_embed",
R
ranqiu 已提交
62
    type=distutils.util.strtobool,
S
Superjom 已提交
63 64 65
    default=False,
    help="whether to share word embedding between source and target")
parser.add_argument(
C
caoying03 已提交
66
    "--dnn_dims",
S
Superjom 已提交
67
    type=str,
C
caoying03 已提交
68
    default="256,128,64,32",
C
caoying03 已提交
69 70 71
    help=("The dimentions of dnn layers, default is `256,128,64,32`, "
          "which means a dnn with 4 layers with "
          "dmentions 256, 128, 64 and 32 will be created."))
S
Superjom 已提交
72
parser.add_argument(
C
caoying03 已提交
73 74
    "-c",
    "--class_num",
S
Superjom 已提交
75 76
    type=int,
    default=0,
C
caoying03 已提交
77
    help="The number of categories for classification task.")
S
Superjom 已提交
78 79 80 81 82

args = parser.parse_args()
args.model_type = ModelType(args.model_type)
args.model_arch = ModelArch(args.model_arch)
if args.model_type.is_classification():
C
caoying03 已提交
83 84
    assert args.class_num > 1, ("The parameter class_num should be set "
                                "in classification task.")
S
Superjom 已提交
85

C
caoying03 已提交
86
layer_dims = map(int, args.dnn_dims.split(","))
C
caoying03 已提交
87 88
args.target_dic_path = args.source_dic_path if not args.target_dic_path \
        else args.target_dic_path
S
Superjom 已提交
89 90 91 92 93 94

paddle.init(use_gpu=False, trainer_count=1)


class Inferer(object):
    def __init__(self, param_path):
S
Superjom 已提交
95
        prediction = DSSM(
S
Superjom 已提交
96 97 98 99 100 101 102 103 104
            dnn_dims=layer_dims,
            vocab_sizes=[
                len(load_dic(path))
                for path in [args.source_dic_path, args.target_dic_path]
            ],
            model_type=args.model_type,
            model_arch=args.model_arch,
            share_semantic_generator=args.share_network_between_source_target,
            class_num=args.class_num,
S
Superjom 已提交
105 106
            share_embed=args.share_embed,
            is_infer=True)()
S
Superjom 已提交
107 108

        # load parameter
C
caoying03 已提交
109
        logger.info("Load the trained model from %s." % param_path)
S
Superjom 已提交
110
        self.parameters = paddle.parameters.Parameters.from_tar(
C
caoying03 已提交
111
            open(param_path, "r"))
S
Superjom 已提交
112 113 114 115 116 117 118 119 120 121 122
        self.inferer = paddle.inference.Inference(
            output_layer=prediction, parameters=self.parameters)

    def infer(self, data_path):
        dataset = reader.Dataset(
            train_path=data_path,
            test_path=None,
            source_dic_path=args.source_dic_path,
            target_dic_path=args.target_dic_path,
            model_type=args.model_type, )
        infer_reader = paddle.batch(dataset.infer, batch_size=1000)
C
caoying03 已提交
123
        logger.warning("Write predictions to %s." % args.prediction_output_path)
S
Superjom 已提交
124

C
caoying03 已提交
125
        output_f = open(args.prediction_output_path, "w")
S
Superjom 已提交
126 127 128

        for id, batch in enumerate(infer_reader()):
            res = self.inferer.infer(input=batch)
C
caoying03 已提交
129
            predictions = [" ".join(map(str, x)) for x in res]
C
caoying03 已提交
130
            assert len(batch) == len(predictions), (
C
caoying03 已提交
131 132 133 134
                "Error! %d inputs are given, "
                "but only %d predictions are returned.") % (len(batch),
                                                            len(predictions))
            output_f.write("\n".join(map(str, predictions)) + "\n")
S
Superjom 已提交
135 136


C
caoying03 已提交
137
if __name__ == "__main__":
S
Superjom 已提交
138 139
    inferer = Inferer(args.model_path)
    inferer.infer(args.data_path)