未验证 提交 74a33b7f 编写于 作者: Z zhangyubo0722 提交者: GitHub

fix gbk (#2941)

上级 b3fcc986
......@@ -36,8 +36,12 @@ class Topk(object):
try:
class_id_map = {}
with open(class_id_map_file, "r") as fin:
lines = fin.readlines()
try:
with open(class_id_map_file, "r", encoding='utf-8') as fin:
lines = fin.readlines()
except Exception as e:
with open(class_id_map_file, "r", encoding='gbk') as fin:
lines = fin.readlines()
for line in lines:
partition = line.split("\n")[0].partition(self.delimiter)
class_id_map[int(partition[0])] = str(partition[-1])
......
......@@ -18,19 +18,28 @@ import base64
import numpy as np
def get_image_list(img_file):
def get_image_list(img_file, infer_list=None):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for root, dirs, files in os.walk(img_file):
for single_file in files:
if single_file.split('.')[-1] in img_end:
imgs_lists.append(os.path.join(root, single_file))
if infer_list and not os.path.exists(infer_list):
raise Exception("not found infer list {}".format(infer_list))
if infer_list:
with open(infer_list, "r") as f:
lines = f.readlines()
for line in lines:
image_path = line.strip(" ").split()[0]
image_path = os.path.join(img_file, image_path)
imgs_lists.append(image_path)
else:
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for root, dirs, files in os.walk(img_file):
for single_file in files:
if single_file.split('.')[-1] in img_end:
imgs_lists.append(os.path.join(root, single_file))
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
......
......@@ -442,7 +442,9 @@ class Engine(object):
results = []
total_trainer = dist.get_world_size()
local_rank = dist.get_rank()
image_list = get_image_list(self.config["Infer"]["infer_imgs"])
infer_imgs = self.config["Infer"]["infer_imgs"]
infer_list = self.config["Infer"].get("infer_list", None)
image_list = get_image_list(infer_imgs, infer_list=infer_list)
# data split
image_list = image_list[local_rank::total_trainer]
......@@ -450,6 +452,7 @@ class Engine(object):
self.model.eval()
batch_data = []
image_file_list = []
save_path = self.config["Infer"].get("save_dir", None)
for idx, image_file in enumerate(image_list):
with open(image_file, 'rb') as f:
x = f.read()
......@@ -473,11 +476,11 @@ class Engine(object):
out = out["output"]
result = self.postprocess_func(out, image_file_list)
logger.info(result)
if not save_path:
logger.info(result)
results.extend(result)
batch_data.clear()
image_file_list.clear()
save_path = self.config["Infer"].get("save_dir", None)
if save_path:
save_predict_result(save_path, results)
return results
......
......@@ -24,12 +24,9 @@ def save_predict_result(save_path, result):
elif os.path.splitext(save_path)[-1] == '.json':
save_path = save_path
else:
logger.warning(
f"{save_path} is invalid input path, only files in json format are supported."
)
raise Exception(f"{save_path} is invalid input path, only files in json format are supported.")
if os.path.exists(save_path):
logger.warning(
f"The file {save_path} will be overwritten."
)
logger.warning(f"The file {save_path} will be overwritten.")
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(result, f)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册