infer.py 7.4 KB
Newer Older
R
ranqiu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#coding=utf-8

import sys
import argparse
import distutils.util
import gzip

import paddle.v2 as paddle
from model import conv_seq2seq
from beamsearch import BeamSearch
import reader


def parse_args():
    parser = argparse.ArgumentParser(
        description="PaddlePaddle Convolutional Seq2Seq")
    parser.add_argument(
        '--infer_data_path',
        type=str,
        required=True,
        help="Path of the dataset for inference")
    parser.add_argument(
        '--src_dict_path',
        type=str,
        required=True,
        help='Path of the source dictionary')
    parser.add_argument(
        '--trg_dict_path',
        type=str,
        required=True,
        help='path of the target dictionary')
    parser.add_argument(
        '--enc_blocks', type=str, help='Convolution blocks of the encoder')
    parser.add_argument(
        '--dec_blocks', type=str, help='Convolution blocks of the decoder')
    parser.add_argument(
        '--emb_size',
        type=int,
R
ranqiu 已提交
39
        default=256,
R
ranqiu 已提交
40 41 42 43 44 45 46 47 48 49 50
        help='Dimension of word embedding. (default: %(default)s)')
    parser.add_argument(
        '--pos_size',
        type=int,
        default=200,
        help='Total number of the position indexes. (default: %(default)s)')
    parser.add_argument(
        '--drop_rate',
        type=float,
        default=0.,
        help='Dropout rate. (default: %(default)s)')
R
ranqiu 已提交
51 52 53 54 55
    parser.add_argument(
        "--use_bn",
        default=False,
        type=distutils.util.strtobool,
        help="Use batch normalization or not. (default: %(default)s)")
R
ranqiu 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    parser.add_argument(
        "--use_gpu",
        default=False,
        type=distutils.util.strtobool,
        help="Use gpu or not. (default: %(default)s)")
    parser.add_argument(
        "--trainer_count",
        default=1,
        type=int,
        help="Trainer number. (default: %(default)s)")
    parser.add_argument(
        '--max_len',
        type=int,
        default=100,
        help="The maximum length of the sentence to be generated. (default: %(default)s)"
    )
R
ranqiu 已提交
72 73 74 75 76
    parser.add_argument(
        "--batch_size",
        default=1,
        type=int,
        help="Size of a mini-batch. (default: %(default)s)")
R
ranqiu 已提交
77 78 79 80
    parser.add_argument(
        "--beam_size",
        default=1,
        type=int,
R
ranqiu 已提交
81
        help="The width of beam expansion. (default: %(default)s)")
R
ranqiu 已提交
82 83 84 85 86
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="The path of trained model. (default: %(default)s)")
R
ranqiu 已提交
87 88 89 90 91
    parser.add_argument(
        "--is_show_attention",
        default=False,
        type=distutils.util.strtobool,
        help="Whether to show attention weight or not. (default: %(default)s)")
R
ranqiu 已提交
92 93 94 95 96 97 98 99 100
    return parser.parse_args()


def infer(infer_data_path,
          src_dict_path,
          trg_dict_path,
          model_path,
          enc_conv_blocks,
          dec_conv_blocks,
R
ranqiu 已提交
101
          emb_dim=256,
R
ranqiu 已提交
102 103
          pos_size=200,
          drop_rate=0.,
R
ranqiu 已提交
104
          use_bn=False,
R
ranqiu 已提交
105
          max_len=100,
R
ranqiu 已提交
106 107 108
          batch_size=1,
          beam_size=1,
          is_show_attention=False):
R
ranqiu 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    """
    Inference.

    :param infer_data_path: The path of the data for inference.
    :type infer_data_path: str
    :param src_dict_path: The path of the source dictionary.
    :type src_dict_path: str
    :param trg_dict_path: The path of the target dictionary.
    :type trg_dict_path: str
    :param model_path: The path of a trained model.
    :type model_path: str
    :param enc_conv_blocks: The scale list of the encoder's convolution blocks. And each element of
                            the list contains output dimension and context length of the corresponding
                            convolution block.
    :type enc_conv_blocks: list of tuple
    :param dec_conv_blocks: The scale list of the decoder's convolution blocks. And each element of
                            the list contains output dimension and context length of the corresponding
                            convolution block.
    :type dec_conv_blocks: list of tuple
    :param emb_dim: The dimension of the embedding vector.
    :type emb_dim: int
    :param pos_size: The total number of the position indexes, which means
                     the maximum value of the index is pos_size - 1.
    :type pos_size: int
    :param drop_rate: Dropout rate.
    :type drop_rate: float
R
ranqiu 已提交
135 136
    :param use_bn: Whether to use batch normalization or not. False is the default value.
    :type use_bn: bool
R
ranqiu 已提交
137 138 139 140
    :param max_len: The maximum length of the sentence to be generated.
    :type max_len: int
    :param beam_size: The width of beam expansion.
    :type beam_size: int
R
ranqiu 已提交
141 142
    :param is_show_attention: Whether to show attention weight or not. False is the default value.
    :type is_show_attention: bool
R
ranqiu 已提交
143 144 145 146 147 148 149
    """
    # load dict
    src_dict = reader.load_dict(src_dict_path)
    trg_dict = reader.load_dict(trg_dict_path)
    src_dict_size = src_dict.__len__()
    trg_dict_size = trg_dict.__len__()

R
ranqiu 已提交
150
    prob, weight = conv_seq2seq(
R
ranqiu 已提交
151 152 153 154 155 156 157
        src_dict_size=src_dict_size,
        trg_dict_size=trg_dict_size,
        pos_size=pos_size,
        emb_dim=emb_dim,
        enc_conv_blocks=enc_conv_blocks,
        dec_conv_blocks=dec_conv_blocks,
        drop_rate=drop_rate,
R
ranqiu 已提交
158
        with_bn=use_bn,
R
ranqiu 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172
        is_infer=True)

    # load parameters
    parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))

    padding_list = [context_len - 1 for (size, context_len) in dec_conv_blocks]
    padding_num = reduce(lambda x, y: x + y, padding_list)
    infer_reader = reader.data_reader(
        data_file=infer_data_path,
        src_dict=src_dict,
        trg_dict=trg_dict,
        pos_size=pos_size,
        padding_num=padding_num)

R
ranqiu 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    if is_show_attention:
        attention_inferer = paddle.inference.Inference(
            output_layer=weight, parameters=parameters)
        for i, data in enumerate(infer_reader()):
            src_len = len(data[0])
            trg_len = len(data[2])
            attention_weight = attention_inferer.infer(
                [data], field='value', flatten_result=False)
            attention_weight = [
                weight.reshape((trg_len, src_len))
                for weight in attention_weight
            ]
            print attention_weight
            break
        return

    infer_data = []
    for i, raw_data in enumerate(infer_reader()):
        infer_data.append([raw_data[0], raw_data[1]])

R
ranqiu 已提交
193 194 195 196 197 198 199 200 201
    inferer = paddle.inference.Inference(
        output_layer=prob, parameters=parameters)

    searcher = BeamSearch(
        inferer=inferer,
        trg_dict=trg_dict,
        pos_size=pos_size,
        padding_num=padding_num,
        max_len=max_len,
R
ranqiu 已提交
202
        batch_size=batch_size,
R
ranqiu 已提交
203 204
        beam_size=beam_size)

R
ranqiu 已提交
205
    searcher.search(infer_data)
R
ranqiu 已提交
206 207 208 209 210 211 212 213
    return


def main():
    args = parse_args()
    enc_conv_blocks = eval(args.enc_blocks)
    dec_conv_blocks = eval(args.dec_blocks)

R
ranqiu 已提交
214 215
    sys.setrecursionlimit(10000)

R
ranqiu 已提交
216 217 218 219 220 221 222 223 224 225 226 227
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)

    infer(
        infer_data_path=args.infer_data_path,
        src_dict_path=args.src_dict_path,
        trg_dict_path=args.trg_dict_path,
        model_path=args.model_path,
        enc_conv_blocks=enc_conv_blocks,
        dec_conv_blocks=dec_conv_blocks,
        emb_dim=args.emb_size,
        pos_size=args.pos_size,
        drop_rate=args.drop_rate,
R
ranqiu 已提交
228
        use_bn=args.use_bn,
R
ranqiu 已提交
229
        max_len=args.max_len,
R
ranqiu 已提交
230 231 232
        batch_size=args.batch_size,
        beam_size=args.beam_size,
        is_show_attention=args.is_show_attention)
R
ranqiu 已提交
233 234 235 236


if __name__ == '__main__':
    main()