infer.py 8.4 KB
Newer Older
1
import argparse
2
import ast
3
import numpy as np
4
from functools import partial
5

6
import paddle
7 8 9 10 11
import paddle.fluid as fluid

import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
12
from model import fast_decode as fast_decoder
13
from config import *
14
from train import pad_batch_data
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
import reader


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--test_file_pattern",
        type=str,
        required=True,
        help="The pattern to match test data files.")
    parser.add_argument(
        "--batch_size",
        type=int,
        default=50,
        help="The number of examples in one run for sequence generation.")
    parser.add_argument(
        "--pool_size",
        type=int,
        default=10000,
        help="The buffer size to pool data.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
51 52
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
53
        type=lambda x: str(x.encode().decode("unicode-escape")),
54 55
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
56
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
57 58 59 60 61 62
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
63 64 65 66 67 68 69 70 71 72 73
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [InferTaskConfig, ModelHyperParams])
74
    return args
75 76


77 78 79 80 81 82 83 84 85 86 87 88 89 90
def post_process_seq(seq,
                     bos_idx=ModelHyperParams.bos_idx,
                     eos_idx=ModelHyperParams.eos_idx,
                     output_bos=InferTaskConfig.output_bos,
                     output_eos=InferTaskConfig.output_eos):
    """
    Post-process the beam-search decoded sequence. Truncate from the first
    <eos> and remove the <bos> and <eos> tokens currently.
    """
    eos_pos = len(seq) - 1
    for i, idx in enumerate(seq):
        if idx == eos_idx:
            eos_pos = i
            break
G
guosheng 已提交
91 92 93 94 95
    seq = [
        idx for idx in seq[:eos_pos + 1]
        if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
    ]
    return seq
96 97


Y
Yu Yang 已提交
98 99
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
                        d_model, place):
100
    """
101
    Put all padded data needed by beam search decoder into a dict.
102 103 104 105 106 107 108
    """
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
    # start tokens
    trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, 1, 1]).astype("float32")
Y
Yu Yang 已提交
109 110 111
    trg_word = trg_word.reshape(-1, 1, 1)
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
112 113 114 115 116 117 118 119 120 121 122

    def to_lodtensor(data, place, lod=None):
        data_tensor = fluid.LoDTensor()
        data_tensor.set(data, place)
        if lod is not None:
            data_tensor.set_lod(lod)
        return data_tensor

    # beamsearch_op must use tensors with lod
    init_score = to_lodtensor(
        np.zeros_like(
Y
Yu Yang 已提交
123
            trg_word, dtype="float32").reshape(-1, 1),
124 125 126 127 128 129 130 131 132
        place, [range(trg_word.shape[0] + 1)] * 2)
    trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
            trg_src_attn_bias
        ]))

Y
Yu Yang 已提交
133
    input_dict = dict(data_input_dict.items())
134 135 136
    return input_dict


137
def fast_infer(test_data, trg_idx2word):
138 139 140
    """
    Inference by beam search decoder based solely on Fluid operators.
    """
141 142 143
    place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

144
    out_ids, out_scores = fast_decoder(
145 146 147 148
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
149 150 151
        ModelHyperParams.d_inner_hid, ModelHyperParams.prepostprocess_dropout,
        ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
        ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
152 153
        ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
        InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
154 155 156 157

    fluid.io.load_vars(
        exe,
        InferTaskConfig.model_path,
G
guosheng 已提交
158 159 160 161
        vars=[
            var for var in fluid.default_main_program().list_vars()
            if isinstance(var, fluid.framework.Parameter)
        ])
162 163

    # This is used here to set dropout to the test mode.
G
guosheng 已提交
164
    infer_program = fluid.default_main_program().clone(for_test=True)
165 166 167 168 169 170 171 172

    for batch_id, data in enumerate(test_data.batch_generator()):
        data_input = prepare_batch_input(
            data, encoder_data_input_fields + fast_decoder_data_input_fields,
            ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
            ModelHyperParams.n_head, ModelHyperParams.d_model, place)
        seq_ids, seq_scores = exe.run(infer_program,
                                      feed=data_input,
173
                                      fetch_list=[out_ids, out_scores],
174
                                      return_numpy=False)
175 176 177 178 179 180 181 182
        # How to parse the results:
        #   Suppose the lod of seq_ids is:
        #     [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
        #   then from lod[0]:
        #     there are 2 source sentences, beam width is 3.
        #   from lod[1]:
        #     the first source sentence has 3 hyps; the lengths are 12, 12, 16
        #     the second source sentence has 3 hyps; the lengths are 14, 13, 15
183
        hyps = [[] for i in range(len(data))]
184
        scores = [[] for i in range(len(data))]
185 186 187 188 189 190 191 192
        for i in range(len(seq_ids.lod()[0]) - 1):  # for each source sentence
            start = seq_ids.lod()[0][i]
            end = seq_ids.lod()[0][i + 1]
            for j in range(end - start):  # for each candidate
                sub_start = seq_ids.lod()[1][start + j]
                sub_end = seq_ids.lod()[1][start + j + 1]
                hyps[i].append(" ".join([
                    trg_idx2word[idx]
193 194
                    for idx in post_process_seq(
                        np.array(seq_ids)[sub_start:sub_end])
195
                ]))
196
                scores[i].append(np.array(seq_scores)[sub_end - 1])
G
guosheng 已提交
197
                print(hyps[i][-1])
198 199 200 201 202 203 204 205 206 207
                if len(hyps[i]) >= InferTaskConfig.n_best:
                    break


def infer(args, inferencer=fast_infer):
    place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
    test_data = reader.DataReader(
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        fpattern=args.test_file_pattern,
208
        token_delimiter=args.token_delimiter,
209
        use_token_batch=False,
210
        batch_size=args.batch_size,
211 212 213 214 215 216 217
        pool_size=args.pool_size,
        sort_type=reader.SortType.NONE,
        shuffle=False,
        shuffle_batch=False,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
218 219
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
220 221 222
        clip_last_batch=False)
    trg_idx2word = test_data.load_dict(
        dict_path=args.trg_vocab_fpath, reverse=True)
223
    inferencer(test_data, trg_idx2word)
224 225


226
if __name__ == "__main__":
227
    args = parse_args()
228
    infer(args)