From 6590b710da7aa9860feb6f9b6e69f91b2a0d29ac Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Mon, 24 Jun 2019 11:39:50 +0800 Subject: [PATCH] To avoid the error when use_program_cache is True in Transformer inference. (#2485) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To avoid the error when use_program_cache is True in Transformer infeā€¦rence. To avoid the data decode error. --- PaddleNLP/neural_machine_translation/transformer/infer.py | 6 +++--- PaddleNLP/neural_machine_translation/transformer/reader.py | 2 +- .../neural_machine_translation/transformer/infer.py | 2 +- .../neural_machine_translation/transformer/model.py | 2 +- .../neural_machine_translation/transformer/reader.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/PaddleNLP/neural_machine_translation/transformer/infer.py b/PaddleNLP/neural_machine_translation/transformer/infer.py index 08543f01..aaf813a5 100644 --- a/PaddleNLP/neural_machine_translation/transformer/infer.py +++ b/PaddleNLP/neural_machine_translation/transformer/infer.py @@ -281,11 +281,11 @@ def fast_infer(args): feed=feed_dict_list[0] if feed_dict_list is not None else None, return_numpy=False, - use_program_cache=True) + use_program_cache=False) seq_ids_list, seq_scores_list = [seq_ids], [ seq_scores - ] if isinstance( - seq_ids, paddle.fluid.LoDTensor) else (seq_ids, seq_scores) + ] if isinstance(seq_ids, + paddle.fluid.LoDTensor) else (seq_ids, seq_scores) for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list): # How to parse the results: # Suppose the lod of seq_ids is: diff --git a/PaddleNLP/neural_machine_translation/transformer/reader.py b/PaddleNLP/neural_machine_translation/transformer/reader.py index 10f44ade..aa4dea3e 100644 --- a/PaddleNLP/neural_machine_translation/transformer/reader.py +++ b/PaddleNLP/neural_machine_translation/transformer/reader.py @@ -266,7 +266,7 @@ class DataReader(object): with open(fpath, "rb") as f: for line in f: if six.PY3: - line = line.decode() + line = line.decode("utf8", errors="ignore") fields = line.strip("\n").split(self._field_delimiter) if (not self._only_src and len(fields) == 2) or ( self._only_src and len(fields) == 1): diff --git a/PaddleNLP/unarchived/neural_machine_translation/transformer/infer.py b/PaddleNLP/unarchived/neural_machine_translation/transformer/infer.py index 96b8e0a1..cf89607d 100644 --- a/PaddleNLP/unarchived/neural_machine_translation/transformer/infer.py +++ b/PaddleNLP/unarchived/neural_machine_translation/transformer/infer.py @@ -280,7 +280,7 @@ def fast_infer(args): feed=feed_dict_list[0] if feed_dict_list is not None else None, return_numpy=False, - use_program_cache=True) + use_program_cache=False) seq_ids_list, seq_scores_list = [ seq_ids ], [seq_scores] if isinstance( diff --git a/PaddleNLP/unarchived/neural_machine_translation/transformer/model.py b/PaddleNLP/unarchived/neural_machine_translation/transformer/model.py index cfd85dc5..5b19be6a 100644 --- a/PaddleNLP/unarchived/neural_machine_translation/transformer/model.py +++ b/PaddleNLP/unarchived/neural_machine_translation/transformer/model.py @@ -875,7 +875,7 @@ def fast_decode(src_vocab_size, accu_scores = layers.elementwise_add( x=layers.log(topk_scores), y=pre_scores, axis=0) # beam_search op uses lod to differentiate branches. - topk_indices = layers.lod_reset(accu_scores, pre_ids) + accu_scores = layers.lod_reset(accu_scores, pre_ids) # topK reduction across beams, also contain special handle of # end beams and end sentences(batch reduction) selected_ids, selected_scores, gather_idx = layers.beam_search( diff --git a/PaddleNLP/unarchived/neural_machine_translation/transformer/reader.py b/PaddleNLP/unarchived/neural_machine_translation/transformer/reader.py index 10f44ade..aa4dea3e 100644 --- a/PaddleNLP/unarchived/neural_machine_translation/transformer/reader.py +++ b/PaddleNLP/unarchived/neural_machine_translation/transformer/reader.py @@ -266,7 +266,7 @@ class DataReader(object): with open(fpath, "rb") as f: for line in f: if six.PY3: - line = line.decode() + line = line.decode("utf8", errors="ignore") fields = line.strip("\n").split(self._field_delimiter) if (not self._only_src and len(fields) == 2) or ( self._only_src and len(fields) == 1): -- GitLab