diff --git a/ppcls/data/postprocess/topk.py b/ppcls/data/postprocess/topk.py index 2b475f83f787036f4fda88c6c98317a357100637..50cc40d4948215ddf7bedcd12ab03084e165d313 100644 --- a/ppcls/data/postprocess/topk.py +++ b/ppcls/data/postprocess/topk.py @@ -36,11 +36,15 @@ class Topk(object): try: class_id_map = {} - with open(class_id_map_file, "r") as fin: - lines = fin.readlines() - for line in lines: - partition = line.split("\n")[0].partition(self.delimiter) - class_id_map[int(partition[0])] = str(partition[-1]) + try: + with open(class_id_map_file, "r", encoding='utf-8') as fin: + lines = fin.readlines() + except Exception as e: + with open(class_id_map_file, "r", encoding='gbk') as fin: + lines = fin.readlines() + for line in lines: + partition = line.split("\n")[0].partition(self.delimiter) + class_id_map[int(partition[0])] = str(partition[-1]) except Exception as ex: print(ex) class_id_map = None diff --git a/ppcls/data/utils/get_image_list.py b/ppcls/data/utils/get_image_list.py index 34d12c3c9c351870f1fc6ad1c798bb8e0d894e5f..9b6de0690a576181d1ef53a162d7903a3d04058f 100644 --- a/ppcls/data/utils/get_image_list.py +++ b/ppcls/data/utils/get_image_list.py @@ -18,19 +18,28 @@ import base64 import numpy as np -def get_image_list(img_file): +def get_image_list(img_file, infer_list=None): imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] - if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for root, dirs, files in os.walk(img_file): - for single_file in files: - if single_file.split('.')[-1] in img_end: - imgs_lists.append(os.path.join(root, single_file)) + if infer_list and not os.path.exists(infer_list): + raise Exception("not found infer list {}".format(infer_list)) + if infer_list: + with open(infer_list, "r") as f: + lines = f.readlines() + for line in lines: + image_path = line.strip(" ").split()[0] + image_path = os.path.join(img_file, image_path) + imgs_lists.append(image_path) + else: + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for root, dirs, files in os.walk(img_file): + for single_file in files: + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(root, single_file)) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file)) imgs_lists = sorted(imgs_lists) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py old mode 100755 new mode 100644 index 3f93ceb7ababd97781966ceff46b88be62d58e33..851571010237c25ca60c03d6572c1385fd846794 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -442,7 +442,9 @@ class Engine(object): results = [] total_trainer = dist.get_world_size() local_rank = dist.get_rank() - image_list = get_image_list(self.config["Infer"]["infer_imgs"]) + infer_imgs = self.config["Infer"]["infer_imgs"] + infer_list = self.config["Infer"].get("infer_list", None) + image_list = get_image_list(infer_imgs, infer_list=infer_list) # data split image_list = image_list[local_rank::total_trainer] @@ -450,6 +452,7 @@ class Engine(object): self.model.eval() batch_data = [] image_file_list = [] + save_path = self.config["Infer"].get("save_dir", None) for idx, image_file in enumerate(image_list): with open(image_file, 'rb') as f: x = f.read() @@ -473,11 +476,11 @@ class Engine(object): out = out["output"] result = self.postprocess_func(out, image_file_list) - logger.info(result) + if not save_path: + logger.info(result) results.extend(result) batch_data.clear() image_file_list.clear() - save_path = self.config["Infer"].get("save_dir", None) if save_path: save_predict_result(save_path, results) return results diff --git a/ppcls/utils/save_result.py b/ppcls/utils/save_result.py index 863113ce876595b856bd85b816dce4b828745cde..f470b801ebb45db997a508bc39b04d5be95487a9 100644 --- a/ppcls/utils/save_result.py +++ b/ppcls/utils/save_result.py @@ -24,12 +24,9 @@ def save_predict_result(save_path, result): elif os.path.splitext(save_path)[-1] == '.json': save_path = save_path else: - logger.warning( - f"{save_path} is invalid input path, only files in json format are supported." - ) + raise Exception(f"{save_path} is invalid input path, only files in json format are supported.") + if os.path.exists(save_path): - logger.warning( - f"The file {save_path} will be overwritten." - ) + logger.warning(f"The file {save_path} will be overwritten.") with open(save_path, 'w', encoding='utf-8') as f: json.dump(result, f)