predict.py 3.1 KB
Newer Older
L
liu zhengxi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
import logging
import os
import six
import sys
import time

import numpy as np
import paddle

import yaml
from attrdict import AttrDict
from pprint import pprint

# Include task-specific libs
import reader
from transformer import InferTransformerModel, position_encoding_init


def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
    """
    Post-process the decoded sequence.
    """
    eos_pos = len(seq) - 1
    for i, idx in enumerate(seq):
        if idx == eos_idx:
            eos_pos = i
            break
    seq = [
        idx for idx in seq[:eos_pos + 1]
        if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
    ]
    return seq


@paddle.fluid.dygraph.no_grad()
def do_predict(args):
    if args.use_gpu:
        place = "gpu:0"
    else:
        place = "cpu"

    paddle.set_device(place)

    # Define data loader
    (test_loader,
     test_steps_fn), trg_idx2word = reader.create_infer_loader(args)

    # Define model
    transformer = InferTransformerModel(
        src_vocab_size=args.src_vocab_size,
        trg_vocab_size=args.trg_vocab_size,
        max_length=args.max_length + 1,
        n_layer=args.n_layer,
        n_head=args.n_head,
        d_model=args.d_model,
        d_inner_hid=args.d_inner_hid,
        dropout=args.dropout,
        weight_sharing=args.weight_sharing,
        bos_id=args.bos_idx,
        eos_id=args.eos_idx,
        beam_size=args.beam_size,
        max_out_len=args.max_out_len)

    # Load the trained model
    assert args.init_from_params, (
        "Please set init_from_params to load the infer model.")

    model_dict = paddle.load(
        os.path.join(args.init_from_params, "transformer.pdparams"))

    # To avoid a longer length than training, reset the size of position
    # encoding to max_length
    model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
        args.max_length + 1, args.d_model)
    model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
        args.max_length + 1, args.d_model)
    transformer.load_dict(model_dict)

    # Set evaluate mode
    transformer.eval()

    f = open(args.output_file, "w")
    for input_data in test_loader:
        (src_word, src_pos, src_slf_attn_bias, trg_word,
         trg_src_attn_bias) = input_data
        finished_seq = transformer(
            src_word=src_word,
            src_pos=src_pos,
            src_slf_attn_bias=src_slf_attn_bias,
            trg_word=trg_word,
            trg_src_attn_bias=trg_src_attn_bias)
        finished_seq = finished_seq.numpy().transpose([0, 2, 1])
        for ins in finished_seq:
            for beam_idx, beam in enumerate(ins):
                if beam_idx >= args.n_best:
                    break
                id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
                word_list = [trg_idx2word[id] for id in id_list]
                sequence = " ".join(word_list) + "\n"
                f.write(sequence)


if __name__ == "__main__":
    yaml_file = "./transformer.yaml"
    with open(yaml_file, 'rt') as f:
        args = AttrDict(yaml.safe_load(f))
        pprint(args)

    do_predict(args)