From d495dfa9d2620f1b73e36935b7487585e1485c25 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Tue, 19 Apr 2022 16:36:46 +0800 Subject: [PATCH] update delimiter --- ppcls/data/dataloader/imagenet_dataset.py | 11 ++++++++++- ppcls/data/postprocess/topk.py | 5 +++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index 1166ab38..897394d9 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -21,6 +21,15 @@ from .common_dataset import CommonDataset class ImageNetDataset(CommonDataset): + def __init__( + self, + image_root, + cls_label_path, + transform_ops=None, + delimiter=None): + super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops) + self.delimiter = delimiter if delimiter is not None else " " + def _load_anno(self, seed=None): assert os.path.exists(self._cls_path) assert os.path.exists(self._img_root) @@ -32,7 +41,7 @@ class ImageNetDataset(CommonDataset): if seed is not None: np.random.RandomState(seed).shuffle(lines) for l in lines: - l = l.strip().split(" ") + l = l.strip().split(self.delimiter) self.images.append(os.path.join(self._img_root, l[0])) self.labels.append(np.int64(l[1])) assert os.path.exists(self.images[-1]) diff --git a/ppcls/data/postprocess/topk.py b/ppcls/data/postprocess/topk.py index 9c1371bf..76987b5d 100644 --- a/ppcls/data/postprocess/topk.py +++ b/ppcls/data/postprocess/topk.py @@ -19,10 +19,11 @@ import paddle.nn.functional as F class Topk(object): - def __init__(self, topk=1, class_id_map_file=None): + def __init__(self, topk=1, class_id_map_file=None, delimiter=None): assert isinstance(topk, (int, )) self.class_id_map = self.parse_class_id_map(class_id_map_file) self.topk = topk + self.delimiter = delimiter if delimiter is not None else " " def parse_class_id_map(self, class_id_map_file): if class_id_map_file is None: @@ -38,7 +39,7 @@ class Topk(object): with open(class_id_map_file, "r") as fin: lines = fin.readlines() for line in lines: - partition = line.split("\n")[0].partition(" ") + partition = line.split("\n")[0].partition(self.delimiter) class_id_map[int(partition[0])] = str(partition[-1]) except Exception as ex: print(ex) -- GitLab