提交 1214fdfd 编写于 作者: G guoshengCS

Fix inference output encoding in Transformer under python3 and windows.

上级 f3c92f11
......@@ -4,6 +4,9 @@ import multiprocessing
import numpy as np
import os
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/")
from functools import partial
......@@ -307,13 +310,13 @@ def fast_infer(args):
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
hyps[i].append(b" ".join([
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1])
print(hyps[i][-1].decode("utf8"))
if len(hyps[i]) >= InferTaskConfig.n_best:
break
except (StopIteration, fluid.core.EOFException):
......
......@@ -11,6 +11,9 @@ if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
import six
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/")
import time
......
......@@ -3,6 +3,10 @@ import ast
import multiprocessing
import numpy as np
import os
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
from functools import partial
import paddle
......@@ -303,13 +307,13 @@ def fast_infer(args):
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
hyps[i].append(b" ".join([
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1])
print(hyps[i][-1].decode("utf8"))
if len(hyps[i]) >= InferTaskConfig.n_best:
break
except (StopIteration, fluid.core.EOFException):
......
......@@ -6,6 +6,9 @@ import multiprocessing
import os
import six
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import time
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册