infer.py 11.9 KB
Newer Older
1
import argparse
2
import ast
3
import multiprocessing
4
import numpy as np
5
import os
Y
Yibing Liu 已提交
6
import sys
7
sys.path.append("../../")
Y
Yibing Liu 已提交
8
sys.path.append("../../models/neural_machine_translation/transformer/")
9
from functools import partial
10

11
import paddle
12 13
import paddle.fluid as fluid

14
from models.model_check import check_cuda
15 16
import reader
from config import *
Y
Yibing Liu 已提交
17
from desc import *
18
from model import fast_decode as fast_decoder
19
from train import pad_batch_data, prepare_data_generator
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--test_file_pattern",
        type=str,
        required=True,
        help="The pattern to match test data files.")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=50,
        help="The number of examples in one run for sequence generation.")
    parser.add_argument(
        "--pool_size",
        type=int,
        default=10000,
        help="The buffer size to pool data.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
55 56
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
57
        type=lambda x: str(x.encode().decode("unicode-escape")),
58 59
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
60
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
    parser.add_argument(
        "--use_parallel_exe",
        type=ast.literal_eval,
        default=False,
        help="The flag indicating whether to use ParallelExecutor.")
76 77 78 79 80 81
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
82 83 84 85 86 87 88 89 90 91 92
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [InferTaskConfig, ModelHyperParams])
93
    return args
94 95


96 97 98 99 100 101 102 103 104 105 106 107 108 109
def post_process_seq(seq,
                     bos_idx=ModelHyperParams.bos_idx,
                     eos_idx=ModelHyperParams.eos_idx,
                     output_bos=InferTaskConfig.output_bos,
                     output_eos=InferTaskConfig.output_eos):
    """
    Post-process the beam-search decoded sequence. Truncate from the first
    <eos> and remove the <bos> and <eos> tokens currently.
    """
    eos_pos = len(seq) - 1
    for i, idx in enumerate(seq):
        if idx == eos_idx:
            eos_pos = i
            break
G
guosheng 已提交
110 111 112 113 114
    seq = [
        idx for idx in seq[:eos_pos + 1]
        if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
    ]
    return seq
115 116


Y
Yu Yang 已提交
117 118
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
                        d_model, place):
119
    """
120
    Put all padded data needed by beam search decoder into a dict.
121 122 123 124 125 126 127
    """
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
    # start tokens
    trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, 1, 1]).astype("float32")
Y
Yu Yang 已提交
128 129 130
    trg_word = trg_word.reshape(-1, 1, 1)
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
131 132 133 134 135 136 137 138 139 140 141

    def to_lodtensor(data, place, lod=None):
        data_tensor = fluid.LoDTensor()
        data_tensor.set(data, place)
        if lod is not None:
            data_tensor.set_lod(lod)
        return data_tensor

    # beamsearch_op must use tensors with lod
    init_score = to_lodtensor(
        np.zeros_like(
Y
Yu Yang 已提交
142
            trg_word, dtype="float32").reshape(-1, 1),
143 144
        place, [range(trg_word.shape[0] + 1)] * 2)
    trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
145
    init_idx = np.asarray(range(len(insts)), dtype="int32")
146 147 148 149

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
150
            init_idx, trg_src_attn_bias
151
        ]))
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    return data_input_dict


def prepare_feed_dict_list(data_generator, count, place):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.bos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model, place)
            feed_dict_list.append(data_input_dict)
    return feed_dict_list if len(feed_dict_list) == count else None


def py_reader_provider_wrapper(data_reader, place):
    """
    Data provider needed by fluid.layers.py_reader.
    """
176

177 178 179 180 181 182 183 184
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.bos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model, place)
            yield [data_input_dict[item] for item in data_input_names]
185

186
    return py_reader_provider
187

188 189

def fast_infer(args):
190 191 192
    """
    Inference by beam search decoder based solely on Fluid operators.
    """
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    out_ids, out_scores, pyreader = fast_decoder(
        ModelHyperParams.src_vocab_size,
        ModelHyperParams.trg_vocab_size,
        ModelHyperParams.max_length + 1,
        ModelHyperParams.n_layer,
        ModelHyperParams.n_head,
        ModelHyperParams.d_key,
        ModelHyperParams.d_value,
        ModelHyperParams.d_model,
        ModelHyperParams.d_inner_hid,
        ModelHyperParams.prepostprocess_dropout,
        ModelHyperParams.attention_dropout,
        ModelHyperParams.relu_dropout,
        ModelHyperParams.preprocess_cmd,
        ModelHyperParams.postprocess_cmd,
        ModelHyperParams.weight_sharing,
        InferTaskConfig.beam_size,
        InferTaskConfig.max_out_len,
Y
Yibing Liu 已提交
211
        ModelHyperParams.bos_idx,
212 213 214 215 216
        ModelHyperParams.eos_idx,
        use_py_reader=args.use_py_reader)

    # This is used here to set dropout to the test mode.
    infer_program = fluid.default_main_program().clone(for_test=True)
217

218 219 220 221
    if args.use_mem_opt:
        fluid.memory_optimize(infer_program)

    if InferTaskConfig.use_gpu:
222
        check_cuda(InferTaskConfig.use_gpu)
223 224 225 226 227 228 229
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
230 231 232 233

    fluid.io.load_vars(
        exe,
        InferTaskConfig.model_path,
G
guosheng 已提交
234
        vars=[
235
            var for var in infer_program.list_vars()
G
guosheng 已提交
236 237
            if isinstance(var, fluid.framework.Parameter)
        ])
238

239 240 241 242 243 244 245 246 247 248
    exec_strategy = fluid.ExecutionStrategy()
    # For faster executor
    exec_strategy.use_experimental_executor = True
    exec_strategy.num_threads = 1
    build_strategy = fluid.BuildStrategy()
    infer_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        main_program=infer_program,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)
249

250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    # data reader settings for inference
    args.train_file_pattern = args.test_file_pattern
    args.use_token_batch = False
    args.sort_type = reader.SortType.NONE
    args.shuffle = False
    args.shuffle_batch = False
    test_data = prepare_data_generator(
        args,
        is_test=False,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper,
        place=place)
    if args.use_py_reader:
        pyreader.start()
        data_generator = None
    else:
        data_generator = test_data()
    trg_idx2word = reader.DataReader.load_dict(
269
        dict_path=args.trg_vocab_fpath, reverse=True)
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286

    while True:
        try:
            feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
                                                    place)
            if args.use_parallel_exe:
                seq_ids, seq_scores = infer_exe.run(
                    fetch_list=[out_ids.name, out_scores.name],
                    feed=feed_dict_list,
                    return_numpy=False)
            else:
                seq_ids, seq_scores = exe.run(
                    program=infer_program,
                    fetch_list=[out_ids.name, out_scores.name],
                    feed=feed_dict_list[0]
                    if feed_dict_list is not None else None,
                    return_numpy=False,
287
                    use_program_cache=False)
Y
Yibing Liu 已提交
288 289
            seq_ids_list, seq_scores_list = [seq_ids], [
                seq_scores
290 291
            ] if isinstance(seq_ids,
                            paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
            for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
                # How to parse the results:
                #   Suppose the lod of seq_ids is:
                #     [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
                #   then from lod[0]:
                #     there are 2 source sentences, beam width is 3.
                #   from lod[1]:
                #     the first source sentence has 3 hyps; the lengths are 12, 12, 16
                #     the second source sentence has 3 hyps; the lengths are 14, 13, 15
                hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)]
                scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)]
                for i in range(len(seq_ids.lod()[0]) -
                               1):  # for each source sentence
                    start = seq_ids.lod()[0][i]
                    end = seq_ids.lod()[0][i + 1]
                    for j in range(end - start):  # for each candidate
                        sub_start = seq_ids.lod()[1][start + j]
                        sub_end = seq_ids.lod()[1][start + j + 1]
                        hyps[i].append(" ".join([
                            trg_idx2word[idx]
                            for idx in post_process_seq(
                                np.array(seq_ids)[sub_start:sub_end])
                        ]))
                        scores[i].append(np.array(seq_scores)[sub_end - 1])
                        print(hyps[i][-1])
                        if len(hyps[i]) >= InferTaskConfig.n_best:
                            break
        except (StopIteration, fluid.core.EOFException):
            # The data pass is over.
            if args.use_py_reader:
                pyreader.reset()
            break
324 325


326
if __name__ == "__main__":
327
    args = parse_args()
328
    fast_infer(args)