#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # #Licensed under the Apache License, Version 2.0 (the "License"); #you may not use this file except in compliance with the License. #You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # #Unless required by applicable law or agreed to in writing, software #distributed under the License is distributed on an "AS IS" BASIS, #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. from PIL import Image import os import os.path import sys IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff') def has_file_allowed_extension(filename, extensions): """Checks if a file is an allowed extension. Args: filename (string): path to a file extensions (tuple of strings): extensions to consider (lowercase) Returns: bool: True if the filename ends with one of given extensions """ return filename.lower().endswith(extensions) def is_image_file(filename): """Checks if a file is an allowed image extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ return has_file_allowed_extension(filename, IMG_EXTENSIONS) def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir) for target in sorted(class_to_idx.keys()): d = os.path.join(dir, target) if not os.path.isdir(d): continue for root, _, fnames in sorted(os.walk(d)): for fname in sorted(fnames): if has_file_allowed_extension(fname, extensions): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) return images class DatasetFolder(object): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (tuple[string]): A list of allowed extensions. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples targets (list): The class_index value for each image in the dataset """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): self.root = root self.transform = transform self.target_transform = target_transform classes, class_to_idx = self._find_classes(self.root) samples = make_dataset(self.root, class_to_idx, extensions) if len(samples) == 0: raise (RuntimeError( "Found 0 files in subfolders of: " + self.root + "\n" "Supported extensions are: " + ",".join(extensions))) self.loader = loader self.extensions = extensions self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.targets = [s[1] for s in samples] def _find_classes(self, dir): """ Finds the class folders in a dataset. Args: dir (string): Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. Ensures: No class is a subdirectory of another. """ if sys.version_info >= (3, 5): # Faster and available in Python 3.5 and above classes = [d.name for d in os.scandir(dir) if d.is_dir()] else: classes = [ d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) ] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def default_loader(path): return pil_loader(path) class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__(self, root, transform=None, target_transform=None, loader=default_loader): super(ImageFolder, self).__init__( root, loader, IMG_EXTENSIONS, transform=transform, target_transform=target_transform) self.imgs = self.samples