infer.py 9.1 KB
Newer Older
1
import argparse
2
import ast
3
import numpy as np
4
from functools import partial
5

6
import paddle
7 8 9 10 11
import paddle.fluid as fluid

import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
12
from model import fast_decode as fast_decoder
13
from config import *
14
from train import pad_batch_data
15
import reader
16
import util
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--test_file_pattern",
        type=str,
        required=True,
        help="The pattern to match test data files.")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=50,
        help="The number of examples in one run for sequence generation.")
    parser.add_argument(
        "--pool_size",
        type=int,
        default=10000,
        help="The buffer size to pool data.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
52 53 54 55
    parser.add_argument(
        "--use_wordpiece",
        type=ast.literal_eval,
        default=False,
56
        help="The flag indicating if the data in wordpiece. The EN-FR data "
57
        "we provided is wordpiece data. For wordpiece data, converting ids to "
58 59
        "original words is a little different and some special codes are "
        "provided in util.py to do this.")
60 61
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
62
        type=lambda x: str(x.encode().decode("unicode-escape")),
63 64 65 66
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
        "For EN-DE BPE data we provided, use spaces as token delimiter.; "
        "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
67 68 69 70 71 72
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
73 74 75 76 77 78 79 80 81 82 83
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [InferTaskConfig, ModelHyperParams])
84
    return args
85 86


87 88 89 90 91 92 93 94 95 96 97 98 99 100
def post_process_seq(seq,
                     bos_idx=ModelHyperParams.bos_idx,
                     eos_idx=ModelHyperParams.eos_idx,
                     output_bos=InferTaskConfig.output_bos,
                     output_eos=InferTaskConfig.output_eos):
    """
    Post-process the beam-search decoded sequence. Truncate from the first
    <eos> and remove the <bos> and <eos> tokens currently.
    """
    eos_pos = len(seq) - 1
    for i, idx in enumerate(seq):
        if idx == eos_idx:
            eos_pos = i
            break
G
guosheng 已提交
101 102 103 104 105
    seq = [
        idx for idx in seq[:eos_pos + 1]
        if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
    ]
    return seq
106 107


Y
Yu Yang 已提交
108 109
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
                        d_model, place):
110
    """
111
    Put all padded data needed by beam search decoder into a dict.
112 113 114 115 116 117 118
    """
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
    # start tokens
    trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, 1, 1]).astype("float32")
Y
Yu Yang 已提交
119 120 121
    trg_word = trg_word.reshape(-1, 1, 1)
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
122 123 124 125 126 127 128 129 130 131 132

    def to_lodtensor(data, place, lod=None):
        data_tensor = fluid.LoDTensor()
        data_tensor.set(data, place)
        if lod is not None:
            data_tensor.set_lod(lod)
        return data_tensor

    # beamsearch_op must use tensors with lod
    init_score = to_lodtensor(
        np.zeros_like(
Y
Yu Yang 已提交
133
            trg_word, dtype="float32").reshape(-1, 1),
134 135 136 137 138 139 140 141 142
        place, [range(trg_word.shape[0] + 1)] * 2)
    trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
            trg_src_attn_bias
        ]))

Y
Yu Yang 已提交
143
    input_dict = dict(data_input_dict.items())
144 145 146
    return input_dict


147
def fast_infer(test_data, trg_idx2word, use_wordpiece):
148 149 150
    """
    Inference by beam search decoder based solely on Fluid operators.
    """
151 152 153
    place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

154
    out_ids, out_scores = fast_decoder(
155 156 157 158
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
159 160 161
        ModelHyperParams.d_inner_hid, ModelHyperParams.prepostprocess_dropout,
        ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
        ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
162 163
        ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
        InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
164 165 166 167

    fluid.io.load_vars(
        exe,
        InferTaskConfig.model_path,
G
guosheng 已提交
168 169 170 171
        vars=[
            var for var in fluid.default_main_program().list_vars()
            if isinstance(var, fluid.framework.Parameter)
        ])
172 173

    # This is used here to set dropout to the test mode.
G
guosheng 已提交
174
    infer_program = fluid.default_main_program().clone(for_test=True)
175 176 177 178 179 180 181 182

    for batch_id, data in enumerate(test_data.batch_generator()):
        data_input = prepare_batch_input(
            data, encoder_data_input_fields + fast_decoder_data_input_fields,
            ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
            ModelHyperParams.n_head, ModelHyperParams.d_model, place)
        seq_ids, seq_scores = exe.run(infer_program,
                                      feed=data_input,
183
                                      fetch_list=[out_ids, out_scores],
184
                                      return_numpy=False)
185 186 187 188 189 190 191 192
        # How to parse the results:
        #   Suppose the lod of seq_ids is:
        #     [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
        #   then from lod[0]:
        #     there are 2 source sentences, beam width is 3.
        #   from lod[1]:
        #     the first source sentence has 3 hyps; the lengths are 12, 12, 16
        #     the second source sentence has 3 hyps; the lengths are 14, 13, 15
193
        hyps = [[] for i in range(len(data))]
194
        scores = [[] for i in range(len(data))]
195 196 197 198 199 200 201 202
        for i in range(len(seq_ids.lod()[0]) - 1):  # for each source sentence
            start = seq_ids.lod()[0][i]
            end = seq_ids.lod()[0][i + 1]
            for j in range(end - start):  # for each candidate
                sub_start = seq_ids.lod()[1][start + j]
                sub_end = seq_ids.lod()[1][start + j + 1]
                hyps[i].append(" ".join([
                    trg_idx2word[idx]
203 204
                    for idx in post_process_seq(
                        np.array(seq_ids)[sub_start:sub_end])
205
                ]) if not use_wordpiece else util.subtoken_ids_to_str(
206 207
                    post_process_seq(np.array(seq_ids)[sub_start:sub_end]),
                    trg_idx2word))
208
                scores[i].append(np.array(seq_scores)[sub_end - 1])
G
guosheng 已提交
209
                print(hyps[i][-1])
210 211 212 213 214 215 216 217 218 219
                if len(hyps[i]) >= InferTaskConfig.n_best:
                    break


def infer(args, inferencer=fast_infer):
    place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
    test_data = reader.DataReader(
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        fpattern=args.test_file_pattern,
220
        token_delimiter=args.token_delimiter,
221
        use_token_batch=False,
222
        batch_size=args.batch_size,
223 224 225 226 227 228 229
        pool_size=args.pool_size,
        sort_type=reader.SortType.NONE,
        shuffle=False,
        shuffle_batch=False,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
230 231
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
232 233 234
        clip_last_batch=False)
    trg_idx2word = test_data.load_dict(
        dict_path=args.trg_vocab_fpath, reverse=True)
235
    inferencer(test_data, trg_idx2word, args.use_wordpiece)
236 237


238
if __name__ == "__main__":
239
    args = parse_args()
240
    infer(args)