提交 4a1803d1 编写于 作者: D Dilyar 提交者: Yibing Liu

simnet update to paddle v1.6 (#3711)

* Add files via upload

* Delete unicom_compute_pos_neg.py

* Delete unicom_split.py

* Add files via upload

* Update README.md

* Delete reader.py

* Delete run.sh

* Delete run_classifier.py

* Delete utils.py

* Add files via upload

* Update config.py

* Update run.sh
上级 939c4853
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
|UNICOM|联通客服|客服| |UNICOM|联通客服|客服|
## 快速开始 ## 快速开始
#### 版本依赖 #### 版本依赖
本项目依赖于 Paddlepaddle Fluid 1.3.1,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。 本项目依赖于 Paddlepaddle Fluid 1.6,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。
python版本依赖python 2.7 python版本依赖python 2.7
#### 安装代码 #### 安装代码
...@@ -36,11 +36,19 @@ cd models/PaddleNLP/similarity_net ...@@ -36,11 +36,19 @@ cd models/PaddleNLP/similarity_net
```shell ```shell
sh download_data.sh sh download_data.sh
``` ```
或者
```
python download.py dataset
```
#### 模型准备 #### 模型准备
我们开源了基于大规模数据训练好的```pairwise```模型(基于bow模型训练),用户可以通过运行命令下载预训练好的模型,该模型将保存在```./model_files/simnet_bow_pairwise_pretrained_model/```下。 我们开源了基于大规模数据训练好的```pairwise```模型(基于bow模型训练),用户可以通过运行命令下载预训练好的模型,该模型将保存在```./model_files/simnet_bow_pairwise_pretrained_model/```下。
```shell ```shell
sh download_pretrained_model.sh sh download_pretrained_model.sh
``` ```
或者
```
python download.py model
```
#### 评估 #### 评估
我们公开了自建的测试集,包括百度知道、ECOM、QQSIM、UNICOM四个数据集,基于上面的预训练模型,用户可以进入evaluate目录下依次执行下列命令获取测试集评估结果。 我们公开了自建的测试集,包括百度知道、ECOM、QQSIM、UNICOM四个数据集,基于上面的预训练模型,用户可以进入evaluate目录下依次执行下列命令获取测试集评估结果。
...@@ -137,6 +145,7 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut ...@@ -137,6 +145,7 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut
├── reader.py:定义了读入数据的相关函数 ├── reader.py:定义了读入数据的相关函数
├── utils.py:定义了其他常用的功能函数 ├── utils.py:定义了其他常用的功能函数
├── Config: 定义多种模型的配置文件 ├── Config: 定义多种模型的配置文件
├── download.py: 下载数据及预训练模型脚本
``` ```
### 如何训练 ### 如何训练
......
...@@ -17,6 +17,7 @@ SimNet config ...@@ -17,6 +17,7 @@ SimNet config
import six import six
import json import json
import io
class SimNetConfig(object): class SimNetConfig(object):
...@@ -31,7 +32,7 @@ class SimNetConfig(object): ...@@ -31,7 +32,7 @@ class SimNetConfig(object):
def _parse(self, config_path): def _parse(self, config_path):
try: try:
with open(config_path) as json_file: with io.open(config_path) as json_file:
config_dict = json.load(json_file) config_dict = json.load(json_file)
except Exception: except Exception:
raise IOError("Error in parsing simnet model config file '%s'" % config_path) raise IOError("Error in parsing simnet model config file '%s'" % config_path)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import time
import hashlib
import tarfile
import requests
def usage():
desc = ("\nDownload datasets and pretrained models for SimilarityNet task.\n"
"Usage:\n"
" 1. python download.py dataset\n"
" 2. python download.py model\n")
print(desc)
def md5file(fname):
hash_md5 = hashlib.md5()
with io.open(fname, "rb") as fin:
for chunk in iter(lambda: fin.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def extract(fname, dir_path):
"""
Extract tar.gz file
"""
try:
tar = tarfile.open(fname, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, dir_path)
print(file_name)
tar.close()
except Exception as e:
raise e
def download(url, filename, md5sum):
"""
Download file and check md5
"""
retry = 0
retry_limit = 3
chunk_size = 4096
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError("Cannot download dataset ({0}) with retry {1} times.".
format(url, retry_limit))
try:
start = time.time()
size = 0
res = requests.get(url, stream=True)
filesize = int(res.headers['content-length'])
if res.status_code == 200:
print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024))
# save by chunk
with io.open(filename, "wb") as fout:
for chunk in res.iter_content(chunk_size=chunk_size):
if chunk:
fout.write(chunk)
size += len(chunk)
pr = '>' * int(size * 50 / filesize)
print('\r[Process ]: %s%.2f%%' % (pr, float(size / filesize*100)), end='')
end = time.time()
print("\n[CostTime]: %.2f s" % (end - start))
except Exception as e:
print(e)
def download_dataset(dir_path):
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
DATASET_NAME = "simnet_dataset-1.0.0.tar.gz"
DATASET_MD5 = "ec65b313bc237150ef536a8d26f3c73b"
file_path = os.path.join(dir_path, DATASET_NAME)
url = BASE_URL + DATASET_NAME
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# download dataset
print("Downloading dataset: %s" % url)
download(url, file_path, DATASET_MD5)
# extract dataset
print("Extracting dataset: %s" % file_path)
extract(file_path, dir_path)
os.remove(file_path)
def download_model(dir_path):
MODELS = {}
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
CNN_NAME = "simnet_bow-pairwise-1.0.0.tar.gz"
CNN_MD5 = "199a3f3af31558edcc71c3b54ea5e129"
MODELS[CNN_NAME] = CNN_MD5
if not os.path.exists(dir_path):
os.makedirs(dir_path)
for model in MODELS:
url = BASE_URL + model
model_path = os.path.join(dir_path, model)
print("Downloading model: %s" % url)
# download model
download(url, model_path, MODELS[model])
# extract model.tar.gz
print("Extracting model: %s" % model_path)
extract(model_path, dir_path)
os.remove(model_path)
if __name__ == '__main__':
if len(sys.argv) != 2:
usage()
sys.exit(1)
if sys.argv[1] == "dataset":
pwd = os.path.join(os.path.dirname(__file__), './')
download_dataset(pwd)
elif sys.argv[1] == "model":
pwd = os.path.join(os.path.dirname(__file__), './pretrain_models')
download_model(pwd)
else:
usage()
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
comput unicom comput unicom
""" """
import io
infer_results = [] infer_results = []
labels = [] labels = []
result = [] result = []
...@@ -23,11 +26,11 @@ temp_query = "" ...@@ -23,11 +26,11 @@ temp_query = ""
pos_num = 0.0 pos_num = 0.0
neg_num = 0.0 neg_num = 0.0
with open("./unicom_infer_result", "r") as infer_result_file: with io.open("./unicom_infer_result", "r", encoding="utf8") as infer_result_file:
for line in infer_result_file: for line in infer_result_file:
infer_results.append(line.strip().split("\t")) infer_results.append(line.strip().split("\t"))
with open("./unicom_label", "r") as label_file: with io.open("./unicom_label", "r", encoding="utf8") as label_file:
for line in label_file: for line in label_file:
labels.append(line.strip().split("\t")) labels.append(line.strip().split("\t"))
......
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
split unicom file split unicom file
""" """
with open("../data/unicom", "r") as unicom_file: import io
with open("./unicom_infer", "w") as infer_file:
with open("./unicom_label", "w") as label_file:
with io.open("../data/unicom", "r", encoding="utf8") as unicom_file:
with io.open("./unicom_infer", "w", encoding="utf8") as infer_file:
with io.open("./unicom_label", "w", encoding="utf8") as label_file:
for line in unicom_file: for line in unicom_file:
line = line.strip().split('\t') line = line.strip().split('\t')
infer_file.write("\t".join(line[:2]) + '\n') infer_file.write("\t".join(line[:2]) + '\n')
......
...@@ -17,7 +17,7 @@ SimNet reader ...@@ -17,7 +17,7 @@ SimNet reader
import logging import logging
import numpy as np import numpy as np
import codecs import io
class SimNetProcessor(object): class SimNetProcessor(object):
...@@ -38,8 +38,8 @@ class SimNetProcessor(object): ...@@ -38,8 +38,8 @@ class SimNetProcessor(object):
Reader with Pairwise Reader with Pairwise
""" """
if mode == "valid": if mode == "valid":
with codecs.open(self.args.valid_data_dir, "r", with io.open(self.args.valid_data_dir, "r",
"utf-8") as file: encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -62,7 +62,7 @@ class SimNetProcessor(object): ...@@ -62,7 +62,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
elif mode == "test": elif mode == "test":
with codecs.open(self.args.test_data_dir, "r", "utf-8") as file: with io.open(self.args.test_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -85,8 +85,8 @@ class SimNetProcessor(object): ...@@ -85,8 +85,8 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with codecs.open(self.args.train_data_dir, "r", with io.open(self.args.train_data_dir, "r",
"utf-8") as file: encoding="utf8") as file:
for line in file: for line in file:
query, pos_title, neg_title = line.strip().split("\t") query, pos_title, neg_title = line.strip().split("\t")
if len(query) == 0 or len(pos_title) == 0 or len( if len(query) == 0 or len(pos_title) == 0 or len(
...@@ -119,8 +119,8 @@ class SimNetProcessor(object): ...@@ -119,8 +119,8 @@ class SimNetProcessor(object):
Reader with Pointwise Reader with Pointwise
""" """
if mode == "valid": if mode == "valid":
with codecs.open(self.args.valid_data_dir, "r", with io.open(self.args.valid_data_dir, "r",
"utf-8") as file: encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -143,7 +143,7 @@ class SimNetProcessor(object): ...@@ -143,7 +143,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
elif mode == "test": elif mode == "test":
with codecs.open(self.args.test_data_dir, "r", "utf-8") as file: with io.open(self.args.test_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -166,8 +166,8 @@ class SimNetProcessor(object): ...@@ -166,8 +166,8 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with codecs.open(self.args.train_data_dir, "r", with io.open(self.args.train_data_dir, "r",
"utf-8") as file: encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -200,7 +200,7 @@ class SimNetProcessor(object): ...@@ -200,7 +200,7 @@ class SimNetProcessor(object):
""" """
get infer reader get infer reader
""" """
with codecs.open(self.args.infer_data_dir, "r", "utf-8") as file: with io.open(self.args.infer_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title = line.strip().split("\t") query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0: if len(query) == 0 or len(title) == 0:
...@@ -224,7 +224,7 @@ class SimNetProcessor(object): ...@@ -224,7 +224,7 @@ class SimNetProcessor(object):
""" """
get infer data get infer data
""" """
with codecs.open(self.args.infer_data_dir, "r", "utf-8") as file: with io.open(self.args.infer_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title = line.strip().split("\t") query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0: if len(query) == 0 or len(title) == 0:
...@@ -238,7 +238,7 @@ class SimNetProcessor(object): ...@@ -238,7 +238,7 @@ class SimNetProcessor(object):
""" """
if self.valid_label.size == 0: if self.valid_label.size == 0:
labels = [] labels = []
with codecs.open(self.args.valid_data_dir, "r", "utf-8") as f: with io.open(self.args.valid_data_dir, "r", encoding="utf8") as f:
for line in f: for line in f:
labels.append([int(line.strip().split("\t")[-1])]) labels.append([int(line.strip().split("\t")[-1])])
self.valid_label = np.array(labels) self.valid_label = np.array(labels)
...@@ -250,7 +250,7 @@ class SimNetProcessor(object): ...@@ -250,7 +250,7 @@ class SimNetProcessor(object):
""" """
if self.test_label.size == 0: if self.test_label.size == 0:
labels = [] labels = []
with codecs.open(self.args.test_data_dir, "r", "utf-8") as f: with io.open(self.args.test_data_dir, "r", encoding="utf8") as f:
for line in f: for line in f:
labels.append([int(line.strip().split("\t")[-1])]) labels.append([int(line.strip().split("\t")[-1])])
self.test_label = np.array(labels) self.test_label = np.array(labels)
......
...@@ -39,11 +39,13 @@ import config ...@@ -39,11 +39,13 @@ import config
import utils import utils
import reader import reader
import models.matching.paddle_layers as layers import models.matching.paddle_layers as layers
import codecs import io
import logging
from utils import ArgConfig from utils import ArgConfig
from models.model_check import check_version
from models.model_check import check_cuda
import logging
def create_model(args, pyreader_name, is_inference = False, is_pointwise = False): def create_model(args, pyreader_name, is_inference = False, is_pointwise = False):
""" """
...@@ -99,8 +101,6 @@ def train(conf_dict, args): ...@@ -99,8 +101,6 @@ def train(conf_dict, args):
vocab = utils.load_vocab(args.vocab_path) vocab = utils.load_vocab(args.vocab_path)
# get vocab size # get vocab size
conf_dict['dict_size'] = len(vocab) conf_dict['dict_size'] = len(vocab)
# Get data layer
data = layers.DataLayer()
# Load network structure dynamically # Load network structure dynamically
net = utils.import_class("../models/matching", net = utils.import_class("../models/matching",
conf_dict["net"]["module_name"], conf_dict["net"]["module_name"],
...@@ -182,7 +182,7 @@ def train(conf_dict, args): ...@@ -182,7 +182,7 @@ def train(conf_dict, args):
return auc and acc return auc and acc
""" """
# Get Batch Data # Get Batch Data
batch_data = paddle.batch(get_valid_examples, args.batch_size, drop_last=False) batch_data = fluid.io.batch(get_valid_examples, args.batch_size, drop_last=False)
test_pyreader.decorate_paddle_reader(batch_data) test_pyreader.decorate_paddle_reader(batch_data)
test_pyreader.start() test_pyreader.start()
pred_list = [] pred_list = []
...@@ -219,8 +219,8 @@ def train(conf_dict, args): ...@@ -219,8 +219,8 @@ def train(conf_dict, args):
ce_info = [] ce_info = []
train_exe = exe train_exe = exe
for epoch_id in range(args.epoch): for epoch_id in range(args.epoch):
train_batch_data = paddle.batch( train_batch_data = fluid.io.batch(
paddle.reader.shuffle( fluid.io.shuffle(
get_train_examples, buf_size=10000), get_train_examples, buf_size=10000),
args.batch_size, args.batch_size,
drop_last=False) drop_last=False)
...@@ -343,8 +343,7 @@ def test(conf_dict, args): ...@@ -343,8 +343,7 @@ def test(conf_dict, args):
startup_prog = fluid.Program() startup_prog = fluid.Program()
get_test_examples = simnet_process.get_reader("test") get_test_examples = simnet_process.get_reader("test")
batch_data = paddle.batch(get_test_examples, args.batch_size, drop_last=False) batch_data = fluid.io.batch(get_test_examples, args.batch_size, drop_last=False)
test_prog = fluid.Program() test_prog = fluid.Program()
conf_dict['dict_size'] = len(vocab) conf_dict['dict_size'] = len(vocab)
...@@ -355,7 +354,7 @@ def test(conf_dict, args): ...@@ -355,7 +354,7 @@ def test(conf_dict, args):
metric = fluid.metrics.Auc(name="auc") metric = fluid.metrics.Auc(name="auc")
with codecs.open("predictions.txt", "w", "utf-8") as predictions_file: with io.open("predictions.txt", "w", encoding="utf8") as predictions_file:
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -365,6 +364,7 @@ def test(conf_dict, args): ...@@ -365,6 +364,7 @@ def test(conf_dict, args):
is_inference=True) is_inference=True)
left_feat, pos_score = net.predict(left, pos_right) left_feat, pos_score = net.predict(left, pos_right)
pred = pos_score pred = pos_score
#print(pred)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
else: else:
...@@ -398,11 +398,11 @@ def test(conf_dict, args): ...@@ -398,11 +398,11 @@ def test(conf_dict, args):
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
pred_list += list(map(lambda item: float(item[0]), output[0])) pred_list += list(map(lambda item: float(item[0]), output[0]))
predictions_file.write("\n".join( predictions_file.write("\n".join(
map(lambda item: str((item[0] + 1) / 2), output[0])) + "\n") map(lambda item: str((item[0] + 1) / 2).decode(), output[0])) + "\n")
else: else:
pred_list += map(lambda item: item, output[0]) pred_list += map(lambda item: item, output[0])
predictions_file.write("\n".join( predictions_file.write("\n".join(
map(lambda item: str(np.argmax(item)), output[0])) + "\n") map(lambda item: str(np.argmax(item)).decode(), output[0])) + "\n")
except fluid.core.EOFException: except fluid.core.EOFException:
test_pyreader.reset() test_pyreader.reset()
break break
...@@ -446,7 +446,7 @@ def infer(conf_dict, args): ...@@ -446,7 +446,7 @@ def infer(conf_dict, args):
startup_prog = fluid.Program() startup_prog = fluid.Program()
get_infer_examples = simnet_process.get_infer_reader get_infer_examples = simnet_process.get_infer_reader
batch_data = paddle.batch(get_infer_examples, args.batch_size, drop_last=False) batch_data = fluid.io.batch(get_infer_examples, args.batch_size, drop_last=False)
test_prog = fluid.Program() test_prog = fluid.Program()
...@@ -490,13 +490,13 @@ def infer(conf_dict, args): ...@@ -490,13 +490,13 @@ def infer(conf_dict, args):
output = test_exe.run(program=test_prog,fetch_list=fetch_list) output = test_exe.run(program=test_prog,fetch_list=fetch_list)
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
preds_list += list( preds_list += list(
map(lambda item: str((item[0] + 1) / 2), output[0])) map(lambda item: str((item[0] + 1) / 2).decode(), output[0]))
else: else:
preds_list += map(lambda item: str(np.argmax(item)), output[0]) preds_list += map(lambda item: str(np.argmax(item)).decode(), output[0])
except fluid.core.EOFException: except fluid.core.EOFException:
infer_pyreader.reset() infer_pyreader.reset()
break break
with codecs.open(args.infer_result_path, "w", "utf-8") as infer_file: with io.open(args.infer_result_path, "w", encoding="utf8") as infer_file:
for _data, _pred in zip(simnet_process.get_infer_data(), preds_list): for _data, _pred in zip(simnet_process.get_infer_data(), preds_list):
infer_file.write(_data + "\t" + _pred + "\n") infer_file.write(_data + "\t" + _pred + "\n")
logging.info("infer result saved in %s" % logging.info("infer result saved in %s" %
...@@ -516,15 +516,8 @@ if __name__ == "__main__": ...@@ -516,15 +516,8 @@ if __name__ == "__main__":
args = args.build_conf() args = args.build_conf()
utils.print_arguments(args) utils.print_arguments(args)
try: check_cuda(args.use_cuda)
if fluid.is_compiled_with_cuda() != True and args.use_cuda == True: check_version()
print(
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\nPlease: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
)
sys.exit(1)
except Exception as e:
pass
utils.init_log("./log/TextSimilarityNet") utils.init_log("./log/TextSimilarityNet")
conf_dict = config.SimNetConfig(args) conf_dict = config.SimNetConfig(args)
if args.do_train: if args.do_train:
...@@ -535,4 +528,4 @@ if __name__ == "__main__": ...@@ -535,4 +528,4 @@ if __name__ == "__main__":
infer(conf_dict, args) infer(conf_dict, args)
else: else:
raise ValueError( raise ValueError(
"one of do_train and do_test and do_infer must be True") "one of do_train and do_test and do_infer must be True")
\ No newline at end of file
#encoding=utf-8 # -*- encoding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -21,7 +21,6 @@ import sys ...@@ -21,7 +21,6 @@ import sys
import re import re
import os import os
import six import six
import codecs
import numpy as np import numpy as np
import logging import logging
import logging.handlers import logging.handlers
...@@ -56,9 +55,9 @@ def get_result_file(args): ...@@ -56,9 +55,9 @@ def get_result_file(args):
result_file: merge sample and predict result result_file: merge sample and predict result
""" """
with codecs.open(args.test_data_dir, "r", "utf-8") as test_file: with io.open(args.test_data_dir, "r", encoding="utf8") as test_file:
with codecs.open("predictions.txt", "r", "utf-8") as predictions_file: with io.open("predictions.txt", "r", encoding="utf8") as predictions_file:
with codecs.open(args.test_result_path, "w", "utf-8") as test_result_file: with io.open(args.test_result_path, "w", encoding="utf8") as test_result_file:
test_datas = [line.strip("\n") for line in test_file] test_datas = [line.strip("\n") for line in test_file]
predictions = [line.strip("\n") for line in predictions_file] predictions = [line.strip("\n") for line in predictions_file]
for test_data, prediction in zip(test_datas, predictions): for test_data, prediction in zip(test_datas, predictions):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册