提交 ab1b1de9 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/models into add-transformer-ppl

......@@ -34,6 +34,11 @@ class InferTaskConfig(object):
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
......@@ -59,6 +64,8 @@ class ModelHyperParams(object):
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# position value corresponding to the <pad> token.
pos_pad_idx = 0
......
......@@ -11,10 +11,25 @@ from config import InferTaskConfig, ModelHyperParams, \
from train import pad_batch_data
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx,
bos_idx, eos_idx):
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
......@@ -48,7 +63,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# size of feeded batch is changing.
beam_map = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
......@@ -60,7 +75,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs
......@@ -114,8 +130,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array(
[
beam_backtrace(
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
......@@ -167,6 +182,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
predict = (predict_all[inst_idx, :, :]
......@@ -187,7 +204,10 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)]
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist()
......@@ -254,17 +274,47 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx)
exe, [item[0] for item in data],
encoder_program,
encoder_input_data_names, [enc_output.name],
decoder_program,
decoder_input_data_names, [predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册