提交 7340c88c 编写于 作者: L Lizhengo 提交者: Yibing Liu

fix open file bug in test (#3159)

* fix pred in simnet

* add init checkpoint in train and fix file open bug in window10

* Update README.md

* fix file open bug in test
上级 dcce6de0
...@@ -17,6 +17,7 @@ sys.path.append("..") ...@@ -17,6 +17,7 @@ sys.path.append("..")
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import codecs
import config import config
import utils import utils
import reader import reader
...@@ -326,7 +327,7 @@ def test(conf_dict, args): ...@@ -326,7 +327,7 @@ def test(conf_dict, args):
simnet_process = reader.SimNetProcessor(args, vocab) simnet_process = reader.SimNetProcessor(args, vocab)
# load auc method # load auc method
metric = fluid.metrics.Auc(name="auc") metric = fluid.metrics.Auc(name="auc")
with open("predictions.txt", "w") as predictions_file: with codecs.open("predictions.txt", "w", "utf-8") as predictions_file:
# Get model path # Get model path
model_path = args.init_checkpoint model_path = args.init_checkpoint
# Get device # Get device
...@@ -430,7 +431,7 @@ def infer(args): ...@@ -430,7 +431,7 @@ def infer(args):
map(lambda item: str((item[0] + 1) / 2), output[1])) map(lambda item: str((item[0] + 1) / 2), output[1]))
else: else:
preds_list += map(lambda item: str(np.argmax(item)), output[1]) preds_list += map(lambda item: str(np.argmax(item)), output[1])
with open(args.infer_result_path, "w") as infer_file: with codecs.open(args.infer_result_path, "w", "utf-8") as infer_file:
for _data, _pred in zip(simnet_process.get_infer_data(), preds_list): for _data, _pred in zip(simnet_process.get_infer_data(), preds_list):
infer_file.write(_data + "\t" + _pred + "\n") infer_file.write(_data + "\t" + _pred + "\n")
logging.info("infer result saved in %s" % logging.info("infer result saved in %s" %
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册