提交 3e9fccea 编写于 作者: G guosheng

Make outputs between fast_infer and the original python infer alignment in Transformer

上级 82ba5c03
...@@ -42,9 +42,9 @@ class InferTaskConfig(object): ...@@ -42,9 +42,9 @@ class InferTaskConfig(object):
# the number of decoded sentences to output. # the number of decoded sentences to output.
n_best = 1 n_best = 1
# the flags indicating whether to output the special tokens. # the flags indicating whether to output the special tokens.
output_bos = False output_bos = True #False
output_eos = False output_eos = True #False
output_unk = False output_unk = True #False
# the directory for loading the trained model. # the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model" model_path = "trained_models/pass_1.infer.model"
......
...@@ -275,11 +275,11 @@ def translate_batch(exe, ...@@ -275,11 +275,11 @@ def translate_batch(exe,
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[:: top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]] -1]]
top_scores_ids = np.asarray( # top_scores_ids = np.asarray(
sorted( # sorted(
top_scores_ids, # top_scores_ids,
lambda x, y: x / predict_all.shape[-1] - y / predict_all.shape[-1] # lambda x, y: x / predict_all.shape[-1] - y / predict_all.shape[-1]
)) # sort by pre_branch and score to compare with fast_infer # )) # sort by pre_branch and score to compare with fast_infer
top_scores = predict[top_scores_ids] top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids / prev_branchs[beam_idx].append(top_scores_ids /
...@@ -368,6 +368,7 @@ def infer(args): ...@@ -368,6 +368,7 @@ def infer(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False) clip_last_batch=False)
trg_idx2word = test_data.load_dict( trg_idx2word = test_data.load_dict(
...@@ -394,6 +395,8 @@ def infer(args): ...@@ -394,6 +395,8 @@ def infer(args):
seq) seq)
for batch_id, data in enumerate(test_data.batch_generator()): for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
exe, exe,
[item[0] for item in data], [item[0] for item in data],
...@@ -422,6 +425,8 @@ def infer(args): ...@@ -422,6 +425,8 @@ def infer(args):
scores = batch_scores[i] scores = batch_scores[i]
for seq in seqs: for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq])) print(" ".join([trg_idx2word[idx] for idx in seq]))
print scores
exit(0)
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
...@@ -522,12 +527,15 @@ def fast_infer(args): ...@@ -522,12 +527,15 @@ def fast_infer(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False) clip_last_batch=False)
trg_idx2word = test_data.load_dict( trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
for batch_id, data in enumerate(test_data.batch_generator()): for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields, data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields, encoder_util_input_fields + fast_decoder_util_input_fields,
...@@ -540,6 +548,7 @@ def fast_infer(args): ...@@ -540,6 +548,7 @@ def fast_infer(args):
# print np.array(seq_ids)#, np.array(seq_scores) # print np.array(seq_ids)#, np.array(seq_scores)
# print seq_ids.lod()#, seq_scores.lod() # print seq_ids.lod()#, seq_scores.lod()
hyps = [[] for i in range(len(data))] hyps = [[] for i in range(len(data))]
scores = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
start = seq_ids.lod()[0][i] start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1] end = seq_ids.lod()[0][i + 1]
...@@ -550,8 +559,11 @@ def fast_infer(args): ...@@ -550,8 +559,11 @@ def fast_infer(args):
trg_idx2word[idx] trg_idx2word[idx]
for idx in np.array(seq_ids)[sub_start:sub_end] for idx in np.array(seq_ids)[sub_start:sub_end]
])) ]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i] print hyps[i]
print scores[i]
print len(hyps[i]), [len(hyp.split()) for hyp in hyps[i]] print len(hyps[i]), [len(hyp.split()) for hyp in hyps[i]]
exit(0)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -123,15 +123,15 @@ def multi_head_attention(queries, ...@@ -123,15 +123,15 @@ def multi_head_attention(queries,
act="softmax") act="softmax")
weights = layers.reshape( weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape) x=weights, shape=product.shape, actual_shape=post_softmax_shape)
global FLAG # global FLAG
if FLAG: # if FLAG:
print "hehehehehe" # print "hehehehehe"
layers.Print(scaled_q) # layers.Print(scaled_q)
layers.Print(k) # layers.Print(k)
layers.Print(v) # layers.Print(v)
layers.Print(product) # layers.Print(product)
layers.Print(weights) # layers.Print(weights)
FLAG = False # FLAG = False
if dropout_rate: if dropout_rate:
weights = layers.dropout( weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False) weights, dropout_prob=dropout_rate, is_test=False)
...@@ -694,7 +694,7 @@ def fast_decode( ...@@ -694,7 +694,7 @@ def fast_decode(
src_attn_pre_softmax_shape, src_attn_post_softmax_shape), src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output, enc_output=pre_enc_output,
caches=pre_caches) caches=pre_caches)
layers.Print(logits) # layers.Print(logits)
topk_scores, topk_indices = layers.topk(logits, k=beam_size) topk_scores, topk_indices = layers.topk(logits, k=beam_size)
# layers.Print(topk_scores) # layers.Print(topk_scores)
# layers.Print(topk_indices) # layers.Print(topk_indices)
...@@ -708,6 +708,7 @@ def fast_decode( ...@@ -708,6 +708,7 @@ def fast_decode(
topk_indices = layers.lod_reset(topk_indices, pre_ids) topk_indices = layers.lod_reset(topk_indices, pre_ids)
selected_ids, selected_scores = layers.beam_search( selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids, pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices, ids=topk_indices,
scores=accu_scores, scores=accu_scores,
beam_size=beam_size, beam_size=beam_size,
...@@ -735,12 +736,16 @@ def fast_decode( ...@@ -735,12 +736,16 @@ def fast_decode(
y=attn_post_softmax_shape_delta), y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape) slf_attn_post_softmax_shape)
max_len_cond = layers.less_than(x=step_idx, y=max_len) length_cond = layers.less_than(x=step_idx, y=max_len)
all_finish_cond = layers.less_than(x=step_idx, y=max_len) finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond) # layers.Print(length_cond)
# layers.Print(finish_cond)
finished_ids, finished_scores = layers.beam_search_decode(ids, scores, layers.logical_and(x=length_cond, y=finish_cond, out=cond)
eos_idx) layers.Print(step_idx)
# finished_ids, finished_scores = layers.beam_search_decode(ids, scores,
# eos_idx)
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores return finished_ids, finished_scores
finished_ids, finished_scores = beam_search() finished_ids, finished_scores = beam_search()
......
...@@ -288,6 +288,7 @@ def train(args): ...@@ -288,6 +288,7 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False) clip_last_batch=False)
train_data = read_multiple(reader=train_data.batch_generator) train_data = read_multiple(reader=train_data.batch_generator)
...@@ -315,6 +316,7 @@ def train(args): ...@@ -315,6 +316,7 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False, clip_last_batch=False,
shuffle=False, shuffle=False,
shuffle_batch=False) shuffle_batch=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册