提交 1214fdfd 编写于 作者: G guoshengCS

Fix inference output encoding in Transformer under python3 and windows.

上级 f3c92f11
...@@ -4,6 +4,9 @@ import multiprocessing ...@@ -4,6 +4,9 @@ import multiprocessing
import numpy as np import numpy as np
import os import os
import sys import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../") sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/") sys.path.append("../../models/neural_machine_translation/transformer/")
from functools import partial from functools import partial
...@@ -307,13 +310,13 @@ def fast_infer(args): ...@@ -307,13 +310,13 @@ def fast_infer(args):
for j in range(end - start): # for each candidate for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j] sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1] sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([ hyps[i].append(b" ".join([
trg_idx2word[idx] trg_idx2word[idx]
for idx in post_process_seq( for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end]) np.array(seq_ids)[sub_start:sub_end])
])) ]))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1]) print(hyps[i][-1].decode("utf8"))
if len(hyps[i]) >= InferTaskConfig.n_best: if len(hyps[i]) >= InferTaskConfig.n_best:
break break
except (StopIteration, fluid.core.EOFException): except (StopIteration, fluid.core.EOFException):
......
...@@ -11,6 +11,9 @@ if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None: ...@@ -11,6 +11,9 @@ if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
import six import six
import sys import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../") sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/") sys.path.append("../../models/neural_machine_translation/transformer/")
import time import time
......
...@@ -3,6 +3,10 @@ import ast ...@@ -3,6 +3,10 @@ import ast
import multiprocessing import multiprocessing
import numpy as np import numpy as np
import os import os
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
from functools import partial from functools import partial
import paddle import paddle
...@@ -303,13 +307,13 @@ def fast_infer(args): ...@@ -303,13 +307,13 @@ def fast_infer(args):
for j in range(end - start): # for each candidate for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j] sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1] sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([ hyps[i].append(b" ".join([
trg_idx2word[idx] trg_idx2word[idx]
for idx in post_process_seq( for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end]) np.array(seq_ids)[sub_start:sub_end])
])) ]))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1]) print(hyps[i][-1].decode("utf8"))
if len(hyps[i]) >= InferTaskConfig.n_best: if len(hyps[i]) >= InferTaskConfig.n_best:
break break
except (StopIteration, fluid.core.EOFException): except (StopIteration, fluid.core.EOFException):
......
...@@ -6,6 +6,9 @@ import multiprocessing ...@@ -6,6 +6,9 @@ import multiprocessing
import os import os
import six import six
import sys import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import time import time
import numpy as np import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册