diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index 1166ab3851b0c594469f135c1c4b3c7bc921ac5f..897394d99008df1b62e7a5371839c1f659069262 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -21,6 +21,15 @@ from .common_dataset import CommonDataset class ImageNetDataset(CommonDataset): + def __init__( + self, + image_root, + cls_label_path, + transform_ops=None, + delimiter=None): + super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops) + self.delimiter = delimiter if delimiter is not None else " " + def _load_anno(self, seed=None): assert os.path.exists(self._cls_path) assert os.path.exists(self._img_root) @@ -32,7 +41,7 @@ class ImageNetDataset(CommonDataset): if seed is not None: np.random.RandomState(seed).shuffle(lines) for l in lines: - l = l.strip().split(" ") + l = l.strip().split(self.delimiter) self.images.append(os.path.join(self._img_root, l[0])) self.labels.append(np.int64(l[1])) assert os.path.exists(self.images[-1]) diff --git a/ppcls/data/postprocess/topk.py b/ppcls/data/postprocess/topk.py index 9c1371bfd11f4c93f06c82436e88e0ff20a57b35..76987b5d558c75cc3412958ff899992e0caa712f 100644 --- a/ppcls/data/postprocess/topk.py +++ b/ppcls/data/postprocess/topk.py @@ -19,10 +19,11 @@ import paddle.nn.functional as F class Topk(object): - def __init__(self, topk=1, class_id_map_file=None): + def __init__(self, topk=1, class_id_map_file=None, delimiter=None): assert isinstance(topk, (int, )) self.class_id_map = self.parse_class_id_map(class_id_map_file) self.topk = topk + self.delimiter = delimiter if delimiter is not None else " " def parse_class_id_map(self, class_id_map_file): if class_id_map_file is None: @@ -38,7 +39,7 @@ class Topk(object): with open(class_id_map_file, "r") as fin: lines = fin.readlines() for line in lines: - partition = line.split("\n")[0].partition(" ") + partition = line.split("\n")[0].partition(self.delimiter) class_id_map[int(partition[0])] = str(partition[-1]) except Exception as ex: print(ex)