提交 b2eb5149 编写于 作者: G guosheng

Add beam search for Transformer.

上级 2a0991e1
......@@ -16,7 +16,9 @@ import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time
import contextlib
import numpy as np
import paddle
......@@ -27,10 +29,11 @@ from utils.check import check_gpu, check_version
# include task-specific libs
import reader
from model import Transformer, position_encoding_init
from transformer import InferTransformer, position_encoding_init
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
output_eos=False):
"""
Post-process the decoded sequence.
"""
......@@ -47,10 +50,13 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
def do_predict(args):
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
device_ids = list(range(args.num_devices))
@contextlib.contextmanager
def null_guard():
yield
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator
processor = reader.DataProcessor(fpattern=args.predict_file,
......@@ -69,68 +75,61 @@ def do_predict(args):
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict", place=place)
batch_generator = processor.data_generator(phase="predict")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
trg_idx2word = reader.DataProcessor.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
with fluid.dygraph.guard(place):
with guard:
# define data loader
test_loader = fluid.io.DataLoader.from_generator(capacity=10)
test_loader.set_batch_generator(batch_generator, places=place)
test_loader = batch_generator
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
args.postprocess_cmd, args.weight_sharing, args.bos_idx,
args.eos_idx)
transformer = InferTransformer(args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
# load the trained model
assert args.init_from_params, (
"Please set init_from_params to load the infer model.")
model_dict, _ = fluid.load_dygraph(
os.path.join(args.init_from_params, "transformer"))
# to avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
transformer.load_dict(model_dict)
# set evaluate mode
transformer.eval()
transformer.load(os.path.join(args.init_from_params, "transformer"))
f = open(args.output_file, "wb")
for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data
finished_seq, finished_scores = transformer.beam_search(
src_word,
src_pos,
finished_seq = transformer.test(inputs=(src_word, src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=args.bos_idx,
eos_id=args.eos_idx,
beam_size=args.beam_size,
max_len=args.max_out_len)
finished_seq = finished_seq.numpy()
finished_scores = finished_scores.numpy()
trg_src_attn_bias),
device='gpu',
device_ids=device_ids)[0]
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
id_list = post_process_seq(beam, args.bos_idx,
args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n"
f.write(sequence)
break
if __name__ == "__main__":
......
......@@ -114,7 +114,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head, place):
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
"""
Put all padded data needed by beam search decoder into a list.
"""
......@@ -517,7 +517,7 @@ class DataProcessor(object):
return __impl__
def data_generator(self, phase, place=None):
def data_generator(self, phase):
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
src_pad_idx = trg_pad_idx = self._eos_idx
......@@ -540,7 +540,7 @@ class DataProcessor(object):
def __for_predict__():
for data in data_reader():
data_inputs = prepare_infer_input(data, src_pad_idx, bos_idx,
n_head, place)
n_head)
yield data_inputs
return __for_train__ if phase == "train" else __for_predict__
......
此差异已折叠。
python -u train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny \
--validation_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 1 \
--use_cuda True \
--random_seed 1000 \
--save_step 10 \
--eager_run True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo `date`
python -u predict.py \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 64 \
--init_from_params base_model_dygraph/step_100000/ \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--eager_run True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo `date`
\ No newline at end of file
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册