提交 3e9fccea 编写于 作者: G guosheng

Make outputs between fast_infer and the original python infer alignment in Transformer

上级 82ba5c03
......@@ -42,9 +42,9 @@ class InferTaskConfig(object):
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
output_bos = True #False
output_eos = True #False
output_unk = True #False
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
......
......@@ -275,11 +275,11 @@ def translate_batch(exe,
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
top_scores_ids = np.asarray(
sorted(
top_scores_ids,
lambda x, y: x / predict_all.shape[-1] - y / predict_all.shape[-1]
)) # sort by pre_branch and score to compare with fast_infer
# top_scores_ids = np.asarray(
# sorted(
# top_scores_ids,
# lambda x, y: x / predict_all.shape[-1] - y / predict_all.shape[-1]
# )) # sort by pre_branch and score to compare with fast_infer
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
......@@ -368,6 +368,7 @@ def infer(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
......@@ -394,6 +395,8 @@ def infer(args):
seq)
for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
......@@ -422,6 +425,8 @@ def infer(args):
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
print scores
exit(0)
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
......@@ -522,12 +527,15 @@ def fast_infer(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
for batch_id, data in enumerate(test_data.batch_generator()):
if batch_id != 0:
continue
data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields,
encoder_util_input_fields + fast_decoder_util_input_fields,
......@@ -540,6 +548,7 @@ def fast_infer(args):
# print np.array(seq_ids)#, np.array(seq_scores)
# print seq_ids.lod()#, seq_scores.lod()
hyps = [[] for i in range(len(data))]
scores = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
......@@ -550,8 +559,11 @@ def fast_infer(args):
trg_idx2word[idx]
for idx in np.array(seq_ids)[sub_start:sub_end]
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i]
print scores[i]
print len(hyps[i]), [len(hyp.split()) for hyp in hyps[i]]
exit(0)
if __name__ == "__main__":
......
......@@ -123,15 +123,15 @@ def multi_head_attention(queries,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
global FLAG
if FLAG:
print "hehehehehe"
layers.Print(scaled_q)
layers.Print(k)
layers.Print(v)
layers.Print(product)
layers.Print(weights)
FLAG = False
# global FLAG
# if FLAG:
# print "hehehehehe"
# layers.Print(scaled_q)
# layers.Print(k)
# layers.Print(v)
# layers.Print(product)
# layers.Print(weights)
# FLAG = False
if dropout_rate:
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
......@@ -694,7 +694,7 @@ def fast_decode(
src_attn_pre_softmax_shape, src_attn_post_softmax_shape),
enc_output=pre_enc_output,
caches=pre_caches)
layers.Print(logits)
# layers.Print(logits)
topk_scores, topk_indices = layers.topk(logits, k=beam_size)
# layers.Print(topk_scores)
# layers.Print(topk_indices)
......@@ -708,6 +708,7 @@ def fast_decode(
topk_indices = layers.lod_reset(topk_indices, pre_ids)
selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
......@@ -735,12 +736,16 @@ def fast_decode(
y=attn_post_softmax_shape_delta),
slf_attn_post_softmax_shape)
max_len_cond = layers.less_than(x=step_idx, y=max_len)
all_finish_cond = layers.less_than(x=step_idx, y=max_len)
layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond)
finished_ids, finished_scores = layers.beam_search_decode(ids, scores,
eos_idx)
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
# layers.Print(length_cond)
# layers.Print(finish_cond)
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
layers.Print(step_idx)
# finished_ids, finished_scores = layers.beam_search_decode(ids, scores,
# eos_idx)
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
......
......@@ -288,6 +288,7 @@ def train(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False)
train_data = read_multiple(reader=train_data.batch_generator)
......@@ -315,6 +316,7 @@ def train(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False,
shuffle=False,
shuffle_batch=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册