未验证 提交 c2de925f 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #809 from guoshengCS/refine-transformer-token

Refine the inference to output special tokens optionally in Transformer
...@@ -31,6 +31,11 @@ class InferTaskConfig(object): ...@@ -31,6 +31,11 @@ class InferTaskConfig(object):
# the number of decoded sentences to output. # the number of decoded sentences to output.
n_best = 1 n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
# the directory for loading the trained model. # the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model" model_path = "trained_models/pass_1.infer.model"
...@@ -56,6 +61,8 @@ class ModelHyperParams(object): ...@@ -56,6 +61,8 @@ class ModelHyperParams(object):
bos_idx = 0 bos_idx = 0
# index for <eos> token # index for <eos> token
eos_idx = 1 eos_idx = 1
# index for <unk> token
unk_idx = 2
# position value corresponding to the <pad> token. # position value corresponding to the <pad> token.
pos_pad_idx = 0 pos_pad_idx = 0
......
...@@ -11,10 +11,25 @@ from config import InferTaskConfig, ModelHyperParams, \ ...@@ -11,10 +11,25 @@ from config import InferTaskConfig, ModelHyperParams, \
from train import pad_batch_data from train import pad_batch_data
def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, def translate_batch(exe,
decoder, dec_in_names, dec_out_names, beam_size, max_length, src_words,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, encoder,
bos_idx, eos_idx): enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
""" """
Run the encoder program once and run the decoder program multiple times to Run the encoder program once and run the decoder program multiple times to
implement beam search externally. implement beam search externally.
...@@ -48,7 +63,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -48,7 +63,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# size of feeded batch is changing. # size of feeded batch is changing.
beam_map = range(batch_size) beam_map = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
""" """
Decode and select n_best sequences for one instance by backtrace. Decode and select n_best sequences for one instance by backtrace.
""" """
...@@ -60,7 +75,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -60,7 +75,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
seq.append(next_ids[j][k]) seq.append(next_ids[j][k])
k = prev_branchs[j][k] k = prev_branchs[j][k]
seq = seq[::-1] seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq # Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq) seqs.append(seq)
return seqs return seqs
...@@ -114,8 +130,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -114,8 +130,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
trg_cur_len = len(next_ids[0]) + 1 # include the <bos> trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array( trg_words = np.array(
[ [
beam_backtrace( beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
for beam_idx in active_beams for beam_idx in active_beams
], ],
dtype="int64") dtype="int64")
...@@ -167,6 +182,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -167,6 +182,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
predict_all = (predict_all + scores[beam_map].reshape( predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape( [len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1]) [len(beam_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = [] active_beams = []
for inst_idx, beam_idx in enumerate(beam_map): for inst_idx, beam_idx in enumerate(beam_map):
predict = (predict_all[inst_idx, :, :] predict = (predict_all[inst_idx, :, :]
...@@ -187,7 +204,10 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -187,7 +204,10 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)
# Decode beams and select n_best sequences for each instance by backtrace. # Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist() return seqs, scores[:, :n_best].tolist()
...@@ -254,17 +274,47 @@ def main(): ...@@ -254,17 +274,47 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict( trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data()): for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program, exe, [item[0] for item in data],
encoder_input_data_names, [enc_output.name], decoder_program, encoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, encoder_input_data_names, [enc_output.name],
InferTaskConfig.max_length, InferTaskConfig.n_best, decoder_program,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx, decoder_input_data_names, [predict.name],
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, InferTaskConfig.beam_size,
ModelHyperParams.eos_idx) InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)): for i in range(len(batch_seqs)):
seqs = batch_seqs[i] # Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i] scores = batch_scores[i]
for seq in seqs: for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq])) print(" ".join([trg_idx2word[idx] for idx in seq]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册