base_dataset.py 2.8 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
# code was heavily based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
import random
import numpy as np

from paddle.io import Dataset
from PIL import Image
import cv2

import paddle.incubate.hapi.vision.transforms as transforms
from .transforms import transforms as T
from abc import ABC, abstractmethod


class BaseDataset(Dataset, ABC):
    """This class is an abstract base class (ABC) for datasets.
    """

    def __init__(self, cfg):
        """Initialize the class; save the options in the class

        Args:
            cfg (dict) -- stores all the experiment flags
        """
        self.cfg = cfg
        self.root = cfg.dataroot

    @abstractmethod
    def __len__(self):
        """Return the total number of images in the dataset."""
        return 0

    @abstractmethod
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns:
            a dictionary of data with their names. It ususally contains the data itself and its metadata information.
        """
        pass


def get_params(cfg, size):
    w, h = size
    new_h = h
    new_w = w
    if cfg.preprocess == 'resize_and_crop':
        new_h = new_w = cfg.load_size
    elif cfg.preprocess == 'scale_width_and_crop':
        new_w = cfg.load_size
        new_h = cfg.load_size * h // w

    x = random.randint(0, np.maximum(0, new_w - cfg.crop_size))
    y = random.randint(0, np.maximum(0, new_h - cfg.crop_size))

    flip = random.random() > 0.5

    return {'crop_pos': (x, y), 'flip': flip}



def get_transform(cfg, params=None, grayscale=False, method=cv2.INTER_CUBIC, convert=True):
    transform_list = []
    if grayscale:
        print('grayscale not support for now!!!')
L
LielinJiang 已提交
68
        pass
L
LielinJiang 已提交
69 70 71 72 73
    if 'resize' in cfg.preprocess:
        osize = (cfg.load_size, cfg.load_size)
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in cfg.preprocess:
        print('scale_width not support for now!!!')
L
LielinJiang 已提交
74
        pass
L
LielinJiang 已提交
75 76

    if 'crop' in cfg.preprocess:
L
LielinJiang 已提交
77

L
LielinJiang 已提交
78 79 80 81 82 83 84
        if params is None:
            transform_list.append(T.RandomCrop(cfg.crop_size))
        else:
            transform_list.append(T.Crop(params['crop_pos'], cfg.crop_size))

    if cfg.preprocess == 'none':
        print('preprocess not support for now!!!')
L
LielinJiang 已提交
85
        pass
L
LielinJiang 已提交
86 87 88 89 90 91 92 93 94 95 96

    if not cfg.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.RandomHorizontalFlip(1.0))
    
    if convert:
        transform_list += [transforms.Permute(to_rgb=True)]
        transform_list += [transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))]
    return transforms.Compose(transform_list)