# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix import random import numpy as np from paddle.io import Dataset from PIL import Image import cv2 import paddle.vision.transforms as transforms from .transforms import transforms as T from abc import ABC, abstractmethod class BaseDataset(Dataset, ABC): """This class is an abstract base class (ABC) for datasets. """ def __init__(self, cfg): """Initialize the class; save the options in the class Args: cfg (dict) -- stores all the experiment flags """ self.cfg = cfg self.root = cfg.dataroot @abstractmethod def __len__(self): """Return the total number of images in the dataset.""" return 0 @abstractmethod def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index - - a random integer for data indexing Returns: a dictionary of data with their names. It ususally contains the data itself and its metadata information. """ pass def get_params(cfg, size): w, h = size new_h = h new_w = w if cfg.preprocess == 'resize_and_crop': new_h = new_w = cfg.load_size elif cfg.preprocess == 'scale_width_and_crop': new_w = cfg.load_size new_h = cfg.load_size * h // w x = random.randint(0, np.maximum(0, new_w - cfg.crop_size)) y = random.randint(0, np.maximum(0, new_h - cfg.crop_size)) flip = random.random() > 0.5 return {'crop_pos': (x, y), 'flip': flip} def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, convert=True): transform_list = [] if grayscale: print('grayscale not support for now!!!') pass if 'resize' in cfg.preprocess: osize = (cfg.load_size, cfg.load_size) transform_list.append(transforms.Resize(osize, method)) elif 'scale_width' in cfg.preprocess: print('scale_width not support for now!!!') pass if 'crop' in cfg.preprocess: if params is None: transform_list.append(T.RandomCrop(cfg.crop_size)) else: transform_list.append(T.Crop(params['crop_pos'], cfg.crop_size)) if cfg.preprocess == 'none': print('preprocess not support for now!!!') pass if not cfg.no_flip: if params is None: transform_list.append(transforms.RandomHorizontalFlip()) elif params['flip']: transform_list.append(transforms.RandomHorizontalFlip(1.0)) if convert: transform_list += [transforms.Permute(to_rgb=True)] transform_list += [ transforms.Normalize((0., 0., 0.), (255., 255., 255.)) ] # transform_list += [ # transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) # ] return transforms.Compose(transform_list)