提交 dc51469b 编写于 作者: 文幕地方's avatar 文幕地方

add encoding='utf-8'

上级 3ffaf7f2
...@@ -128,12 +128,16 @@ def evaluate(args, ...@@ -128,12 +128,16 @@ def evaluate(args,
"f1": f1_score(out_label_list, preds_list), "f1": f1_score(out_label_list, preds_list),
} }
with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout: with open(
os.path.join(args.output_dir, "test_gt.txt"), "w",
encoding='utf-8') as fout:
for lbl in out_label_list: for lbl in out_label_list:
for l in lbl: for l in lbl:
fout.write(l + "\t") fout.write(l + "\t")
fout.write("\n") fout.write("\n")
with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout: with open(
os.path.join(args.output_dir, "test_pred.txt"), "w",
encoding='utf-8') as fout:
for lbl in preds_list: for lbl in preds_list:
for l in lbl: for l in lbl:
fout.write(l + "\t") fout.write(l + "\t")
......
...@@ -37,7 +37,7 @@ def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True): ...@@ -37,7 +37,7 @@ def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
assert fp_type in ["gt", "pred"] assert fp_type in ["gt", "pred"]
key = "label" if fp_type == "gt" else "pred" key = "label" if fp_type == "gt" else "pred"
res_dict = dict() res_dict = dict()
with open(fp, "r") as fin: with open(fp, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
for _, line in enumerate(lines): for _, line in enumerate(lines):
......
...@@ -16,13 +16,13 @@ import json ...@@ -16,13 +16,13 @@ import json
def transfer_xfun_data(json_path=None, output_file=None): def transfer_xfun_data(json_path=None, output_file=None):
with open(json_path, "r") as fin: with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
json_info = json.loads(lines[0]) json_info = json.loads(lines[0])
documents = json_info["documents"] documents = json_info["documents"]
label_info = {} label_info = {}
with open(output_file, "w") as fout: with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents): for idx, document in enumerate(documents):
img_info = document["img"] img_info = document["img"]
document = document["document"] document = document["document"]
......
...@@ -92,7 +92,7 @@ def infer(args): ...@@ -92,7 +92,7 @@ def infer(args):
def load_ocr(img_folder, json_path): def load_ocr(img_folder, json_path):
import json import json
d = [] d = []
with open(json_path, "r") as fin: with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
image_name, info_str = line.split("\t") image_name, info_str = line.split("\t")
......
...@@ -59,7 +59,8 @@ def pad_sentences(tokenizer, ...@@ -59,7 +59,8 @@ def pad_sentences(tokenizer,
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0] encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference ] * difference
else: else:
assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}" assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
tokenizer.padding_side)
else: else:
if return_attention_mask: if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
...@@ -224,7 +225,7 @@ def infer(args): ...@@ -224,7 +225,7 @@ def infer(args):
# load ocr results json # load ocr results json
ocr_results = dict() ocr_results = dict()
with open(args.ocr_json_path, "r") as fin: with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
img_name, json_info = line.split("\t") img_name, json_info = line.split("\t")
...@@ -234,7 +235,10 @@ def infer(args): ...@@ -234,7 +235,10 @@ def infer(args):
infer_imgs = get_image_file_list(args.infer_imgs) infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer # loop for infer
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path)) print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
......
...@@ -113,7 +113,10 @@ if __name__ == "__main__": ...@@ -113,7 +113,10 @@ if __name__ == "__main__":
# loop for infer # loop for infer
ser_engine = SerPredictor(args) ser_engine = SerPredictor(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path)) print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
......
...@@ -112,7 +112,10 @@ if __name__ == "__main__": ...@@ -112,7 +112,10 @@ if __name__ == "__main__":
# loop for infer # loop for infer
ser_re_engine = SerReSystem(args) ser_re_engine = SerReSystem(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout: with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path)) print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
......
...@@ -32,7 +32,7 @@ def set_seed(seed): ...@@ -32,7 +32,7 @@ def set_seed(seed):
def get_bio_label_maps(label_map_path): def get_bio_label_maps(label_map_path):
with open(label_map_path, "r") as fin: with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
lines = [line.strip() for line in lines] lines = [line.strip() for line in lines]
if "O" not in lines: if "O" not in lines:
......
...@@ -162,7 +162,7 @@ class XFUNDataset(Dataset): ...@@ -162,7 +162,7 @@ class XFUNDataset(Dataset):
return encoded_inputs return encoded_inputs
def read_all_lines(self, ): def read_all_lines(self, ):
with open(self.label_path, "r") as fin: with open(self.label_path, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
return lines return lines
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册