import logging import os import six import sys import time import numpy as np import paddle import yaml from attrdict import AttrDict from pprint import pprint # Include task-specific libs import reader from transformer import InferTransformerModel, position_encoding_init def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): """ Post-process the decoded sequence. """ eos_pos = len(seq) - 1 for i, idx in enumerate(seq): if idx == eos_idx: eos_pos = i break seq = [ idx for idx in seq[:eos_pos + 1] if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) ] return seq @paddle.fluid.dygraph.no_grad() def do_predict(args): if args.use_gpu: place = "gpu:0" else: place = "cpu" paddle.set_device(place) # Define data loader (test_loader, test_steps_fn), trg_idx2word = reader.create_infer_loader(args) # Define model transformer = InferTransformerModel( src_vocab_size=args.src_vocab_size, trg_vocab_size=args.trg_vocab_size, max_length=args.max_length + 1, n_layer=args.n_layer, n_head=args.n_head, d_model=args.d_model, d_inner_hid=args.d_inner_hid, dropout=args.dropout, weight_sharing=args.weight_sharing, bos_id=args.bos_idx, eos_id=args.eos_idx, beam_size=args.beam_size, max_out_len=args.max_out_len) # Load the trained model assert args.init_from_params, ( "Please set init_from_params to load the infer model.") model_dict = paddle.load( os.path.join(args.init_from_params, "transformer.pdparams")) # To avoid a longer length than training, reset the size of position # encoding to max_length model_dict["encoder.pos_encoder.weight"] = position_encoding_init( args.max_length + 1, args.d_model) model_dict["decoder.pos_encoder.weight"] = position_encoding_init( args.max_length + 1, args.d_model) transformer.load_dict(model_dict) # Set evaluate mode transformer.eval() f = open(args.output_file, "w") for input_data in test_loader: (src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias) = input_data finished_seq = transformer( src_word=src_word, src_pos=src_pos, src_slf_attn_bias=src_slf_attn_bias, trg_word=trg_word, trg_src_attn_bias=trg_src_attn_bias) finished_seq = finished_seq.numpy().transpose([0, 2, 1]) for ins in finished_seq: for beam_idx, beam in enumerate(ins): if beam_idx >= args.n_best: break id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) word_list = [trg_idx2word[id] for id in id_list] sequence = " ".join(word_list) + "\n" f.write(sequence) if __name__ == "__main__": yaml_file = "./transformer.yaml" with open(yaml_file, 'rt') as f: args = AttrDict(yaml.safe_load(f)) pprint(args) do_predict(args)