You need to sign in or sign up before continuing.
提交 b2eb5149 编写于 作者: G guosheng

Add beam search for Transformer.

上级 2a0991e1
...@@ -16,7 +16,9 @@ import logging ...@@ -16,7 +16,9 @@ import logging
import os import os
import six import six
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time import time
import contextlib
import numpy as np import numpy as np
import paddle import paddle
...@@ -27,10 +29,11 @@ from utils.check import check_gpu, check_version ...@@ -27,10 +29,11 @@ from utils.check import check_gpu, check_version
# include task-specific libs # include task-specific libs
import reader import reader
from model import Transformer, position_encoding_init from transformer import InferTransformer, position_encoding_init
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
output_eos=False):
""" """
Post-process the decoded sequence. Post-process the decoded sequence.
""" """
...@@ -47,10 +50,13 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): ...@@ -47,10 +50,13 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
def do_predict(args): def do_predict(args):
if args.use_cuda: device_ids = list(range(args.num_devices))
place = fluid.CUDAPlace(0)
else: @contextlib.contextmanager
place = fluid.CPUPlace() def null_guard():
yield
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator # define the data generator
processor = reader.DataProcessor(fpattern=args.predict_file, processor = reader.DataProcessor(fpattern=args.predict_file,
...@@ -69,68 +75,61 @@ def do_predict(args): ...@@ -69,68 +75,61 @@ def do_predict(args):
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=args.max_length, max_length=args.max_length,
n_head=args.n_head) n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict", place=place) batch_generator = processor.data_generator(phase="predict")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary() args.unk_idx = processor.get_vocab_summary()
trg_idx2word = reader.DataProcessor.load_dict( trg_idx2word = reader.DataProcessor.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ with guard:
args.unk_idx = processor.get_vocab_summary()
with fluid.dygraph.guard(place):
# define data loader # define data loader
test_loader = fluid.io.DataLoader.from_generator(capacity=10) test_loader = batch_generator
test_loader.set_batch_generator(batch_generator, places=place)
# define model # define model
transformer = Transformer( transformer = InferTransformer(args.src_vocab_size,
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, args.trg_vocab_size,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, args.max_length + 1,
args.d_inner_hid, args.prepostprocess_dropout, args.n_layer,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd, args.n_head,
args.postprocess_cmd, args.weight_sharing, args.bos_idx, args.d_key,
args.eos_idx) args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
# load the trained model # load the trained model
assert args.init_from_params, ( assert args.init_from_params, (
"Please set init_from_params to load the infer model.") "Please set init_from_params to load the infer model.")
model_dict, _ = fluid.load_dygraph( transformer.load(os.path.join(args.init_from_params, "transformer"))
os.path.join(args.init_from_params, "transformer"))
# to avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
args.max_length + 1, args.d_model)
transformer.load_dict(model_dict)
# set evaluate mode
transformer.eval()
f = open(args.output_file, "wb") f = open(args.output_file, "wb")
for input_data in test_loader(): for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word, (src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data trg_src_attn_bias) = input_data
finished_seq, finished_scores = transformer.beam_search( finished_seq = transformer.test(inputs=(src_word, src_pos,
src_word, src_slf_attn_bias,
src_pos, trg_src_attn_bias),
src_slf_attn_bias, device='gpu',
trg_word, device_ids=device_ids)[0]
trg_src_attn_bias, finished_seq = np.transpose(finished_seq, [0, 2, 1])
bos_id=args.bos_idx,
eos_id=args.eos_idx,
beam_size=args.beam_size,
max_len=args.max_out_len)
finished_seq = finished_seq.numpy()
finished_scores = finished_scores.numpy()
for ins in finished_seq: for ins in finished_seq:
for beam_idx, beam in enumerate(ins): for beam_idx, beam in enumerate(ins):
if beam_idx >= args.n_best: break if beam_idx >= args.n_best: break
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) id_list = post_process_seq(beam, args.bos_idx,
args.eos_idx)
word_list = [trg_idx2word[id] for id in id_list] word_list = [trg_idx2word[id] for id in id_list]
sequence = b" ".join(word_list) + b"\n" sequence = b" ".join(word_list) + b"\n"
f.write(sequence) f.write(sequence)
break
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -114,7 +114,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -114,7 +114,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return data_inputs return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head, place): def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
""" """
Put all padded data needed by beam search decoder into a list. Put all padded data needed by beam search decoder into a list.
""" """
...@@ -517,7 +517,7 @@ class DataProcessor(object): ...@@ -517,7 +517,7 @@ class DataProcessor(object):
return __impl__ return __impl__
def data_generator(self, phase, place=None): def data_generator(self, phase):
# Any token included in dict can be used to pad, since the paddings' loss # Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients. # will be masked out by weights and make no effect on parameter gradients.
src_pad_idx = trg_pad_idx = self._eos_idx src_pad_idx = trg_pad_idx = self._eos_idx
...@@ -540,7 +540,7 @@ class DataProcessor(object): ...@@ -540,7 +540,7 @@ class DataProcessor(object):
def __for_predict__(): def __for_predict__():
for data in data_reader(): for data in data_reader():
data_inputs = prepare_infer_input(data, src_pad_idx, bos_idx, data_inputs = prepare_infer_input(data, src_pad_idx, bos_idx,
n_head, place) n_head)
yield data_inputs yield data_inputs
return __for_train__ if phase == "train" else __for_predict__ return __for_train__ if phase == "train" else __for_predict__
......
此差异已折叠。
python -u train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny \
--validation_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 1 \
--use_cuda True \
--random_seed 1000 \
--save_step 10 \
--eager_run True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo `date`
python -u predict.py \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 64 \
--init_from_params base_model_dygraph/step_100000/ \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--eager_run True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo `date`
\ No newline at end of file
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册