未验证 提交 6590b710 编写于 作者: G Guo Sheng 提交者: GitHub

To avoid the error when use_program_cache is True in Transformer inference. (#2485)

To avoid the error when use_program_cache is True in Transformer infe…rence.

To avoid the data decode error.
上级 41d194cc
...@@ -281,11 +281,11 @@ def fast_infer(args): ...@@ -281,11 +281,11 @@ def fast_infer(args):
feed=feed_dict_list[0] feed=feed_dict_list[0]
if feed_dict_list is not None else None, if feed_dict_list is not None else None,
return_numpy=False, return_numpy=False,
use_program_cache=True) use_program_cache=False)
seq_ids_list, seq_scores_list = [seq_ids], [ seq_ids_list, seq_scores_list = [seq_ids], [
seq_scores seq_scores
] if isinstance( ] if isinstance(seq_ids,
seq_ids, paddle.fluid.LoDTensor) else (seq_ids, seq_scores) paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list): for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
# How to parse the results: # How to parse the results:
# Suppose the lod of seq_ids is: # Suppose the lod of seq_ids is:
......
...@@ -266,7 +266,7 @@ class DataReader(object): ...@@ -266,7 +266,7 @@ class DataReader(object):
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
for line in f: for line in f:
if six.PY3: if six.PY3:
line = line.decode() line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter) fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1): self._only_src and len(fields) == 1):
......
...@@ -280,7 +280,7 @@ def fast_infer(args): ...@@ -280,7 +280,7 @@ def fast_infer(args):
feed=feed_dict_list[0] feed=feed_dict_list[0]
if feed_dict_list is not None else None, if feed_dict_list is not None else None,
return_numpy=False, return_numpy=False,
use_program_cache=True) use_program_cache=False)
seq_ids_list, seq_scores_list = [ seq_ids_list, seq_scores_list = [
seq_ids seq_ids
], [seq_scores] if isinstance( ], [seq_scores] if isinstance(
......
...@@ -875,7 +875,7 @@ def fast_decode(src_vocab_size, ...@@ -875,7 +875,7 @@ def fast_decode(src_vocab_size,
accu_scores = layers.elementwise_add( accu_scores = layers.elementwise_add(
x=layers.log(topk_scores), y=pre_scores, axis=0) x=layers.log(topk_scores), y=pre_scores, axis=0)
# beam_search op uses lod to differentiate branches. # beam_search op uses lod to differentiate branches.
topk_indices = layers.lod_reset(accu_scores, pre_ids) accu_scores = layers.lod_reset(accu_scores, pre_ids)
# topK reduction across beams, also contain special handle of # topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction) # end beams and end sentences(batch reduction)
selected_ids, selected_scores, gather_idx = layers.beam_search( selected_ids, selected_scores, gather_idx = layers.beam_search(
......
...@@ -266,7 +266,7 @@ class DataReader(object): ...@@ -266,7 +266,7 @@ class DataReader(object):
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
for line in f: for line in f:
if six.PY3: if six.PY3:
line = line.decode() line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter) fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1): self._only_src and len(fields) == 1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册