未验证 提交 5bc68857 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #2005 from guoshengCS/fix-transformer-core

Remove core api in Transformer
......@@ -281,10 +281,10 @@ def fast_infer(args):
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=True)
seq_ids_list, seq_scores_list = [seq_ids], [
seq_scores
] if isinstance(
seq_ids, paddle.fluid.core.LoDTensor) else (seq_ids, seq_scores)
seq_ids_list, seq_scores_list = [
seq_ids
], [seq_scores] if isinstance(
seq_ids, paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
# How to parse the results:
# Suppose the lod of seq_ids is:
......
......@@ -875,7 +875,7 @@ def fast_decode(src_vocab_size,
accu_scores = layers.elementwise_add(
x=layers.log(topk_scores), y=pre_scores, axis=0)
# beam_search op uses lod to differentiate branches.
topk_indices = layers.lod_reset(topk_indices, pre_ids)
topk_indices = layers.lod_reset(accu_scores, pre_ids)
# topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction)
selected_ids, selected_scores, gather_idx = layers.beam_search(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册