未验证 提交 1b3cf0da 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #4969 from WenmuZhou/fix_vqa

add encoding='utf-8'
......@@ -128,12 +128,16 @@ def evaluate(args,
"f1": f1_score(out_label_list, preds_list),
}
with open(os.path.join(args.output_dir, "test_gt.txt"), "w") as fout:
with open(
os.path.join(args.output_dir, "test_gt.txt"), "w",
encoding='utf-8') as fout:
for lbl in out_label_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
with open(os.path.join(args.output_dir, "test_pred.txt"), "w") as fout:
with open(
os.path.join(args.output_dir, "test_pred.txt"), "w",
encoding='utf-8') as fout:
for lbl in preds_list:
for l in lbl:
fout.write(l + "\t")
......
......@@ -37,7 +37,7 @@ def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
assert fp_type in ["gt", "pred"]
key = "label" if fp_type == "gt" else "pred"
res_dict = dict()
with open(fp, "r") as fin:
with open(fp, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for _, line in enumerate(lines):
......
......@@ -16,13 +16,13 @@ import json
def transfer_xfun_data(json_path=None, output_file=None):
with open(json_path, "r") as fin:
with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
json_info = json.loads(lines[0])
documents = json_info["documents"]
label_info = {}
with open(output_file, "w") as fout:
with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents):
img_info = document["img"]
document = document["document"]
......
......@@ -92,7 +92,7 @@ def infer(args):
def load_ocr(img_folder, json_path):
import json
d = []
with open(json_path, "r") as fin:
with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
image_name, info_str = line.split("\t")
......
......@@ -59,7 +59,8 @@ def pad_sentences(tokenizer,
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
assert False, f"padding_side of tokenizer just supports [\"right\"] but got {tokenizer.padding_side}"
assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
tokenizer.padding_side)
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
......@@ -224,7 +225,7 @@ def infer(args):
# load ocr results json
ocr_results = dict()
with open(args.ocr_json_path, "r") as fin:
with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
img_name, json_info = line.split("\t")
......@@ -234,7 +235,10 @@ def infer(args):
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}]".format(idx, len(infer_imgs), img_path))
......
......@@ -113,7 +113,10 @@ if __name__ == "__main__":
# loop for infer
ser_engine = SerPredictor(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
......
......@@ -112,7 +112,10 @@ if __name__ == "__main__":
# loop for infer
ser_re_engine = SerReSystem(args)
with open(os.path.join(args.output_dir, "infer_results.txt"), "w") as fout:
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
print("process: [{}/{}], {}".format(idx, len(infer_imgs), img_path))
......
......@@ -32,7 +32,7 @@ def set_seed(seed):
def get_bio_label_maps(label_map_path):
with open(label_map_path, "r") as fin:
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
......
......@@ -162,7 +162,7 @@ class XFUNDataset(Dataset):
return encoded_inputs
def read_all_lines(self, ):
with open(self.label_path, "r") as fin:
with open(self.label_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
return lines
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册