未验证 提交 6590b710 编写于 作者: G Guo Sheng 提交者: GitHub

To avoid the error when use_program_cache is True in Transformer inference. (#2485)

To avoid the error when use_program_cache is True in Transformer infe…rence.

To avoid the data decode error.
上级 41d194cc
......@@ -281,11 +281,11 @@ def fast_infer(args):
feed=feed_dict_list[0]
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=True)
use_program_cache=False)
seq_ids_list, seq_scores_list = [seq_ids], [
seq_scores
] if isinstance(
seq_ids, paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
] if isinstance(seq_ids,
paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
# How to parse the results:
# Suppose the lod of seq_ids is:
......
......@@ -266,7 +266,7 @@ class DataReader(object):
with open(fpath, "rb") as f:
for line in f:
if six.PY3:
line = line.decode()
line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
......
......@@ -280,7 +280,7 @@ def fast_infer(args):
feed=feed_dict_list[0]
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=True)
use_program_cache=False)
seq_ids_list, seq_scores_list = [
seq_ids
], [seq_scores] if isinstance(
......
......@@ -875,7 +875,7 @@ def fast_decode(src_vocab_size,
accu_scores = layers.elementwise_add(
x=layers.log(topk_scores), y=pre_scores, axis=0)
# beam_search op uses lod to differentiate branches.
topk_indices = layers.lod_reset(accu_scores, pre_ids)
accu_scores = layers.lod_reset(accu_scores, pre_ids)
# topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction)
selected_ids, selected_scores, gather_idx = layers.beam_search(
......
......@@ -266,7 +266,7 @@ class DataReader(object):
with open(fpath, "rb") as f:
for line in f:
if six.PY3:
line = line.decode()
line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册