paired_dataset.py 2.6 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10
import cv2
import paddle
import os.path
from .base_dataset import BaseDataset, get_params, get_transform
from .image_folder import make_dataset

from .builder import DATASETS


@DATASETS.register()
L
LielinJiang 已提交
11
class PairedDataset(BaseDataset):
L
LielinJiang 已提交
12 13 14
    """A dataset class for paired image dataset.
    """

L
LielinJiang 已提交
15
    def __init__(self, cfg):
L
LielinJiang 已提交
16 17 18 19 20
        """Initialize this dataset class.

        Args:
            cfg (dict) -- stores all the experiment flags
        """
L
LielinJiang 已提交
21 22 23
        BaseDataset.__init__(self, cfg)
        self.dir_AB = os.path.join(cfg.dataroot, cfg.phase)  # get the image directory
        self.AB_paths = sorted(make_dataset(self.dir_AB, cfg.max_dataset_size))  # get image paths
L
LielinJiang 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        assert(self.cfg.transform.load_size >= self.cfg.transform.crop_size)   # crop_size should be smaller than the size of loaded image
        self.input_nc = self.cfg.output_nc if self.cfg.direction == 'BtoA' else self.cfg.input_nc
        self.output_nc = self.cfg.input_nc if self.cfg.direction == 'BtoA' else self.cfg.output_nc

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor) - - an image in the input domain
            B (tensor) - - its corresponding image in the target domain
            A_paths (str) - - image paths
            B_paths (str) - - image paths (same as A_paths)
        """
        # read a image given a random integer index
        AB_path = self.AB_paths[index]
        AB = cv2.imread(AB_path)

        # split AB image into A and B
        h, w = AB.shape[:2]
        # w, h = AB.size
        w2 = int(w / 2)

        A = AB[:h, :w2, :]
        B = AB[:h, w2:, :]


        # apply the same transform to both A and B
        # transform_params = get_params(self.opt, A.size)
        transform_params = get_params(self.cfg.transform, (w2, h))

        A_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.input_nc == 1))
        B_transform = get_transform(self.cfg.transform, transform_params, grayscale=(self.output_nc == 1))

        A = A_transform(A)
        B = B_transform(B)

L
LielinJiang 已提交
63
        return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
L
LielinJiang 已提交
64 65 66 67 68 69 70 71 72 73 74 75

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.AB_paths)

    def get_path_by_indexs(self, indexs):
        if isinstance(indexs, paddle.Variable):
            indexs = indexs.numpy()
        current_paths = []
        for index in indexs:
            current_paths.append(self.AB_paths[index])
        return current_paths