import paddle from .base_dataset import BaseDataset from .builder import DATASETS import os from itertools import chain from pathlib import Path import traceback import random import numpy as np from PIL import Image from paddle.io import Dataset, WeightedRandomSampler def listdir(dname): fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) for ext in ['png', 'jpg', 'jpeg', 'JPG']])) return fnames def _make_balanced_sampler(labels): class_counts = np.bincount(labels) class_weights = 1. / class_counts weights = class_weights[labels] return WeightedRandomSampler(weights, len(weights)) class ImageFolder(Dataset): def __init__(self, root, use_sampler=False): self.samples, self.targets = self._make_dataset(root) self.use_sampler = use_sampler if self.use_sampler: self.sampler = _make_balanced_sampler(self.targets) self.iter_sampler = iter(self.sampler) def _make_dataset(self, root): domains = os.listdir(root) fnames, labels = [], [] for idx, domain in enumerate(sorted(domains)): class_dir = os.path.join(root, domain) cls_fnames = listdir(class_dir) fnames += cls_fnames labels += [idx] * len(cls_fnames) return fnames, labels def __getitem__(self, i): if self.use_sampler: try: index = next(self.iter_sampler) except StopIteration: self.iter_sampler = iter(self.sampler) index = next(self.iter_sampler) else: index = i fname = self.samples[index] label = self.targets[index] return fname, label def __len__(self): return len(self.targets) class ReferenceDataset(Dataset): def __init__(self, root, use_sampler=None): self.samples, self.targets = self._make_dataset(root) self.use_sampler = use_sampler if self.use_sampler: self.sampler = _make_balanced_sampler(self.targets) self.iter_sampler = iter(self.sampler) def _make_dataset(self, root): domains = os.listdir(root) fnames, fnames2, labels = [], [], [] for idx, domain in enumerate(sorted(domains)): class_dir = os.path.join(root, domain) cls_fnames = listdir(class_dir) fnames += cls_fnames fnames2 += random.sample(cls_fnames, len(cls_fnames)) labels += [idx] * len(cls_fnames) return list(zip(fnames, fnames2)), labels def __getitem__(self, i): if self.use_sampler: try: index = next(self.iter_sampler) except StopIteration: self.iter_sampler = iter(self.sampler) index = next(self.iter_sampler) else: index = i fname, fname2 = self.samples[index] label = self.targets[index] return fname, fname2, label def __len__(self): return len(self.targets) @DATASETS.register() class StarGANv2Dataset(BaseDataset): """ """ def __init__(self, dataroot, is_train, preprocess, test_count=0): """Initialize single dataset class. Args: dataroot (str): Directory of dataset. preprocess (list[dict]): A sequence of data preprocess config. """ super(StarGANv2Dataset, self).__init__(preprocess) self.dataroot = dataroot self.is_train = is_train if self.is_train: self.src_loader = ImageFolder(self.dataroot, use_sampler=True) self.ref_loader = ReferenceDataset(self.dataroot, use_sampler=True) self.counts = len(self.src_loader) else: files = os.listdir(self.dataroot) if 'src' in files and 'ref' in files: self.src_loader = ImageFolder(os.path.join(self.dataroot, 'src')) self.ref_loader = ImageFolder(os.path.join(self.dataroot, 'ref')) else: self.src_loader = ImageFolder(self.dataroot) self.ref_loader = ImageFolder(self.dataroot) self.counts = min(test_count, len(self.src_loader)) self.counts = min(self.counts, len(self.ref_loader)) def _fetch_inputs(self): try: x, y = next(self.iter_src) except (AttributeError, StopIteration): self.iter_src = iter(self.src_loader) x, y = next(self.iter_src) return x, y def _fetch_refs(self): try: x, x2, y = next(self.iter_ref) except (AttributeError, StopIteration): self.iter_ref = iter(self.ref_loader) x, x2, y = next(self.iter_ref) return x, x2, y def __getitem__(self, idx): if self.is_train: x, y = self._fetch_inputs() x_ref, x_ref2, y_ref = self._fetch_refs() datas = { 'src_path': x, 'src_cls': y, 'ref_path': x_ref, 'ref2_path': x_ref2, 'ref_cls': y_ref, } else: x, y = self.src_loader[idx] x_ref, y_ref = self.ref_loader[idx] datas = { 'src_path': x, 'src_cls': y, 'ref_path': x_ref, 'ref_cls': y_ref, } if hasattr(self, 'preprocess') and self.preprocess: datas = self.preprocess(datas) return datas def __len__(self): return self.counts def prepare_data_infos(self, dataroot): pass