# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from PIL import Image import os import os.path from typing import Any, Callable, cast, Dict, List, Optional, Tuple IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: return filename.lower().endswith(extensions) def make_dataset( directory: str, class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]]=None, is_valid_file: Optional[Callable[[str], bool]]=None, ) -> List[Tuple[ str, int]]: instances = [] directory = os.path.expanduser(directory) both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError( "Both extensions and is_valid_file cannot be None or not None at the same time" ) if extensions is not None: def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) if is_valid_file(path): item = path, class_index instances.append(item) return instances class ImageFolder(paddle.vision.datasets.ImageFolder): def __init__(self, root, transform=None, target_transform=None, loader=None, is_valid_file=None): super().__init__( root, loader=loader, transform=transform, is_valid_file=is_valid_file) self.target_transform = target_transform classes, class_to_idx = self._find_classes(root) self.samples = self.make_dataset(self.root, class_to_idx, IMG_EXTENSIONS if is_valid_file is None else None, is_valid_file) def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx @staticmethod def make_dataset( directory: str, class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]]=None, is_valid_file: Optional[Callable[[str], bool]]=None, ) -> List[ Tuple[str, int]]: return make_dataset( directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) def __getitem__(self, index): path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return [sample, target]