infer.py 11.8 KB
Newer Older
1
import argparse
2
import ast
3
import multiprocessing
4
import numpy as np
5
import os
Y
Yibing Liu 已提交
6 7
import sys
sys.path.append("../../models/neural_machine_translation/transformer/")
8
from functools import partial
9

10
import paddle
11 12
import paddle.fluid as fluid

13 14
import reader
from config import *
Y
Yibing Liu 已提交
15
from desc import *
16
from model import fast_decode as fast_decoder
17
from train import pad_batch_data, prepare_data_generator
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--test_file_pattern",
        type=str,
        required=True,
        help="The pattern to match test data files.")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=50,
        help="The number of examples in one run for sequence generation.")
    parser.add_argument(
        "--pool_size",
        type=int,
        default=10000,
        help="The buffer size to pool data.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
53 54
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
55
        type=lambda x: str(x.encode().decode("unicode-escape")),
56 57
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
58
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
    parser.add_argument(
        "--use_parallel_exe",
        type=ast.literal_eval,
        default=False,
        help="The flag indicating whether to use ParallelExecutor.")
74 75 76 77 78 79
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
80 81 82 83 84 85 86 87 88 89 90
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [InferTaskConfig, ModelHyperParams])
91
    return args
92 93


94 95 96 97 98 99 100 101 102 103 104 105 106 107
def post_process_seq(seq,
                     bos_idx=ModelHyperParams.bos_idx,
                     eos_idx=ModelHyperParams.eos_idx,
                     output_bos=InferTaskConfig.output_bos,
                     output_eos=InferTaskConfig.output_eos):
    """
    Post-process the beam-search decoded sequence. Truncate from the first
    <eos> and remove the <bos> and <eos> tokens currently.
    """
    eos_pos = len(seq) - 1
    for i, idx in enumerate(seq):
        if idx == eos_idx:
            eos_pos = i
            break
G
guosheng 已提交
108 109 110 111 112
    seq = [
        idx for idx in seq[:eos_pos + 1]
        if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
    ]
    return seq
113 114


Y
Yu Yang 已提交
115 116
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
                        d_model, place):
117
    """
118
    Put all padded data needed by beam search decoder into a dict.
119 120 121 122 123 124 125
    """
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
    # start tokens
    trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, 1, 1]).astype("float32")
Y
Yu Yang 已提交
126 127 128
    trg_word = trg_word.reshape(-1, 1, 1)
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
129 130 131 132 133 134 135 136 137 138 139

    def to_lodtensor(data, place, lod=None):
        data_tensor = fluid.LoDTensor()
        data_tensor.set(data, place)
        if lod is not None:
            data_tensor.set_lod(lod)
        return data_tensor

    # beamsearch_op must use tensors with lod
    init_score = to_lodtensor(
        np.zeros_like(
Y
Yu Yang 已提交
140
            trg_word, dtype="float32").reshape(-1, 1),
141 142
        place, [range(trg_word.shape[0] + 1)] * 2)
    trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
143
    init_idx = np.asarray(range(len(insts)), dtype="int32")
144 145 146 147

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
148
            init_idx, trg_src_attn_bias
149
        ]))
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    return data_input_dict


def prepare_feed_dict_list(data_generator, count, place):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.bos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model, place)
            feed_dict_list.append(data_input_dict)
    return feed_dict_list if len(feed_dict_list) == count else None


def py_reader_provider_wrapper(data_reader, place):
    """
    Data provider needed by fluid.layers.py_reader.
    """
174

175 176 177 178 179 180 181 182
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.bos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model, place)
            yield [data_input_dict[item] for item in data_input_names]
183

184
    return py_reader_provider
185

186 187

def fast_infer(args):
188 189 190
    """
    Inference by beam search decoder based solely on Fluid operators.
    """
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
    out_ids, out_scores, pyreader = fast_decoder(
        ModelHyperParams.src_vocab_size,
        ModelHyperParams.trg_vocab_size,
        ModelHyperParams.max_length + 1,
        ModelHyperParams.n_layer,
        ModelHyperParams.n_head,
        ModelHyperParams.d_key,
        ModelHyperParams.d_value,
        ModelHyperParams.d_model,
        ModelHyperParams.d_inner_hid,
        ModelHyperParams.prepostprocess_dropout,
        ModelHyperParams.attention_dropout,
        ModelHyperParams.relu_dropout,
        ModelHyperParams.preprocess_cmd,
        ModelHyperParams.postprocess_cmd,
        ModelHyperParams.weight_sharing,
        InferTaskConfig.beam_size,
        InferTaskConfig.max_out_len,
Y
Yibing Liu 已提交
209
        ModelHyperParams.bos_idx,
210 211 212 213 214
        ModelHyperParams.eos_idx,
        use_py_reader=args.use_py_reader)

    # This is used here to set dropout to the test mode.
    infer_program = fluid.default_main_program().clone(for_test=True)
215

216 217 218 219 220 221 222 223 224 225 226
    if args.use_mem_opt:
        fluid.memory_optimize(infer_program)

    if InferTaskConfig.use_gpu:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
227 228 229 230

    fluid.io.load_vars(
        exe,
        InferTaskConfig.model_path,
G
guosheng 已提交
231
        vars=[
232
            var for var in infer_program.list_vars()
G
guosheng 已提交
233 234
            if isinstance(var, fluid.framework.Parameter)
        ])
235

236 237 238 239 240 241 242 243 244 245
    exec_strategy = fluid.ExecutionStrategy()
    # For faster executor
    exec_strategy.use_experimental_executor = True
    exec_strategy.num_threads = 1
    build_strategy = fluid.BuildStrategy()
    infer_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        main_program=infer_program,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)
246

247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
    # data reader settings for inference
    args.train_file_pattern = args.test_file_pattern
    args.use_token_batch = False
    args.sort_type = reader.SortType.NONE
    args.shuffle = False
    args.shuffle_batch = False
    test_data = prepare_data_generator(
        args,
        is_test=False,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper,
        place=place)
    if args.use_py_reader:
        pyreader.start()
        data_generator = None
    else:
        data_generator = test_data()
    trg_idx2word = reader.DataReader.load_dict(
266
        dict_path=args.trg_vocab_fpath, reverse=True)
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283

    while True:
        try:
            feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
                                                    place)
            if args.use_parallel_exe:
                seq_ids, seq_scores = infer_exe.run(
                    fetch_list=[out_ids.name, out_scores.name],
                    feed=feed_dict_list,
                    return_numpy=False)
            else:
                seq_ids, seq_scores = exe.run(
                    program=infer_program,
                    fetch_list=[out_ids.name, out_scores.name],
                    feed=feed_dict_list[0]
                    if feed_dict_list is not None else None,
                    return_numpy=False,
284
                    use_program_cache=False)
Y
Yibing Liu 已提交
285 286
            seq_ids_list, seq_scores_list = [seq_ids], [
                seq_scores
287 288
            ] if isinstance(seq_ids,
                            paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
            for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
                # How to parse the results:
                #   Suppose the lod of seq_ids is:
                #     [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
                #   then from lod[0]:
                #     there are 2 source sentences, beam width is 3.
                #   from lod[1]:
                #     the first source sentence has 3 hyps; the lengths are 12, 12, 16
                #     the second source sentence has 3 hyps; the lengths are 14, 13, 15
                hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)]
                scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)]
                for i in range(len(seq_ids.lod()[0]) -
                               1):  # for each source sentence
                    start = seq_ids.lod()[0][i]
                    end = seq_ids.lod()[0][i + 1]
                    for j in range(end - start):  # for each candidate
                        sub_start = seq_ids.lod()[1][start + j]
                        sub_end = seq_ids.lod()[1][start + j + 1]
                        hyps[i].append(" ".join([
                            trg_idx2word[idx]
                            for idx in post_process_seq(
                                np.array(seq_ids)[sub_start:sub_end])
                        ]))
                        scores[i].append(np.array(seq_scores)[sub_end - 1])
                        print(hyps[i][-1])
                        if len(hyps[i]) >= InferTaskConfig.n_best:
                            break
        except (StopIteration, fluid.core.EOFException):
            # The data pass is over.
            if args.use_py_reader:
                pyreader.reset()
            break
321 322


323
if __name__ == "__main__":
324
    args = parse_args()
325
    fast_infer(args)